# DO NOT EDIT this file, you'll break the checkoff process!

import math
import numpy
import matplotlib.pyplot as p
import lab1
reload(lab1)
import im1d
reload(im1d)
import channel
reload(channel)

def demo_deconvolver(input, mychannel, h, deconver, images=False):
    # Transmit the data through the channel
    rcvd = mychannel(numpy.append(input,numpy.zeros(len(h))))

    # Deconvolve the 
    deconv = deconver(rcvd,h)

    # Show the rcvd data
    if images:
        plots = [131,132,133]
        plotf = im1d.im1dshow
    else:
        plots = [311,312,313]
        plotf = p.plot


    p.figure()
    p.subplots_adjust(hspace = 0.6)
    p.subplot(plots[0])
    plotf(input)
    p.title("Input Samples")
    p.subplot(plots[1])
    plotf(rcvd[0:len(input)])
    p.title("Received Samples")
    p.subplot(plots[2])
    plotf(deconv[0:len(input)])
    p.title("Deconvolved Samples")
    

# for Task #3
def plot_unit_sample_response(response,name):
    p.figure()
    p.stem(range(len(response)),response)
    p.title('Unit-sample response of channel %s' % name)
    p.xlabel('Sample number')


def testrun(channel,message,samples_per_bit,npreamble=0,npostamble=0,hlen=127):
    x = [0.0]*(npreamble + len(message)*samples_per_bit + npostamble)
    index = npreamble
    one_bit_samples = [1.0]*samples_per_bit
    for i in range(len(message)):
        next_index = index + samples_per_bit
        if message[i] == 1:
            x[index:next_index] = one_bit_samples
        index = next_index
    y = channel(x)
    h = channel([1.0]+hlen*[0.0])
    return (x,y,h)

def maxdiff(x,w):
    return max([abs(x[i]-w[i]) for i in xrange(min(len(w),len(x)))])

def test_deconvolver(deconvolver,channel,message,samples_per_bit):
    # send message through the channel
    x,y,h = testrun(channel,message,samples_per_bit)

    # deconvolve
    w = deconvolver(y,h)

    print "Using channel %s:" %channel.id,"the input and output of deconvolver differ by",maxdiff(x,w)

    # compute vertical axis bounds
    vmax = max(max(x),max(y),max(w))
    vmin = min(min(x),min(y),min(w))
    delta = vmax - vmin
    vmax += 0.1 * delta
    vmin -= 0.1 * delta

    p.figure()
    p.subplots_adjust(hspace=0.6)
    # plot input to channel
    p.subplot(3,1,1)
    p.plot(x)
    p.axis([-1,len(x)+1,vmin,vmax])
    p.xlabel('Sample number')
    p.ylabel('Volts')
    p.title('Input to channel %s' % channel.id)
    # plot output from channel
    p.subplot(3,1,2)
    p.plot(y)
    p.axis([-1,len(x)+1,vmin,vmax])
    p.xlabel('Sample number')
    p.ylabel('Volts')
    p.title('Output from channel %s' % channel.id)
    # plot deconvolved output
    p.subplot(3,1,3)
    p.plot(w)
    p.axis([-1,len(x)+1,vmin,vmax])
    p.xlabel('Sample number')
    p.ylabel('Volts')
    p.title('Deconvolved output from channel %s (noise=%g)' % (channel.id,channel.noise))


def verify_task4(f):
    points = 0

    ch = f()
    bits = [1,0,1,0,1,0,1,0]
    samples_per_bit = 500
    ts = lab1.bits_to_samples(bits,samples_per_bit,npostamble=100)
    os = numpy.array(ch(ts))


    if(abs(numpy.max(os) - 1.0) < 0.05 and abs(numpy.min(os)) < 0.05):
        points += 0.5
    else:
        print("The IR channel scalling is incorrect")

    tssum = numpy.sum(ts)/2.0
    zat = ts - numpy.average(ts)
    zconv = numpy.convolve(zat,os)
    maxz = numpy.max(zconv)
    if numpy.abs(tssum - maxz)/tssum  < 0.2:
        points += 0.5
    else:
        print("Received samples do not seem to match transmitted samples.")

    if numpy.nonzero(zconv == maxz)[0][0] < len(ts) + 200:
        points += 0.5
    else:
        print("You are leaving lots the leading junk in the received data!")

    if os[0] < 0.02:
        points += 0.5
    else:
        print("You are probably removing to much leading data.")
        print("Be sure you have a paper cover to screen the florescent lights")
    
    return points

def verify_task5(f):
    points = 0
    m = [1,0,1,1,0,1,1,0,0,1,0,0]

    samples_per_bit = 8
    
    channel1 = channel.channel('1')
    x,y,h = testrun(channel1,m,samples_per_bit)
    w = f(y,h)
    if maxdiff(x,w) < 1e-3: points += 1.5

    channel2 = channel.channel('2')
    x,y,h = testrun(channel2,m,samples_per_bit)
    w = f(y,h)
    if maxdiff(x,w) < 1e-3: points += 1.5

    return points

##################################################
##
## Code to submit task to server.  Do not change.
## Task-specific code is in verify(), defined above.
##
##################################################

import Tkinter
class Dialog(Tkinter.Toplevel):
    def __init__(self, parent, title = None):
        Tkinter.Toplevel.__init__(self, parent)
        self.transient(parent)
        if title: self.title(title)
        self.parent = parent

        body = Tkinter.Frame(self)
        self.initial_focus = self.body(body)
        body.pack(padx=5, pady=5)

        self.buttonbox()
        self.grab_set()

        if not self.initial_focus:
            self.initial_focus = self

        self.protocol("WM_DELETE_WINDOW", self.cancel)
        self.geometry("+%d+%d" % (parent.winfo_rootx()+50,parent.winfo_rooty()+50))
        
        self.initial_focus.focus_set()
        self.wait_window(self)

    def body(self, master):
        return None

    # add standard button box
    def buttonbox(self):
        box = Tkinter.Frame(self)
        w = Tkinter.Button(box, text="Ok", width=10, command=self.ok, default=Tkinter.ACTIVE)
        w.pack(side=Tkinter.LEFT, padx=5, pady=5)
        box.pack()
        
    # standard button semantics
    def ok(self, event=None):
        if not self.validate():
            self.initial_focus.focus_set() # put focus back
            return
        self.withdraw()
        self.update_idletasks()
        self.apply()
        self.cancel()
        
    def cancel(self, event=None):
        # put focus back to the parent window
        self.parent.focus_set()
        self.destroy()
        
    # command hooks
    def validate(self):
        return 1 # override

    def apply(self):
        pass   # override

# ask user for Athena username and MIT ID
class SubmitDialog(Dialog):
    def __init__(self,parent,error=None,title = None):
        self.error = error
        self.athena_name = None
        self.mit_id = None
        Dialog.__init__(self,parent,title=title)

    def body(self, master):
        row = 0
        if self.error:
            l = Tkinter.Label(master,text=self.error,
                              anchor=Tkinter.W,justify=Tkinter.LEFT,fg="red")
            l.grid(row=row,sticky=Tkinter.W,columnspan=2)
            row += 1
        Tkinter.Label(master, text="Athena username:").grid(row=row,sticky=Tkinter.E)
        self.e1 = Tkinter.Entry(master)
        self.e1.grid(row=row, column=1)

        row += 1
        Tkinter.Label(master, text="MIT ID:").grid(row=row,sticky=Tkinter.E)
        self.e2 = Tkinter.Entry(master)
        self.e2.grid(row=row, column=1)

        return self.e1 # initial focus

    # add standard button box
    def buttonbox(self):
        box = Tkinter.Frame(self)
        w = Tkinter.Button(box, text="Submit", width=10, command=self.ok,
                           default=Tkinter.ACTIVE)
        w.pack(side=Tkinter.LEFT, padx=5, pady=5)
        w = Tkinter.Button(box, text="Cancel", width=10, command=self.cancel)
        w.pack(side=Tkinter.LEFT, padx=5, pady=5)
        box.pack()
        
    def apply(self):
        self.athena_name = self.e1.get()
        self.mit_id = self.e2.get()

# Let user know what server said
class MessageDialog(Dialog):
    def __init__(self, parent,message = '',title = None):
        self.message = message
        Dialog.__init__(self,parent,title=title)

    def body(self, master):
        l = Tkinter.Label(master, text=self.message,anchor=Tkinter.W,justify=Tkinter.LEFT)
        l.grid(row=0)

# return contents of file as a string
def file_contents(fname):
    # use universal mode to ensure cross-platform consistency in hash
    f = open(fname,'U')
    result = f.read()
    f.close()
    return result

import hashlib
def digest(s):
    m = hashlib.md5()
    m.update(s)
    return m.hexdigest()

# if verify(f) indicates points have been earned, submit results
# to server if requested to do so
import inspect,os,urllib,urllib2
def checkoff(f,task='???',submit=True):
    if task == 'L2_1':
        points = 1
    elif task == 'L2_2':
        points = 2
    elif task == 'L2_3':
        points = 1
    elif task == 'L2_4':
        points = verify_task4(f)
    elif task == 'L2_5':
        points = verify_task5(f)
    else:
        raise ValueError,"task must be one of L2_1, L2_2, L2_3, L2_4, L2_5"

    if submit and points:
        root = Tkinter.Tk(); #root.withdraw()
        error = None
        while submit:
            sd = SubmitDialog(root,error=error,title="Submit Task %s?"%task)
            if sd.athena_name:
                if isinstance(f,str): fname = os.path.abspath(f)
                else: fname = os.path.abspath(inspect.getsourcefile(f))
                post = {
                    'user': sd.athena_name,
                    'id': sd.mit_id,
                    'task': task,
                    'digest': digest(file_contents(os.path.abspath(inspect.getsourcefile(checkoff)))),
                    'points': points,
                    'filename': fname,
                    'file': file_contents(fname)
                    }
                try:
                    response = urllib2.urlopen('http://scripts.mit.edu/~6.02/currentsemester/submit_task.cgi',
                                               urllib.urlencode(post)).read()
                except Exception,e:
                    response = 'Error\n'+str(e)
                if response.startswith('Error\n'):
                    error = response[6:]
                else:
                    MessageDialog(root,message=response,title='Submission response')
                    break
            else: break

        root.destroy()