import numpy
import matplotlib
from matplotlib import pyplot as p
from numpy.linalg import norm as norm

def im1dread(filename):
    # Read in image file and convert to 1d luminance numpy array
    imI = p.imread(filename + ".png")
    if(len(imI.shape) > 2):
        assert imI.shape[2] == 3 or imI.shape[2] == 4, "Image format unsupported"
        if imI.shape[2] == 4:
            imI1d = imI.reshape((-1,4))
        elif imI.shape[2] == 3:
            imI1d = imI.reshape((-1,3))
        imO1d = numpy.array([norm(imI1d[i,:3]) for i in range(len(imI1d))])

    else:
        imO1d = imI.reshape(-1)
    imO1d /= max(imO1d)
    return imO1d
    
def im1dshow(image1d, rows = 512, clip = True):
    # Test the shape
    assert len(image1d) % rows == 0, "Number of rows not equal to %d" % rows
    p.gray()
    p_im = image1d.copy()
    if clip:
        numpy.clip(p_im, 0.0, 1.0)
    else:
        p_im /= max(abs(p_im))
    p.imshow(p_im.reshape((rows,-1)))