Commit 35534291ba7760d96ef0239a0303a75f0481e2f8

Authored by sberisha
1 parent db331d8a

envi.py: added function to load balanced data for training

Showing 3 changed files with 104 additions and 25 deletions   Show diff stats
matlab/stimBrewerMap.m
1   -function result = stimBrewerColormap(R)
  1 +function result = stimBrewerMap(R)
2 2  
3 3 %returns a Brewer colormap with the specified resolution R
4 4  
5   -ctrlPts = zeros(11, 3);
6   -
7   -ctrlPts(1, :) = [0.192157, 0.211765, 0.584314];
8   -ctrlPts(2, :) = [0.270588, 0.458824, 0.705882];
9   -ctrlPts(3, :) = [0.454902, 0.678431, 0.819608];
10   -ctrlPts(4, :) = [0.670588, 0.85098, 0.913725];
11   -ctrlPts(5, :) = [0.878431, 0.952941, 0.972549];
12   -ctrlPts(6, :) = [1, 1, 0.74902];
13   -ctrlPts(7, :) = [0.996078, 0.878431, 0.564706];
14   -ctrlPts(8, :) = [0.992157, 0.682353, 0.380392];
15   -ctrlPts(9, :) = [0.956863, 0.427451, 0.262745];
16   -ctrlPts(10, :) = [0.843137, 0.188235, 0.152941];
17   -ctrlPts(11, :) = [0.647059, 0, 0.14902];
18   -
19   -X = 1:11;
20   -
21   -r = 1:11/R:11;
  5 +ctrlPts = zeros(12, 3);
  6 +ctrlPts(1, :) = [0, 0, 0];
  7 +ctrlPts(2, :) = [0.192157, 0.211765, 0.584314];
  8 +ctrlPts(3, :) = [0.270588, 0.458824, 0.705882];
  9 +ctrlPts(4, :) = [0.454902, 0.678431, 0.819608];
  10 +ctrlPts(5, :) = [0.670588, 0.85098, 0.913725];
  11 +ctrlPts(6, :) = [0.878431, 0.952941, 0.972549];
  12 +ctrlPts(7, :) = [1, 1, 0.74902];
  13 +ctrlPts(8, :) = [0.996078, 0.878431, 0.564706];
  14 +ctrlPts(9, :) = [0.992157, 0.682353, 0.380392];
  15 +ctrlPts(10, :) = [0.956863, 0.427451, 0.262745];
  16 +ctrlPts(11, :) = [0.843137, 0.188235, 0.152941];
  17 +ctrlPts(12, :) = [0.647059, 0, 0.14902];
  18 +
  19 +X = 1:12;
  20 +
  21 +r = 1:12/R:12;
22 22  
23 23 R = interp1(X, ctrlPts(:, 1), r);
24 24 G = interp1(X, ctrlPts(:, 2), r);
... ...
python/classify.py
... ... @@ -14,7 +14,7 @@ import scipy.misc
14 14 import envi
15 15 import hyperspectral
16 16 import random
17   -import progressbar
  17 +import pyprind
18 18 import matplotlib.pyplot as plt
19 19  
20 20 #generate N qualitative colors and return the value for color c
... ... @@ -66,11 +66,13 @@ def filenames2class(masks):
66 66 return
67 67  
68 68 classimages = []
69   - bar = progressbar.ProgressBar(max_value=num_masks)
  69 + #bar = progressbar.ProgressBar(max_value=num_masks)
  70 + bar = pyprind.ProgBar(num_masks)
70 71 for m in range(0, num_masks):
71 72 img = scipy.misc.imread(masks[m], flatten=True).astype(numpy.bool)
72 73 classimages.append(img)
73   - bar.update(m+1)
  74 + #bar.update(m+1)
  75 + bar.update()
74 76  
75 77 result = numpy.stack(classimages)
76 78 sum_images = numpy.sum(result.astype(numpy.uint32), 0)
... ... @@ -169,7 +171,8 @@ def envi_batch_predict(E, C, batch=10000):
169 171 i = 0
170 172 Tv = []
171 173 plt.ion()
172   - bar = progressbar.ProgressBar(max_value=numpy.count_nonzero(E.mask))
  174 + #bar = progressbar.ProgressBar(max_value=numpy.count_nonzero(E.mask))
  175 + bar = pyprind.ProgBar(numpy.count_nonzero(E.mask))
173 176 while not Fv == []:
174 177 Fv = numpy.nan_to_num(Fv) #remove infinite values
175 178 if i == 0:
... ... @@ -184,4 +187,5 @@ def envi_batch_predict(E, C, batch=10000):
184 187 plt.pause(0.05)
185 188 Fv = E.loadbatch(batch)
186 189 i = i + 1
187   - bar.update(len(Tv))
  190 + #bar.update(len(Tv))
  191 + bar.update()
188 192 \ No newline at end of file
... ...
python/envi.py
... ... @@ -9,8 +9,10 @@ import os
9 9 import numpy
10 10 import scipy
11 11 import matplotlib.pyplot as plt
12   -import progressbar
  12 +#import pyprind
13 13 import sys
  14 +from math import floor
  15 +import progressbar
14 16  
15 17 class envi_header:
16 18 def __init__(self, filename = ""):
... ... @@ -272,30 +274,36 @@ class envi:
272 274 flatmask = numpy.reshape(mask, (X * Y))
273 275 i = numpy.flatnonzero(flatmask)
274 276 bar = progressbar.ProgressBar(max_value = P)
  277 + #bar = pyprind.ProgBar(P)
275 278 for p in range(0, P):
276 279 self.file.seek(i[p] * B * type_bytes)
277 280 self.file.readinto(spectrum)
278 281 M[:, p] = spectrum
279 282 bar.update(p+1)
  283 + #bar.update()
280 284 elif self.header.interleave == "bsq":
281 285 band = numpy.zeros(mask.shape, dtype=self.header.data_type)
282 286 i = numpy.nonzero(mask)
283 287 bar = progressbar.ProgressBar(max_value=B)
  288 + #bar = pyprind.ProgBar(P)
284 289 for b in range(0, B):
285 290 self.file.seek(b * X * Y * type_bytes)
286 291 self.file.readinto(band)
287 292 M[b, :] = band[i]
288 293 bar.update(b+1)
  294 + #bar.update()
289 295 elif self.header.interleave == "bil":
290 296 plane = numpy.zeros((B, X), dtype=self.header.data_type)
291 297 p = 0
292 298 bar = progressbar.ProgressBar(max_value=Y)
  299 + #bar = pyprind.ProgBar(P)
293 300 for l in range(0, Y):
294 301 i = numpy.flatnonzero(mask[l, :])
295 302 self.file.readinto(plane)
296 303 M[:, p:p+i.shape[0]] = plane[:, i]
297 304 p = p + i.shape[0]
298 305 bar.update(l+1)
  306 + #bar.update()
299 307 self.file.seek(prev_pos)
300 308 return M
301 309  
... ... @@ -334,6 +342,73 @@ class envi:
334 342 T.append(t)
335 343  
336 344 return numpy.nan_to_num(numpy.concatenate(F, 1).transpose()), numpy.concatenate(T)
  345 +
  346 +
  347 + #create a set of feature/target pairs for classification with balanced data
  348 + #input: envi file object, stack of class masks C x Y x X, number of samples per class
  349 + #output: feature matrix (features x pixels), target matrix (1 x pixels)
  350 + #example: generate_training(("class_coll.bmp", "class_epith.bmp"), (1, 2))
  351 + # verify verify that there are no NaN or Inf values
  352 + def loadtrain_balance(self, classimages, num_samples=None):
  353 +
  354 + # get number of classes
  355 + C = classimages.shape[0]
  356 +
  357 + F = []
  358 + T = []
  359 +
  360 + # get number of samples per class
  361 + samples_per_class = numpy.zeros(C, dtype=numpy.int32)
  362 + for c in range(0, C):
  363 + if num_samples is None:
  364 + samples_per_class[c] = numpy.count_nonzero(classimages[c, :, :])
  365 + else:
  366 + # if user has specified a max number of samples per class
  367 + if num_samples > numpy.count_nonzero(classimages[c, :, :]):
  368 + samples_per_class[c] = numpy.count_nonzero(classimages[c, :, :])
  369 + else:
  370 + samples_per_class[c] = num_samples
  371 +
  372 + for c in range(0, C):
  373 + print("\nLoading class " + str(c+1) + "...")
  374 + # row, col index of valid pixels
  375 + temp = classimages[c,:]
  376 + flat_temp = numpy.reshape(temp, temp.shape[0]*temp.shape[1])
  377 +
  378 + idx = numpy.flatnonzero(temp) # indices of nonzero values
  379 + if num_samples:
  380 + # use specific number of samples for training
  381 + numpy.random.shuffle(idx)
  382 + idx = idx[0:samples_per_class[c]]
  383 +
  384 + # increase number of samples by copying them over multiple times
  385 + max_samples = numpy.amax(samples_per_class)
  386 + # num of times to copy for even division
  387 + copy_times = int(floor(max_samples / samples_per_class[c]))
  388 + rem = max_samples % samples_per_class[c] # remaining samples
  389 +
  390 + for i in range(0, copy_times):
  391 + numpy.random.shuffle(idx)
  392 + shuffle_temp = numpy.zeros(flat_temp.shape, dtype=bool)
  393 + shuffle_temp[idx] = flat_temp[idx]
  394 + f = self.loadmask(numpy.reshape(shuffle_temp, (temp.shape[0], temp.shape[1]))) # load the feature matrix for class c
  395 + t = numpy.ones((f.shape[1])) * (c+1) # generate a target array
  396 + F.append(f)
  397 + T.append(t)
  398 +
  399 + # copy the remaning samples so the total matches the max number of samples chosen by user
  400 + if rem > 0:
  401 + numpy.random.shuffle(idx)
  402 + idx = idx[0:rem]
  403 + shuffle_temp = numpy.zeros(flat_temp.shape, dtype=bool)
  404 + shuffle_temp[idx] = flat_temp[idx]
  405 + f = self.loadmask(numpy.reshape(shuffle_temp, (temp.shape[0], temp.shape[1]))) # load the feature matrix for class c
  406 + t = numpy.ones((f.shape[1])) * (c+1) # generate a target array
  407 + F.append(f)
  408 + T.append(t)
  409 +
  410 + return numpy.nan_to_num(numpy.concatenate(F, 1).transpose()), numpy.concatenate(T)
  411 +
337 412  
338 413 #read a batch of data based on the mask
339 414 def loadbatch(self, npixels):
... ... @@ -390,4 +465,4 @@ def save_envi(A, fname):
390 465 #save the raw data
391 466 file = open(fname, "wb")
392 467 file.write(bytearray(A))
393   - file.close()
394 468 \ No newline at end of file
  469 + file.close()
... ...