Blame view

python/digitalstain.py 2.33 KB
ac5ae422   David Mayerich   added digital sta...
1
2
3
4
5
6
7
  # -*- coding: utf-8 -*-
  """
  Created on Tue Jul 25 16:28:37 2017
  
  @author: david
  """
  
9b3cbdda   David Mayerich   changed the name ...
8
  import hyperspectral
ac5ae422   David Mayerich   added digital sta...
9
10
11
12
13
14
15
16
17
18
19
20
  import envi
  import classify
  import numpy
  import scipy
  import scipy.misc
  import sklearn
  import sklearn.naive_bayes
  import sklearn.neural_network
  import glob
  import matplotlib.pyplot as plt
  import random
  
9b3cbdda   David Mayerich   changed the name ...
21
22
23
24
25
26
27
  def generate_stain(envifile, stainfile, maskfile="", trainmask="", N=5000, batch_size=10000, validate=True):
      if trainmask == "":
          E = envi.envi(envifile)
      else:
          mask = scipy.misc.imread(trainmask, flatten=True)
          E = envi.envi(envifile, mask=mask)
          
ac5ae422   David Mayerich   added digital sta...
28
      mask = classify.random_mask(E.mask, N)
9b3cbdda   David Mayerich   changed the name ...
29
      scipy.misc.imsave("random.bmp", mask)
ac5ae422   David Mayerich   added digital sta...
30
31
32
33
      
      Ft = E.loadmask(mask).transpose()
  
      stain = numpy.rollaxis(scipy.misc.imread(stainfile), 2)
9b3cbdda   David Mayerich   changed the name ...
34
      Tt = hyperspectral.sift2(stain, mask).transpose()
ac5ae422   David Mayerich   added digital sta...
35
  
9b3cbdda   David Mayerich   changed the name ...
36
      print("Training MLPRegressor...")
ac5ae422   David Mayerich   added digital sta...
37
38
39
40
41
42
      CLASS = sklearn.neural_network.MLPRegressor(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(), random_state=1, verbose=True)
      CLASS.fit(Ft, Tt)
  
      if validate == False:
          return CLASS
      
9b3cbdda   David Mayerich   changed the name ...
43
44
45
46
47
48
49
50
      print("Validating Stain...")
      plt.ion()    
      if not maskfile == "":
          E.close()                                                                  #close the ENVI file
          mask = scipy.misc.imread(maskfile, flatten=True)
          print(numpy.count_nonzero(mask))
          E = envi.envi(envifile, mask=mask)
      
ac5ae422   David Mayerich   added digital sta...
51
52
53
54
55
56
57
      Fv = E.loadbatch(batch_size)                                                #load the first batch
      n = 0
      while not Fv == []:                                                         #loop until an empty batch is returned
          if n == 0:
              Tv = CLASS.predict(Fv.transpose()).transpose()
          else:
              Tv = numpy.append(Tv, CLASS.predict(Fv.transpose()).transpose(), 1)                        #append the predicted labels from this batch to those of previous batches
9b3cbdda   David Mayerich   changed the name ...
58
          COLORS = hyperspectral.unsift2(Tv, E.batchmask())                                    #convert the matrix of class labels to a 2D array
ac5ae422   David Mayerich   added digital sta...
59
60
61
62
63
64
          RGB = numpy.rollaxis(COLORS, 0, 3).astype(numpy.ubyte)
          plt.imshow(RGB)                                                             #display it
          plt.pause(0.05)
          Fv = E.loadbatch(batch_size)                                                         #load the next batch
          n = n + 1
      return CLASS, RGB