# ImageProcessing.py
# A collection of image processing functions
from Processing import *

# Compute the "lightness" of a color
def lightness( c ):
    r = red(c)
    g = green(c)
    b = blue(c)
    return 0.5*(max(r, g, b) + min(r, g, b))

# Compute average of three color components
def average( c ):
    return (red(c) + green(c) + blue(c))/3.0

# Compute luminance with model used for HDTV
def luminance( c ):
    return 0.2126 * red(c) + 0.7152 * green(c) + 0.0722 * blue(c)

# Convert the pixel in img1 to img2
def grayscale(img1, img2):
    img1.loadPixels()
    img2.loadPixels()

    w, h = int(img1.width()), int(img1.height())

    for i in range(w):
        for j in range(h):
            c = img1.getPixel(i, j) # Get the color at i, j

            #gray = lightness(c)     # Convert using lightness
            #gray = average(c)       # Convert using average
            gray = luminance(c)     # Convert using luminance

            img2.setPixel(i, j, color(gray))

    img2.updatePixels()

# Perform the threshold function
def threshold(cutoff, img1, img2):
    img1.loadPixels()
    img2.loadPixels()

    w, h = int(img1.width()), int(img1.height())

    for i in range(w):
        for j in range(h):
            c = img1.getPixel(i, j)             # Get the color at i, j
            gray = luminance(c)                 # Convert the color to grayscale
            if gray >= cutoff:                  # Compute threshold
                gray = 255                      # white if above cutoff
            else:
                gray = 0                        # black if below cutoff

            img2.setPixel(i, j, color(gray))    # Set output color to threshold value

    img2.updatePixels()

# Compute negative
def negative( img1, img2 ):
    img1.loadPixels()
    img2.loadPixels()

    w, h = int(img1.width()), int(img1.height())

    for i in range(w):
        for j in range(h):
            c = img1.getPixel(i, j)             # Get the color at i, j
            c = color(255-red(c), 255-green(c), 255-blue(c))
            img2.setPixel(i, j, c)              # Set to negative

    img2.updatePixels()

# Sepia filter
def sepia( img1, img2 ):
    img1.loadPixels()
    img2.loadPixels()

    w, h = int(img1.width()), int(img1.height())

    for i in range(w):
        for j in range(h):
            c = img1.getPixel(i, j)
            r = int( red(c)*0.393 + green(c)*0.769 + blue(c)*0.189 )
            g = int( red(c)*0.349 + green(c)*0.686 + blue(c)*0.168 )
            b = int( red(c)*0.272 + green(c)*0.534 + blue(c)*0.131 )
            r = constrain( r, 0, 255 )
            g = constrain( g, 0, 255 )
            b = constrain( b, 0, 255 )
            c = color(r, g, b)
            img2.setPixel(i, j, c)              # Set to sepia

    img2.updatePixels()

# Perform spatial filtering on one pixel location
def spatial( matrix, img1, img2 ):

    # Sample filter matrix
    
    # Sharpen
    #matrix = [[ -1., -1., -1.],
    #          [ -1.,  9., -1.],
    #          [ -1., -1., -1. ]]
    
    # Laplacian Edge Detection
    #matrix = [[ 0.,  1.,  0. ],
    #          [ 1., -4.,  1. ],
    #          [ 0.,  1.,  0. ]]
    
    # Average
    #matrix = [[ 1./9., 1./9., 1./9. ],
    #          [ 1./9., 1./9., 1./9. ],
    #          [ 1./9., 1./9., 1./9. ]]
    
    # Gaussian Blur
    #matrix = [[ 1./16., 2./16., 1./16. ],
    #          [ 2./16., 4./16., 2./16. ],
    #          [ 1./16., 2./16., 1./16. ]]

    img1.loadPixels()
    img2.loadPixels()

    w, h = int(img1.width()), int(img1.height())

    for i in range(1, w-1):
        for j in range(1, h-1):

            rtotal, gtotal, btotal = 0.0, 0.0, 0.0

            # Loop through filter matrix
            for c in range(3):
                for r in range(3):
                    # Get the weight position in the filter
                    cc = i + c - 1
                    rr = j + r - 1

                    # Apply the filter
                    tc = img1.getPixel(cc, rr)
                    mul = matrix[c][r]
                    rtotal += red(tc) * mul
                    gtotal += green(tc) * mul
                    btotal += blue(tc) * mul

            # Make sure RGB is within range
            rtotal = constrain(rtotal,0,255)
            gtotal = constrain(gtotal,0,255)
            btotal = constrain(btotal,0,255)

            # Resulting color
            nc = color(rtotal, gtotal, btotal)
            img2.setPixel(i, j, nc)              # Set new color

    img2.updatePixels()

# Perform erosion on a pixel in img1 and save to img2
def erode(img1, img2):

    img1.loadPixels()
    img2.loadPixels()

    w, h = int(img1.width()), int(img1.height())

    for i in range(1, w-1):
        for j in range(1, h-1):

            # Init min luminance and color
            minlum = 255
            minclr = color(255)

            # Loop over analysis region
            for c in range(3):
                for r in range(3):
                    # Compute indexes of adjacent pixels
                    cc = i + c - 1
                    rr = j + r - 1

                    # Update if luminance is lower
                    clr = img1.getPixel(cc, rr)
                    lum = luminance( clr )
                    if lum < minlum:
                        minlum = lum
                        minclr = clr

            # Set minimum color in img2
            img2.setPixel(i, j, minclr)

    img2.updatePixels()

# Perform dilation on img1 and save to img2
def dilate(img1, img2):

    img1.loadPixels()
    img2.loadPixels()

    w, h = int(img1.width()), int(img1.height())

    for i in range(1, w-1):
        for j in range(1, h-1):

            # Init max luminance and color
            maxlum = 0
            maxclr = color(0)

            # Loop over analysis region
            for c in range(3):
                for r in range(3):
                    # Compute indexes of adjacent pixels
                    cc = i + c - 1
                    rr = j + r - 1

                    # Update if luminance is lower
                    clr = img1.getPixel(cc, rr)
                    lum = luminance( clr )
                    if lum > maxlum:
                        maxlum = lum
                        maxclr = clr

            # Set maximum color in img2
            img2.setPixel(i, j, maxclr)

    img2.updatePixels()
