Commit c8b1298401f3a2388c16aa30a0820d4f39f5522d
1 parent
eccf10ff
renamed spectral.py to stim_spectral.py; modified envi_batch_predict to return R…
…GB array in case user wants to save the fig
Showing
2 changed files
with
5 additions
and
56 deletions
Show diff stats
python/classify.py
@@ -12,7 +12,7 @@ import sklearn.metrics | @@ -12,7 +12,7 @@ import sklearn.metrics | ||
12 | import scipy | 12 | import scipy |
13 | import scipy.misc | 13 | import scipy.misc |
14 | import envi | 14 | import envi |
15 | -import spectral | 15 | +import stim_spectral |
16 | import random | 16 | import random |
17 | import progressbar | 17 | import progressbar |
18 | import matplotlib.pyplot as plt | 18 | import matplotlib.pyplot as plt |
@@ -176,11 +176,13 @@ def envi_batch_predict(E, C, batch=10000): | @@ -176,11 +176,13 @@ def envi_batch_predict(E, C, batch=10000): | ||
176 | else: | 176 | else: |
177 | Tv = numpy.concatenate((Tv, C.predict(Fv.transpose()).transpose()), 0) | 177 | Tv = numpy.concatenate((Tv, C.predict(Fv.transpose()).transpose()), 0) |
178 | tempmask = E.batchmask() | 178 | tempmask = E.batchmask() |
179 | - Lv = spectral.unsift2(Tv, tempmask) | 179 | + Lv = stim_spectral.unsift2(Tv, tempmask) |
180 | Cv = label2class(Lv.squeeze(), background=0) | 180 | Cv = label2class(Lv.squeeze(), background=0) |
181 | RGB = class2color(Cv) | 181 | RGB = class2color(Cv) |
182 | plt.imshow(RGB) | 182 | plt.imshow(RGB) |
183 | plt.pause(0.05) | 183 | plt.pause(0.05) |
184 | Fv = E.loadbatch(batch) | 184 | Fv = E.loadbatch(batch) |
185 | i = i + 1 | 185 | i = i + 1 |
186 | - bar.update(len(Tv)) | ||
187 | \ No newline at end of file | 186 | \ No newline at end of file |
187 | + bar.update(len(Tv)) | ||
188 | + | ||
189 | + return RGB | ||
188 | \ No newline at end of file | 190 | \ No newline at end of file |
python/spectral.py deleted
1 | -# -*- coding: utf-8 -*- | ||
2 | -""" | ||
3 | -Created on Sun Jul 23 13:52:22 2017 | ||
4 | - | ||
5 | -@author: david | ||
6 | -""" | ||
7 | -import numpy | ||
8 | - | ||
9 | -#sift a 2D hyperspectral image into a PxB matrix where P is the number of pixels and B is the number of bands | ||
10 | -def sift2(I, mask = []): | ||
11 | - | ||
12 | - #get the shape of the input array | ||
13 | - S = I.shape | ||
14 | - | ||
15 | - #convert that array into a 1D matrix | ||
16 | - M = numpy.reshape(I, (S[0], S[1] * S[2])) | ||
17 | - | ||
18 | - #gif no mask is provided, just return all pixels | ||
19 | - if mask == []: | ||
20 | - return M | ||
21 | - | ||
22 | - #if a mask is provided, only return pixels corresponding to that mask | ||
23 | - flatmask = numpy.reshape(mask, (S[1] * S[2])) | ||
24 | - i = numpy.flatnonzero(flatmask) #get the nonzero indices | ||
25 | - return M[:, i] #return pixels corresponding to the masked values | ||
26 | - | ||
27 | -def unsift2(M, mask): | ||
28 | - | ||
29 | - #get the size of the input matrix | ||
30 | - S = M.shape | ||
31 | - | ||
32 | - #count the number of nonzero values in the mask | ||
33 | - nnz = numpy.count_nonzero(mask) | ||
34 | - | ||
35 | - #the number of masked values should be the same as the number of pixels in the input matrix | ||
36 | - if len(S) == 1: | ||
37 | - if not S[0] == nnz: | ||
38 | - print("ERROR: expected " + str(nnz) + " pixels based on the mask but there are " + str(S[0]) + " in the matrix.") | ||
39 | - elif not S[1] == nnz: | ||
40 | - print("ERROR: expected " + str(nnz) + " pixels based on the mask but there are " + str(S[1]) + " in the matrix.") | ||
41 | - | ||
42 | - | ||
43 | - i = numpy.nonzero(mask) | ||
44 | - | ||
45 | - if len(S) == 1: | ||
46 | - I = numpy.zeros((1, mask.shape[0], mask.shape[1]), dtype=M.dtype) | ||
47 | - else: | ||
48 | - I = numpy.zeros((M.shape[0], mask.shape[0], mask.shape[1]), dtype=M.dtype) | ||
49 | - I[:, i[0], i[1]] = M | ||
50 | - return I | ||
51 | - | ||
52 | -#create a function that sifts a color image | ||
53 | -#input: image name, mask | ||
54 | \ No newline at end of file | 0 | \ No newline at end of file |