Commit 59f31ea4d0e84e7381294b0e49b3db9b26604fc2

Authored by David Mayerich
1 parent c6ad7004

fixed bugs in envi.py and added an envi classification function to classify.py

Showing 2 changed files with 33 additions and 9 deletions   Show diff stats
python/classify.py
... ... @@ -12,7 +12,10 @@ import sklearn.metrics
12 12 import scipy
13 13 import scipy.misc
14 14 import envi
  15 +import spectral
15 16 import random
  17 +import progressbar
  18 +import matplotlib.pyplot as plt
16 19  
17 20 #generate a 2D color class map using a stack of binary class images
18 21 #input: C is a C x Y x X binary image
... ... @@ -40,7 +43,7 @@ def class2color(C):
40 43 #input: list of class image names
41 44 #output: C x Y x X binary image specifying class/pixel membership
42 45 #example: image2class(("class_coll.bmp", "class_epith.bmp"))
43   -def image2class(masks):
  46 +def filenames2class(masks):
44 47 #get num of mask file names
45 48 num_masks = len(masks)
46 49  
... ... @@ -50,9 +53,11 @@ def image2class(masks):
50 53 return
51 54  
52 55 classimages = []
53   - for m in masks:
54   - img = scipy.misc.imread(m, flatten=True).astype(numpy.bool)
  56 + bar = progressbar.ProgressBar(max_value=num_masks)
  57 + for m in range(0, num_masks):
  58 + img = scipy.misc.imread(masks[m], flatten=True).astype(numpy.bool)
55 59 classimages.append(img)
  60 + bar.update(m+1)
56 61  
57 62 result = numpy.stack(classimages)
58 63 sum_images = numpy.sum(result.astype(numpy.uint32), 0)
... ... @@ -141,7 +146,24 @@ def random_mask(M, n):
141 146 new_mask[numpy.unravel_index(new_idx[0:n], new_mask.shape)] = True
142 147 return new_mask
143 148  
144   -
145   -#Function to convert a set of class labels to a matrix of neuron responses for an ANN
146   -
147   -#Function CNN extraction function
148 149 \ No newline at end of file
  150 +def envi_batch_predict(E, C, batch=10000):
  151 +
  152 + Fv = E.loadbatch(batch).transpose()
  153 + i = 0
  154 + Tv = []
  155 + plt.ion()
  156 + bar = progressbar.ProgressBar(max_value=numpy.count_nonzero(E.mask))
  157 + while not Fv == []:
  158 + if i == 0:
  159 + Tv = C.predict(Fv)
  160 + else:
  161 + Tv = numpy.concatenate((Tv, C.predict(Fv).transpose()), 0)
  162 + tempmask = E.batchmask()
  163 + Lv = spectral.unsift2(Tv, tempmask)
  164 + Cv = label2class(Lv.squeeze(), background=0)
  165 + RGB = class2color(Cv)
  166 + plt.imshow(RGB)
  167 + plt.pause(0.05)
  168 + Fv = E.loadbatch(batch).transpose()
  169 + i = i + 1
  170 + bar.update(len(Tv))
149 171 \ No newline at end of file
... ...
python/envi.py
... ... @@ -10,6 +10,7 @@ import numpy
10 10 import scipy
11 11 import matplotlib.pyplot as plt
12 12 import progressbar
  13 +import sys
13 14  
14 15 class envi_header:
15 16 def __init__(self, filename = ""):
... ... @@ -245,7 +246,7 @@ class envi:
245 246 self.file.readinto(spectrum)
246 247 M[:, p] = spectrum
247 248 bar.update(p+1)
248   - if self.header.interleave == "bsq":
  249 + elif self.header.interleave == "bsq":
249 250 band = numpy.zeros(mask.shape, dtype=self.header.data_type)
250 251 i = numpy.nonzero(mask)
251 252 bar = progressbar.ProgressBar(max_value=B)
... ... @@ -254,7 +255,7 @@ class envi:
254 255 self.file.readinto(band)
255 256 M[b, :] = band[i]
256 257 bar.update(b+1)
257   - if self.header.interleave == "bil":
  258 + elif self.header.interleave == "bil":
258 259 plane = numpy.zeros((B, X), dtype=self.header.data_type)
259 260 p = 0
260 261 bar = progressbar.ProgressBar(max_value=Y)
... ... @@ -294,6 +295,7 @@ class envi:
294 295 F = []
295 296 T = []
296 297 for c in range(0, C):
  298 + print("\nLoading class " + str(c+1) + "...")
297 299 f = self.loadmask(classimages[c, :, :]) #load the feature matrix for class c
298 300 t = numpy.ones((f.shape[1])) * (c+1) #generate a target array
299 301 F.append(f)
... ...