#Daniel Stoffregen djstoff15@earlham.edu
#a script to create gui to apply image analaysis filters to images
#credit source of NDVI code?
import matplotlib
matplotlib.use('TkAgg')
from tkinter import * 
from tkinter import filedialog

import  PIL
from PIL import Image, ImageTk, ImageEnhance

import getopt
import sys

import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import ticker
from matplotlib.colors import LinearSegmentedColormap

class NDVI(object):
    def __init__(self, file_path, output_file=False, colors=False):
        self.image = plt.imread(file_path)
        self.output_name = output_file or 'NDVI.jpg'
        # self.colors = colors or ['black', 'gray', 'red', 'yellow', 'green']
        self.colors = colors or ['red', 'orange', 'yellow', 'green', 'blue']

    def create_colormap(self, *args):
        return LinearSegmentedColormap.from_list(name='custom1', colors=args)

    def create_colorbar(self, fig, image):
        position = fig.add_axes([0.825, 0.19, 0.2, 0.05])
        norm = colors.Normalize(vmin=-1.0, vmax=1.0)
        cbar = plt.colorbar(image,
                            cax=position,
                            orientation='horizontal',
                            norm=norm)
        cbar.ax.tick_params(labelsize=20)
        tick_locator = ticker.MaxNLocator(nbins=3)
        cbar.locator = tick_locator
        cbar.update_ticks()
        cbar.set_label("NDVI", fontsize=20, x=0.5, y=0.5, labelpad=50)

    def convert(self):
        """
        This function performs the NDVI calculation and returns an GrayScaled frame with mapped colors)
        """
        NIR = (self.image[:, :, 0]).astype('float')
        blue = (self.image[:, :, 2]).astype('float')
        green = (self.image[:, :, 1]).astype('float')
        bottom = (blue - green) ** 2
        bottom[bottom == 0] = 1  # replace 0 from nd.array with 1
        VIS = (blue + green) ** 2 / bottom
        NDVI = (NIR - VIS) / (NIR + VIS)

        fig, ax = plt.subplots(figsize=(25,25))
        image = ax.imshow(NDVI, cmap=self.create_colormap(*self.colors))
        plt.axis('off')

        self.create_colorbar(fig, image)

        extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        fig.savefig(self.output_name, dpi=600, transparent=True, bbox_inches=extent, pad_inches=0)
        # plt.show()

#VARI (u.g-u.r)/(u.r+u.g-u.b+0.001)
class VARI(object):
    def __init__(self, file_path, output_file=False, colors=False):
        self.image = plt.imread(file_path)
        self.output_name = output_file or 'VARI.jpg'
        self.colors = colors or ['red', 'orange', 'yellow', 'green', 'blue']

    def create_colormap_VARI(self, *args):
        return LinearSegmentedColormap.from_list(name='custom1', colors=args)

    def create_colorbar_VARI(self, fig, image):
        position = fig.add_axes([0.825, 0.19, 0.2, 0.05])
        norm = colors.Normalize(vmin=-1.0, vmax=1.0)
        cbar = plt.colorbar(image,
                            cax=position,
                            orientation='horizontal',
                            norm=norm)
        cbar.ax.tick_params(labelsize=20)
        tick_locator = ticker.MaxNLocator(nbins=3)
        cbar.locator = tick_locator
        cbar.update_ticks()
        cbar.set_label("VARI", fontsize=20, x=0.5, y=0.5, labelpad=50) 

    def convert_VARI(self):
        red = (self.image[:, :, 0]).astype('float')
        blue = (self.image[:, :, 2]).astype('float')
        green = (self.image[:, :, 1]).astype('float')
        VARI = (green - red) / (red + green - blue + .001)

        fig, ax = plt.subplots(figsize=(25,25))
        image = ax.imshow(VARI, cmap=self.create_colormap_VARI(*self.colors))
        plt.axis('off')

        self.create_colorbar_VARI(fig, image)

        extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        fig.savefig(self.output_name, dpi=600, transparent=True, bbox_inches=extent, pad_inches=0)
        # plt.show()

#TGI (u.g-(0.39*u.r)-(0.61*u.b))
class TGI(object):
    def __init__(self, file_path, output_file=False, colors=False):
        self.image = plt.imread(file_path)
        self.output_name = output_file or 'VARI.jpg'
        self.colors = colors or ['red', 'orange', 'yellow', 'green', 'blue']

    def create_colormap_TGI(self, *args):
        return LinearSegmentedColormap.from_list(name='custom1', colors=args)

    def create_colorbar_TGI(self, fig, image):
        position = fig.add_axes([0.825, 0.19, 0.2, 0.05])
        norm = colors.Normalize(vmin=-1.0, vmax=1.0)
        cbar = plt.colorbar(image,
                            cax=position,
                            orientation='horizontal',
                            norm=norm)
        cbar.ax.tick_params(labelsize=20)
        tick_locator = ticker.MaxNLocator(nbins=3)
        cbar.locator = tick_locator
        cbar.update_ticks()
        cbar.set_label("TGI", fontsize=20, x=0.5, y=0.5, labelpad=50) 

    def convert_TGI(self):
        red = (self.image[:, :, 0]).astype('float')
        blue = (self.image[:, :, 2]).astype('float')
        green = (self.image[:, :, 1]).astype('float')
        VARI = (green - (.39 *red)) - (.61 * blue)

        fig, ax = plt.subplots(figsize=(25,25))
        image = ax.imshow(VARI, cmap=self.create_colormap_TGI(*self.colors))
        plt.axis('off')

        self.create_colorbar_TGI(fig, image)

        extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        fig.savefig(self.output_name, dpi=600, transparent=True, bbox_inches=extent, pad_inches=0)

class GUI:
    def __init__(self, master):
        """
        Creates a GUI with a button for image input, an input for the name 
        of the created image, and a close button.
        The class has several interal objects, select_algo, threshold, inIm,
        and outIm. Select_algo is a trianary switch, defaulted to 0. the 
        function algo has the options for select_algo. 
        """
        self.master = master
        #select_algo is set as a TK int to go with the radio button, any 
        #calls to it need select_algo.get() to unwrap the int value
        self.select_algo = IntVar()
        self.blue = 0
        self.green = 0
        self.red = 0
        master.title("NDVI Converter")

        self.label = Label(master, text="Select an Image, the algorithm to apply to the image, and a name you want to create a new image with")
        self.label.pack()

        self.inIm_button = Button(master, text="Input Image", command=self.inIm)
        self.inIm_button.pack()

        #this call initialzes a radio button 
        self.algo()

        self.outname = Label( text="New File Name. Include file extension, acceptable extensions include .jpg and .tif")
        self.outname.pack()
        self.newname = Entry(master)
        self.newname.pack()

        self.convert_button = Button(master, text="convert", command=self.con_button)
        self.convert_button.pack()

        self.close_button = Button(master, text="Close", command=master.quit)
        self.close_button.pack()

        self.display_space = Label()

    def inIm(self):
        """
        Reads in an image. askopenfilename asks to open certain file types. 
        The list can be extended, but make sure the the file type has support
        from the PIL.
        """
        self.inIm =  root.filename = filedialog.askopenfilename(initialdir = ".",title = "Select file",filetypes = (("jpg files","*.jpg *.tif *.DNG *.png"),("all files","*.*")))
        
        im = PIL.Image.open(self.inIm)
        im = im.resize((250, 250), resample=Image.NEAREST)
        photo = PIL.ImageTk.PhotoImage(im)

        label = Label( image=photo)
        label.image = photo  # keep a reference!
        label.pack(side = LEFT)
        print(self.inIm)

    def algo(self):
        #creates a radio button with algorithm options, and sets the select_algo varaible
        algorithms = [("NDVI",1),("VARI",2),("VGI",3)]
        for val, algorithm in enumerate(algorithms):
            Radiobutton(root, text=algorithm, padx = 20, value=val, variable = self.select_algo).pack()

    def con_button(self):
        """
        This is very meaty both by intention and accident. With the sliders
        generated by this event, every time the button is pressed the output
        recieves it's own set of sliders. Allows for multiple conversions
        to be analyzed in the same instance.

        Primarily this function takes the selected image, output name, and
        algorithm, then applies the conversion. The conversion is saved and
        put to the window.
        """
        inputFileName = self.inIm
        output_file = str(self.newname.get())
        colors = False
        if self.select_algo.get() == 0:
            blue_ndvi = NDVI(inputFileName, output_file, colors=colors or False)
            blue_ndvi.convert()
        
        elif self.select_algo.get() == 1:
            blue_vari = VARI(inputFileName, output_file, colors=colors or False)
            blue_vari.convert_VARI()
        
        elif self.select_algo.get() == 2:
            blue_tgi = TGI(inputFileName, output_file, colors=colors or False)
            blue_tgi.convert_TGI()
        
        displayIm = PIL.Image.open(output_file)
        x = displayIm.resize((250, 250), resample=Image.NEAREST)
        photo = PIL.ImageTk.PhotoImage(x)

        label = Label( image=photo)
        label.image = photo  # keep a reference!
        label.pack(side = RIGHT)

        """
        These are all resized to reduce processing time, they are not the final
        output. 
        """

        colordifIm = PIL.ImageEnhance.Color(x)
        contrastdifIm = PIL.ImageEnhance.Contrast(x)
        brightnessdifIm = PIL.ImageEnhance.Brightness(x)
        sharpnessdifIm = PIL.ImageEnhance.Sharpness(x)

        test_change = "enhanced"+output_file
        
        def save(*args):
            """
            PIL.ImageEnhance doesn't seem to let you set CBCS at the same time
            so this workaround sets those values one at a time on a full sized
            image, based on the current position of the sliders.
            """
            outFull = PIL.Image.open(output_file)
            colorDifConverter = ImageEnhance.Color(outFull)
            outFull_a = colorDifConverter.enhance(self.colordif_slider.get())

            brightnessConverter = ImageEnhance.Brightness(outFull_a)
            outFull_b = brightnessConverter.enhance(self.brightness_slider.get())

            contrastConverter = ImageEnhance.Contrast(outFull_b)
            outFull_c = contrastConverter.enhance(self.contrast_slider.get())

            sharpnessConverter = ImageEnhance.Sharpness(outFull_c)
            outFull_final = sharpnessConverter.enhance(self.sharpness_slider.get())

            outFull_final.save("final_enhanced_"+output_file)
            

        self.save_button = Button (text="save changes", command=save)
        self.save_button.pack()

        def put_to_screen(*args):
            #Puts the changes made by a slider to the screen.
            enhanced_photo = PIL.Image.open(test_change)

            ep_data = enhanced_photo.getdata()
            colorChange = [(pixel[0] + self.red, pixel[1] + self.green, pixel[2] + self.blue) for pixel in ep_data]
            enhanced_photo.putdata(colorChange)

            labelx2 = ImageTk.PhotoImage(enhanced_photo)
            
            try:
                self.display_space.destroy()
            except:
                print("Doesn't exist yet")

            self.display_space = Label(image=labelx2, text="enhanced")
            self.display_space.image=labelx2
            self.display_space.pack(side = RIGHT, padx=5, pady=5)

        def colorDif(*args):
            #sets the Imageenhance.Color value
            colordifIm.enhance(self.colordif_slider.get()).save("enhanced"+output_file)
            put_to_screen()

        def contrastDif(*args):
            #sets the Imageenhance.Contrast value
            contrastdifIm.enhance(self.contrast_slider.get()).save("enhanced"+output_file)
            put_to_screen()

        def brightnessDif(*args):
            #sets the Imageenhance.Brightness value
            brightnessdifIm.enhance(self.brightness_slider.get()).save("enhanced"+output_file)
            put_to_screen()

        def sharpnessDif(*args):
            #sets the Imageenhance.Sharpness value
            sharpnessdifIm.enhance(self.sharpness_slider.get()).save("enhanced"+output_file)
            put_to_screen()

        #color changers
        def adjRed(*args):
            #sets the red value
            self.red = self.red_slider.get()
            put_to_screen()

        def adjGreen(*args):
            #sets the green value
            self.green = self.green_slider.get()
            put_to_screen()

        def adjBlue(*args):
            #sets the blue value
            self.blue = self.blue_slider.get()
            put_to_screen()

        self.colordif_slider = Scale (root, from_=0.0, to=3.0, command=colorDif, orient="horizontal", label="colordif", digits=3, resolution=0.01)
        self.colordif_slider.set(1.0)
        self.colordif_slider.pack(side = LEFT)

        self.contrast_slider = Scale (root, from_=0.0, to=1.0, command=contrastDif, orient="horizontal", label="contrastdif", digits=3, resolution=0.01)
        self.contrast_slider.set(1.0)
        self.contrast_slider.pack(side = LEFT)

        self.brightness_slider = Scale (root, from_=0.0, to=1.0, command=brightnessDif, orient="horizontal", label="brightnessrdif", digits=3, resolution=0.01)
        self.brightness_slider.set(1.0)
        self.brightness_slider.pack()

        self.sharpness_slider = Scale (root, from_=0.0, to=2.0, command=sharpnessDif, orient="horizontal", label="sharpnessdif", digits=3, resolution=0.01)
        self.sharpness_slider.set(1.0)
        self.sharpness_slider.pack()
        #color sliders
        self.red_slider = Scale (root, from_=-255, to=255, command=adjRed, orient="horizontal", label="red")
        self.red_slider.set(0)
        self.red_slider.pack()

        self.green_slider = Scale (root, from_=-255, to=255, command=adjGreen, orient="horizontal", label="green")
        self.green_slider.set(0)
        self.green_slider.pack()

        self.blue_slider = Scale (root, from_=-255, to=255, command=adjBlue, orient="horizontal", label="blue")
        self.blue_slider.set(0)
        self.blue_slider.pack()
        
        def reset(*args):
            self.colordif_slider.set(1.0)
            self.contrast_slider.set(1.0)
            self.brightness_slider.set(1.0)
            self.sharpness_slider.set(1.0)
            self.red_slider.set(0)
            self.green_slider.set(0)
            self.blue_slider.set(0)
            put_to_screen()

        self.reset_button = Button(text="Reset sliders values", command=reset)
        self.reset_button.pack()
        


root = Tk()
my_gui = GUI(root)
root.mainloop()