Commit c8b1298401f3a2388c16aa30a0820d4f39f5522d

Authored by sberisha
1 parent eccf10ff

renamed spectral.py to stim_spectral.py; modified envi_batch_predict to return R…

…GB array in case user wants to save the fig
Showing 2 changed files with 5 additions and 56 deletions   Show diff stats
python/classify.py
@@ -12,7 +12,7 @@ import sklearn.metrics @@ -12,7 +12,7 @@ import sklearn.metrics
12 import scipy 12 import scipy
13 import scipy.misc 13 import scipy.misc
14 import envi 14 import envi
15 -import spectral 15 +import stim_spectral
16 import random 16 import random
17 import progressbar 17 import progressbar
18 import matplotlib.pyplot as plt 18 import matplotlib.pyplot as plt
@@ -176,11 +176,13 @@ def envi_batch_predict(E, C, batch=10000): @@ -176,11 +176,13 @@ def envi_batch_predict(E, C, batch=10000):
176 else: 176 else:
177 Tv = numpy.concatenate((Tv, C.predict(Fv.transpose()).transpose()), 0) 177 Tv = numpy.concatenate((Tv, C.predict(Fv.transpose()).transpose()), 0)
178 tempmask = E.batchmask() 178 tempmask = E.batchmask()
179 - Lv = spectral.unsift2(Tv, tempmask) 179 + Lv = stim_spectral.unsift2(Tv, tempmask)
180 Cv = label2class(Lv.squeeze(), background=0) 180 Cv = label2class(Lv.squeeze(), background=0)
181 RGB = class2color(Cv) 181 RGB = class2color(Cv)
182 plt.imshow(RGB) 182 plt.imshow(RGB)
183 plt.pause(0.05) 183 plt.pause(0.05)
184 Fv = E.loadbatch(batch) 184 Fv = E.loadbatch(batch)
185 i = i + 1 185 i = i + 1
186 - bar.update(len(Tv))  
187 \ No newline at end of file 186 \ No newline at end of file
  187 + bar.update(len(Tv))
  188 +
  189 + return RGB
188 \ No newline at end of file 190 \ No newline at end of file
python/spectral.py deleted
1 -# -*- coding: utf-8 -*-  
2 -"""  
3 -Created on Sun Jul 23 13:52:22 2017  
4 -  
5 -@author: david  
6 -"""  
7 -import numpy  
8 -  
9 -#sift a 2D hyperspectral image into a PxB matrix where P is the number of pixels and B is the number of bands  
10 -def sift2(I, mask = []):  
11 -  
12 - #get the shape of the input array  
13 - S = I.shape  
14 -  
15 - #convert that array into a 1D matrix  
16 - M = numpy.reshape(I, (S[0], S[1] * S[2]))  
17 -  
18 - #gif no mask is provided, just return all pixels  
19 - if mask == []:  
20 - return M  
21 -  
22 - #if a mask is provided, only return pixels corresponding to that mask  
23 - flatmask = numpy.reshape(mask, (S[1] * S[2]))  
24 - i = numpy.flatnonzero(flatmask) #get the nonzero indices  
25 - return M[:, i] #return pixels corresponding to the masked values  
26 -  
27 -def unsift2(M, mask):  
28 -  
29 - #get the size of the input matrix  
30 - S = M.shape  
31 -  
32 - #count the number of nonzero values in the mask  
33 - nnz = numpy.count_nonzero(mask)  
34 -  
35 - #the number of masked values should be the same as the number of pixels in the input matrix  
36 - if len(S) == 1:  
37 - if not S[0] == nnz:  
38 - print("ERROR: expected " + str(nnz) + " pixels based on the mask but there are " + str(S[0]) + " in the matrix.")  
39 - elif not S[1] == nnz:  
40 - print("ERROR: expected " + str(nnz) + " pixels based on the mask but there are " + str(S[1]) + " in the matrix.")  
41 -  
42 -  
43 - i = numpy.nonzero(mask)  
44 -  
45 - if len(S) == 1:  
46 - I = numpy.zeros((1, mask.shape[0], mask.shape[1]), dtype=M.dtype)  
47 - else:  
48 - I = numpy.zeros((M.shape[0], mask.shape[0], mask.shape[1]), dtype=M.dtype)  
49 - I[:, i[0], i[1]] = M  
50 - return I  
51 -  
52 -#create a function that sifts a color image  
53 -#input: image name, mask  
54 \ No newline at end of file 0 \ No newline at end of file