Commit 71ee1c23951161871058ee9f50e377eb2481ef70

Authored by David Mayerich
1 parent f275c01c

added additional functions to allow batch reading of ENVI files and conversion o…

…f label images to class images
Showing 2 changed files with 48 additions and 7 deletions   Show diff stats
python/classify.py
... ... @@ -16,7 +16,7 @@ import envi
16 16 #generate a 2D color class map using a stack of binary class images
17 17 #input: C is a C x Y x X binary image
18 18 #output: an RGB color image with a unique color for each class
19   -def classcolor2(C):
  19 +def class2color(C):
20 20  
21 21 #determine the number of classes
22 22 nc = C.shape[0]
... ... @@ -119,6 +119,17 @@ def prob2roc(P, t_vals, mask=[]):
119 119  
120 120 return sklearn.metrics.roc_curve(labels, scores)
121 121  
  122 +#convert a label image to a C x Y x X class image
  123 +def label2class(L):
  124 + unique = numpy.unique(L)
  125 + s = L.shape
  126 + s = numpy.append(numpy.array((len(unique)-1)), s)
  127 + print(s)
  128 + C = numpy.zeros(s, dtype=numpy.bool)
  129 + for i in range(1, len(unique)):
  130 + C[i-1, :, :] = L == unique[i]
  131 + return C
  132 +
122 133 #Function to convert a set of class labels to a matrix of neuron responses for an ANN
123 134  
124 135 #Function CNN extraction function
125 136 \ No newline at end of file
... ...
python/envi.py
... ... @@ -182,12 +182,13 @@ class envi_header:
182 182 f.close()
183 183  
184 184 class envi:
185   - def __init__(self, filename, headername = "", maskname = ""):
  185 + def __init__(self, filename, headername = "", mask = []):
186 186 self.open(filename, headername)
187   - if maskname == "":
  187 + if mask == []:
188 188 self.mask = numpy.ones((self.header.samples, self.header.lines), dtype=numpy.bool)
189 189 else:
190   - self.mask = scipy.misc.imread(maskname, flatten=True).astype(numpy.bool)
  190 + self.mask = mask
  191 + self.idx = 0 #initialize the batch IDX to 0 for batch reading
191 192  
192 193 def open(self, filename, headername = ""):
193 194 if headername == "":
... ... @@ -232,6 +233,7 @@ class envi:
232 233 M = numpy.zeros((B, P), dtype=self.header.data_type)
233 234 type_bytes = numpy.dtype(self.header.data_type).itemsize
234 235  
  236 + prev_pos = self.file.tell()
235 237 self.file.seek(0)
236 238 if self.header.interleave == "bip":
237 239 spectrum = numpy.zeros(B, dtype=self.header.data_type)
... ... @@ -262,6 +264,7 @@ class envi:
262 264 M[:, p:p+i.shape[0]] = plane[:, i]
263 265 p = p + i.shape[0]
264 266 bar.update(l+1)
  267 + self.file.seek(prev_pos)
265 268 return M
266 269  
267 270 def loadband(self, n):
... ... @@ -271,11 +274,12 @@ class envi:
271 274  
272 275 band = numpy.zeros((Y, X), dtype=self.header.data_type)
273 276 type_bytes = numpy.dtype(self.header.data_type).itemsize
274   -
  277 +
  278 + prev_pos = self.file.tell()
275 279 if self.header.interleave == "bsq":
276 280 self.file.seek(n * X * Y * type_bytes)
277 281 self.file.readinto(band)
278   -
  282 + self.file.seek(prev_pos)
279 283 return band
280 284  
281 285 #create a set of feature/target pairs for classification
... ... @@ -296,7 +300,33 @@ class envi:
296 300 T.append(t)
297 301  
298 302 return numpy.concatenate(F, 1).transpose(), numpy.concatenate(T)
299   -
  303 +
  304 + #read a batch of data based on the mask
  305 + def loadbatch(self, npixels):
  306 + i = numpy.flatnonzero(self.mask) #get the indices of valid pixels
  307 + npixels = min(npixels, len(i) - self.idx - 1) #if there aren't enough pixels, change the batch size
  308 + B = self.header.bands
  309 +
  310 + batch = numpy.zeros((B, npixels), dtype=self.header.data_type) #allocate space for the batch
  311 + pixel = numpy.zeros((B), dtype=self.header.data_type) #allocate space for a single pixel
  312 + type_bytes = numpy.dtype(self.header.data_type).itemsize #calculate the size of a single value
  313 + if self.header.interleave == "bip":
  314 + for n in range(0, npixels): #for each pixel in the batch
  315 + self.file.seek(i[self.idx] * B * type_bytes) #seek to the current pixel in the file
  316 + self.file.readinto(pixel) #read a single pixel
  317 + batch[:, n] = pixel #save the pixel into the batch matrix
  318 + self.idx = self.idx + 1
  319 + #print("idx: " + str(self.idx))
  320 + #print("tell(): " + str(self.file.tell()))
  321 + return batch
  322 + elif self.header.interleave == "bsq":
  323 + print("ERROR: BSQ batch loading isn't implemented yet!")
  324 + elif self.header.interleave == "bil":
  325 + print("ERROR: BIL batch loading isn't implemented yet!")
  326 +
  327 + #returns the current batch index
  328 + def getidx(self):
  329 + return self.idx
300 330  
301 331 def __del__(self):
302 332 self.file.close()
303 333 \ No newline at end of file
... ...