digitalstain.py 2.33 KB
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 25 16:28:37 2017

@author: david
"""

import hyperspectral
import envi
import classify
import numpy
import scipy
import scipy.misc
import sklearn
import sklearn.naive_bayes
import sklearn.neural_network
import glob
import matplotlib.pyplot as plt
import random

def generate_stain(envifile, stainfile, maskfile="", trainmask="", N=5000, batch_size=10000, validate=True):
    if trainmask == "":
        E = envi.envi(envifile)
    else:
        mask = scipy.misc.imread(trainmask, flatten=True)
        E = envi.envi(envifile, mask=mask)
        
    mask = classify.random_mask(E.mask, N)
    scipy.misc.imsave("random.bmp", mask)
    
    Ft = E.loadmask(mask).transpose()

    stain = numpy.rollaxis(scipy.misc.imread(stainfile), 2)
    Tt = hyperspectral.sift2(stain, mask).transpose()

    print("Training MLPRegressor...")
    CLASS = sklearn.neural_network.MLPRegressor(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(), random_state=1, verbose=True)
    CLASS.fit(Ft, Tt)

    if validate == False:
        return CLASS
    
    print("Validating Stain...")
    plt.ion()    
    if not maskfile == "":
        E.close()                                                                  #close the ENVI file
        mask = scipy.misc.imread(maskfile, flatten=True)
        print(numpy.count_nonzero(mask))
        E = envi.envi(envifile, mask=mask)
    
    Fv = E.loadbatch(batch_size)                                                #load the first batch
    n = 0
    while not Fv == []:                                                         #loop until an empty batch is returned
        if n == 0:
            Tv = CLASS.predict(Fv.transpose()).transpose()
        else:
            Tv = numpy.append(Tv, CLASS.predict(Fv.transpose()).transpose(), 1)                        #append the predicted labels from this batch to those of previous batches
        COLORS = hyperspectral.unsift2(Tv, E.batchmask())                                    #convert the matrix of class labels to a 2D array
        RGB = numpy.rollaxis(COLORS, 0, 3).astype(numpy.ubyte)
        plt.imshow(RGB)                                                             #display it
        plt.pause(0.05)
        Fv = E.loadbatch(batch_size)                                                         #load the next batch
        n = n + 1
    return CLASS, RGB