import numpy
import matplotlib.pyplot as p
import channel
reload(channel)
import lab1
reload(lab1)
import lab2
reload(lab2)
import lab2_2
reload(lab2_2)
import lab2_3
reload(lab2_3)

p.ion()

def plot_with_title(data,title):
    """
    Create a new figure and plot data using blue, square markers.
    """
    p.plot(data)   # line plot
    p.title(title)
    p.xlabel('Sample data')
    p.ylabel('Voltage')
    p.grid()


class scale_ir:

    def __init__(self, mark_size=800):
        # Get the ir channel
        #self.irchannel = channel.channel(channelid='2',random_tails=100)
        self.irchannel = channel.channel(channelid='ir')
        
        # Create a 0s, 1s, 0s mark to prepend data
        self.mark = numpy.zeros(3*mark_size)  # Create 0s 0s 0s + pad
        self.mark[mark_size:2*mark_size] = 1  # Set to 0s 1s 0s + 0 pad


    # arguments:
    #   input -- numpy.array of voltage samples to be transmitted
    # return value:
    #   numpy.array of received voltages from ir channel, scaled
    #   and offset so that the values range from 0 to 1
    def __call__(self,input):
        # Should prepend "mark" to the "input", send the result
        # through the ir channel (hardware is needed), and then
        # process the received data.  The mark samples should be
        # used to determine shifting and scaling for the channel,
        # so that the returned signal is between zero and one,
        # regardless of the ir hardware reflector position. The
        # mark should then be removed, along with any leading
        # zeros. You can use the find_start function to help you with
        # this, but you will need to calibrate offset!

        return input # Your code here


    def find_start(self, scaled_data):
        # Make sure the data is scaled correctly
        mind = numpy.min(scaled_data)
        maxd = numpy.max(scaled_data)
        assert max([abs(mind), abs(maxd - 1.0)]) < 0.05, "Unscaled data!"

        # Compute a version of the mark that has zero average value.
        zero_avg_mark = self.mark - numpy.average(self.mark)

        # Convolve the zero-averaged mark with the scaled_data
        # What happens if you convolve the original mark with a sample
        # sequence whose sample values are all one? How does the result of
        # this convolution change if you replace "mark" with "zero_avg mark?
        convolved_with_data = numpy.convolve(zero_avg_mark, scaled_data)

        # These two lines return the index of the first sample
        # in convolved_with_data that is equal to convolved_with data's
        # maximum value.  What is import about the sample index where
        # convolved_with_data achieves its maximum?
        max_conv = numpy.max(convolved_with_data)
        start_index = numpy.nonzero(convolved_with_data == max_conv)[0][0]

        """
        p.figure
        p.plot(convolved_with_data[0:len(scaled_data)])
        p.plot(scaled_data*max_conv)
        p.axvline(x=start_index)
        """
        # Start index is consistently to large by a particular amount (Why?)
        # Determine the offset by experiment.
        offset = 0
        assert offset != 0, "Offset is zero!  You must calibrate it."
        start_index += offset
        #print "start index", start_index
        return start_index
    
        
if __name__ == '__main__':

    # Create scaled_ir channel
    myir = scale_ir()
    my_un_ir = channel.channel(channelid='ir')

    # Compare the scale_ir channel to the non-scaled ir channel
    bits = [1,0,1,0,1,0,1,0]
    samples_per_bit = 500
    test_samples = lab1.bits_to_samples(bits,samples_per_bit,npostamble=100)
    out_samples = myir(test_samples)
    out_un_samples = my_un_ir(test_samples)


    # Plot the comparison for testing
    p.figure()
    p.subplots_adjust(hspace = 0.6)
    p.subplot(311)
    plot_with_title(test_samples,"Input")
    p.subplot(312)
    plot_with_title(out_un_samples,"Unscaled and Shifted")
    p.subplot(313)
    plot_with_title(out_samples,"Scaled")

    #Uncomment when ready to test
"""
    # Compute the usr and average
    ausr = lab2_2.unit_sample_response(myir,max_length=400,tol=0.0)
    for i in range(9):
        ausr += lab2_2.unit_sample_response(myir,max_length=400,tol=0.0)
    ausr /= 10.0

    # Plot the unit_sample response
    lab2.plot_unit_sample_response(ausr, "scale_ir")

    # Demonstrate the ability to predict
    lab2_3.compare_usr_chan(ausr, myir, "scale_ir",samples_per_bit=100)


    # when ready for checkoff, enable the following line
    # BUT BE SURE You still connected to the IR system you tested.
    #lab2.checkoff(scale_ir,'L2_4')   

"""