# -*- coding: utf-8 -*-
Created on Fri May  4 17:15:07 2018

@author: david

import imagestack
import numpy
import matplotlib.pyplot as plt
import cv2
import progressbar
import scipy.ndimage
import sys
import os

def press(event):
    global i
    global ax
    global fig
    if event.key == 'z':
        i = i - 1
        if i < 0:
            i = 0
    if event.key == 'x':
        i = i + 1
        if i == S.shape[0]:
            i = S.shape[0]
    if event.key == "up":
        for n in range(i, S.shape[0]):
            warps[n][1, 2] = warps[n][1, 2] + 1
    if event.key == "down":
        for n in range(i, S.shape[0]):
            warps[n][1, 2] = warps[n][1, 2] - 1
    if event.key == "left":
        for n in range(i, S.shape[0]):
            warps[n][0, 2] = warps[n][0, 2] + 1
    if event.key == "right":
        for n in range(i, S.shape[0]):
            warps[n][0, 2] = warps[n][0, 2] - 1
    aligned = cv2.warpAffine(S[i, :, :, :], warps[i], (S.shape[2], S.shape[1]), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP);
    ax.set_title('Slice ' + str(i))

#apply the alignment adjustments to I based on the list of warp matrices in t
def apply_alignment(I, t):
    A = numpy.zeros(I.shape, dtype=numpy.uint8)
    for i in range(0, I.shape[0]):
        A[i, :, :, :] = cv2.warpAffine(I[i, :, :, :], t[i], (I.shape[2], I.shape[1]), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP);
    return A

#get the transform that aligns image B to image A
def align(A, B, max_power=5):
    # Define the motion model
    warp_mode = cv2.MOTION_TRANSLATION
    # Specify the number of iterations.
    number_of_iterations = 5000;
    # Specify the threshold of the increment
    # in the correlation coefficient between two iterations
    termination_eps = 1e-10;
    # Define termination criteria
    criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations,  termination_eps)
    #attempt to fit the image, increasing the blur as necessary
    i = 0
    while i <= max_power:    
        #blur the image used for fitting
        im1_blur = scipy.ndimage.filters.gaussian_filter(A, (2 ** i, 2 ** i), mode='constant')
        im2_blur = scipy.ndimage.filters.gaussian_filter(B, (2 ** i, 2 ** i), mode='constant')
        # Run the ECC algorithm. The results are stored in warp_matrix.
        # Define 2x3 matrix and initialize the matrix to identity
        warp_matrix = numpy.eye(2, 3, dtype=numpy.float32)
            (cc, warp_matrix) = cv2.findTransformECC (im1_blur,im2_blur,warp_matrix, warp_mode, criteria)
            #print("Error aligning at p = " + str(i))
            i = i + 1
            #print("Successful alignment at p = " + str(i))
    #enforce the fact that the x-axis is already aligned
    #warp_matrix[0, 2] = 0
    return warp_matrix
    #if i > 0:
    #    (cc, warp_matrix) = cv2.findTransformECC (A,B,warp_matrix, warp_mode, criteria)
        #warp_matrix[0, 2] = 0
    #    return warp_matrix

fmask = "Z:/jack/TinkParaffinLung0.005S/*.png"
out_dir = "Z:/jack/TinkParaffinLung0.005S/aligned"

if not os.path.isdir(out_dir):

#read the image stack for alignment
print("Loading image stack...")
S = imagestack.load(fmask, dtype=numpy.uint8)

#convert to grayscale
G = imagestack.rgb2gray(S.astype(numpy.float32))

print("\n\nCalculating alignment transformations...")
bar = progressbar.ProgressBar(max_value=G.shape[0])
#for each image in the stack, calculate a transformation matrix to align it to the previous image
warps = [numpy.eye(2, 3, dtype=numpy.float32)]
for i in range(1, G.shape[0]):
    warps.append(align(G[i-1, :, :], G[i, :, :]))

print("\n\nApplying alignment transformation...")
# Define 2x3 matrix and initialize the matrix to identity
for i in range(1, G.shape[0]):
    warps[i][0, 2] = warps[i][0, 2] + warps[i-1][0, 2]
    warps[i][1, 2] = warps[i][1, 2] + warps[i-1][1, 2]

i = 0
aligned = cv2.warpAffine(S[i, :, :, :], warps[i], (S.shape[2], S.shape[1]), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)

fig, ax = plt.subplots()

fig.canvas.mpl_connect('key_press_event', press)

ax.set_title('Slice ' + str(i))
ax.set_xlabel("Press 'z' and 'x' to change slices and arrow keys to align")

I = apply_alignment(S, warps), out_dir + "/aligned_", ".bmp")