#Daniel Stoffregen djstoff15@earlham.edu
#a script to create gui to apply image analaysis filters to images
#credit source of NDVI code?
import matplotlib
matplotlib.use('TkAgg')
from Tkinter import * 
import Tkinter, Tkconstants, tkFileDialog
import  PIL
from PIL import Image, ImageTk

import getopt
import sys
import cv2

import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import ticker
from matplotlib.colors import LinearSegmentedColormap

class NDVI(object):
    def __init__(self, file_path, output_file=False, colors=False):
        self.image = plt.imread(file_path)
        self.output_name = output_file or 'NDVI.jpg'
        # self.colors = colors or ['black', 'gray', 'red', 'yellow', 'green']
        self.colors = colors or ['red', 'orange', 'yellow', 'green', 'blue']

    def create_colormap(self, *args):
        return LinearSegmentedColormap.from_list(name='custom1', colors=args)

    def create_colorbar(self, fig, image):
        position = fig.add_axes([0.825, 0.19, 0.2, 0.05])
        norm = colors.Normalize(vmin=-1.0, vmax=1.0)
        cbar = plt.colorbar(image,
                            cax=position,
                            orientation='horizontal',
                            norm=norm)
        cbar.ax.tick_params(labelsize=20)
        tick_locator = ticker.MaxNLocator(nbins=3)
        cbar.locator = tick_locator
        cbar.update_ticks()
        cbar.set_label("NDVI", fontsize=20, x=0.5, y=0.5, labelpad=50)

    def convert(self):
        """
        This function performs the NDVI calculation and returns an GrayScaled frame with mapped colors)
        """
        NIR = (self.image[:, :, 0]).astype('float')
        blue = (self.image[:, :, 2]).astype('float')
        green = (self.image[:, :, 1]).astype('float')
        bottom = (blue - green) ** 2
        bottom[bottom == 0] = 1  # replace 0 from nd.array with 1
        VIS = (blue + green) ** 2 / bottom
        NDVI = (NIR - VIS) / (NIR + VIS)

        fig, ax = plt.subplots(figsize=(25,25))
        image = ax.imshow(NDVI, cmap=self.create_colormap(*self.colors))
        plt.axis('off')

        self.create_colorbar(fig, image)

        extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        fig.savefig(self.output_name, dpi=600, transparent=True, bbox_inches=extent, pad_inches=0)
        # plt.show()

#VARI (u.g-u.r)/(u.r+u.g-u.b+0.001)
class VARI(object):
    def __init__(self, file_path, output_file=False, colors=False):
        self.image = plt.imread(file_path)
        self.output_name = output_file or 'VARI.jpg'
        self.colors = colors or ['red', 'orange', 'yellow', 'green', 'blue']

    def create_colormap_VARI(self, *args):
        return LinearSegmentedColormap.from_list(name='custom1', colors=args)

    def create_colorbar_VARI(self, fig, image):
        position = fig.add_axes([0.825, 0.19, 0.2, 0.05])
        norm = colors.Normalize(vmin=-1.0, vmax=1.0)
        cbar = plt.colorbar(image,
                            cax=position,
                            orientation='horizontal',
                            norm=norm)
        cbar.ax.tick_params(labelsize=20)
        tick_locator = ticker.MaxNLocator(nbins=3)
        cbar.locator = tick_locator
        cbar.update_ticks()
        cbar.set_label("VARI", fontsize=20, x=0.5, y=0.5, labelpad=50) 

    def convert_VARI(self):
        red = (self.image[:, :, 0]).astype('float')
        blue = (self.image[:, :, 2]).astype('float')
        green = (self.image[:, :, 1]).astype('float')
        VARI = (green - red) / (red + green - blue + .001)

        fig, ax = plt.subplots(figsize=(25,25))
        image = ax.imshow(VARI, cmap=self.create_colormap_VARI(*self.colors))
        plt.axis('off')

        self.create_colorbar_VARI(fig, image)

        extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        fig.savefig(self.output_name, dpi=600, transparent=True, bbox_inches=extent, pad_inches=0)
        # plt.show()

#TGI (u.g-(0.39*u.r)-(0.61*u.b))
class TGI(object):
    def __init__(self, file_path, output_file=False, colors=False):
        self.image = plt.imread(file_path)
        self.output_name = output_file or 'VARI.jpg'
        self.colors = colors or ['red', 'orange', 'yellow', 'green', 'blue']

    def create_colormap_TGI(self, *args):
        return LinearSegmentedColormap.from_list(name='custom1', colors=args)

    def create_colorbar_TGI(self, fig, image):
        position = fig.add_axes([0.825, 0.19, 0.2, 0.05])
        norm = colors.Normalize(vmin=-1.0, vmax=1.0)
        cbar = plt.colorbar(image,
                            cax=position,
                            orientation='horizontal',
                            norm=norm)
        cbar.ax.tick_params(labelsize=20)
        tick_locator = ticker.MaxNLocator(nbins=3)
        cbar.locator = tick_locator
        cbar.update_ticks()
        cbar.set_label("TGI", fontsize=20, x=0.5, y=0.5, labelpad=50) 

    def convert_TGI(self):
        red = (self.image[:, :, 0]).astype('float')
        blue = (self.image[:, :, 2]).astype('float')
        green = (self.image[:, :, 1]).astype('float')
        VARI = (green - (.39 *red)) - (.61 * blue)

        fig, ax = plt.subplots(figsize=(25,25))
        image = ax.imshow(VARI, cmap=self.create_colormap_TGI(*self.colors))
        plt.axis('off')

        self.create_colorbar_TGI(fig, image)

        extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        fig.savefig(self.output_name, dpi=600, transparent=True, bbox_inches=extent, pad_inches=0)

class GUI:
    def __init__(self, master):
        """
        Creates a GUI with a button for image input, a slider for a threshold 
        mask, an input for the name of the created image, and a close button.
        The class has several interal objects, select_algo, threshold, inIm,
        and outIm. Select_algo is a trianary switch, defaulted to 0. the 
        function algo has the options for select_algo. 
        """
        self.master = master
        #select_algo is set as a TK int to go with the radio button, any 
        #calls to it need select_algo.get() to unwrap the int value
        self.select_algo = IntVar()
        self.red_adj = IntVar()
        self.blue_adj = IntVar()
        self.green_adj = IntVar()
        master.title("NDVI Converter")

        self.label = Label(master, text="Select an Image, the algorithm to apply to the image, and a name you want to create a new image with")
        self.label.pack()

        self.inIm_button = Button(master, text="JPG Input Image", command=self.inIm)
        self.inIm_button.pack()

        self.inIMDNG_button = Button(master, text="DNG Input Image", command=self.inImDNG)
        self.inIMDNG_button.pack()

        #this call initialzes a radio button 
        self.algo()
        """
        #inits sliders
        self.red_slide = Scale (root, from_=0, to=255, command=self.adj_red, orient="horizontal", label="Red")
        self.red_slide.pack()

        self.blue_slide = Scale (root, from_=0, to=255, command=self.adj_blue, orient="horizontal", label="Blue")
        self.blue_slide.pack()

        self.green_slide = Scale (root, from_=0, to=255, command=self.adj_green, orient="horizontal", label="Green")
        self.green_slide.pack()
        """

        self.outname = Label( text="New File Name. Include file extension, acceptable extensions include .jpg and .tif")
        self.outname.pack()
        self.newname = Entry(master)
        self.newname.pack()

        self.test_button = Button(master, text="test", command=self.adj_red)
        self.test_button.pack()

        self.convert_button = Button(master, text="convert", command=self.con_button)
        self.convert_button.pack()

        self.close_button = Button(master, text="Close", command=master.quit)
        self.close_button.pack()

    def inIm(self):
        self.inIm =  root.filename = tkFileDialog.askopenfilename(initialdir = ".",title = "Select file",filetypes = (("jpg files","*.jpg"),("all files","*.*")))
        

        im = PIL.Image.open(self.inIm)
        im = im.resize((250, 250), resample=Image.NEAREST)
        photo = PIL.ImageTk.PhotoImage(im)

        label = Label( image=photo)
        label.image = photo  # keep a reference!
        label.pack(side = LEFT)
        print(self.inIm)

    def inImDNG(self):
        self.inIm =  root.filename = tkFileDialog.askopenfilename(initialdir = ".",title = "Select file",filetypes = (("DNG files","*.DNG"),("all files","*.*")))
        

        im = PIL.Image.open(self.inIm)
        im = im.resize((250, 250), resample=Image.NEAREST)
        photo = PIL.ImageTk.PhotoImage(im)

        label = Label( image=photo)
        label.image = photo  # keep a reference!
        label.pack(side = LEFT)
        print(self.inIm)

    def algo(self):
        algorithms = [("NDVI",1),("VARI",2),("VGI",3)]
        for val, algorithm in enumerate(algorithms):
            Tkinter.Radiobutton(root, text=algorithm, padx = 20, value=val, variable = self.select_algo).pack()


    def adj_red(self, *args):
        self.red_adj = self.red_slide

    def adj_blue(self, *args):
        self.blue_adj = self.blue_adj

    def adj_green(self, *args):
        self.green_adj = self.green_slide.get()

    def con_button(self):
        inputFileName = self.inIm
        output_file = str(self.newname.get())
        colors = False
        if self.select_algo.get() == 0:
            blue_ndvi = NDVI(inputFileName, output_file, colors=colors or False)
            blue_ndvi.convert()
        
        elif self.select_algo.get() == 1:
            blue_vari = VARI(inputFileName, output_file, colors=colors or False)
            blue_vari.convert_VARI()
        
        elif self.select_algo.get() == 2:
            blue_tgi = TGI(inputFileName, output_file, colors=colors or False)
            blue_tgi.convert_TGI()
        
        displayIm = PIL.Image.open(output_file)
        x = displayIm.resize((250, 250), resample=Image.NEAREST)
        photo = PIL.ImageTk.PhotoImage(x)

        label = Label( image=photo)
        label.image = photo  # keep a reference!
        label.pack(side = RIGHT)

        """
        test_change1 = plt.imread(output_file)
        test_change = test_change1.copy()
        #inits sliders
        self.red_slide = Scale (root, from_=0, to=255, command=self.adj_red, orient="horizontal", label="Red")
        self.red_slide.pack()

        self.blue_slide = Scale (root, from_=0, to=255, command=self.adj_blue, orient="horizontal", label="Blue")
        self.blue_slide.pack()

        self.green_slide = Scale (root, from_=-255, to=255, command=self.adj_green, orient="horizontal", label="Green")
        self.green_slide.pack()
        
        green = (test_change[:, :, 1]).astype('float')
        test_change[:,:,1] = green + self.green_adj.get()
        test_change.show()
        """
        
        


root = Tk()
my_gui = GUI(root)
root.mainloop()