Commit 71ee1c23951161871058ee9f50e377eb2481ef70

Authored by David Mayerich
1 parent f275c01c

added additional functions to allow batch reading of ENVI files and conversion o…

…f label images to class images
Showing 2 changed files with 48 additions and 7 deletions   Show diff stats
python/classify.py
@@ -16,7 +16,7 @@ import envi @@ -16,7 +16,7 @@ import envi
16 #generate a 2D color class map using a stack of binary class images 16 #generate a 2D color class map using a stack of binary class images
17 #input: C is a C x Y x X binary image 17 #input: C is a C x Y x X binary image
18 #output: an RGB color image with a unique color for each class 18 #output: an RGB color image with a unique color for each class
19 -def classcolor2(C): 19 +def class2color(C):
20 20
21 #determine the number of classes 21 #determine the number of classes
22 nc = C.shape[0] 22 nc = C.shape[0]
@@ -119,6 +119,17 @@ def prob2roc(P, t_vals, mask=[]): @@ -119,6 +119,17 @@ def prob2roc(P, t_vals, mask=[]):
119 119
120 return sklearn.metrics.roc_curve(labels, scores) 120 return sklearn.metrics.roc_curve(labels, scores)
121 121
  122 +#convert a label image to a C x Y x X class image
  123 +def label2class(L):
  124 + unique = numpy.unique(L)
  125 + s = L.shape
  126 + s = numpy.append(numpy.array((len(unique)-1)), s)
  127 + print(s)
  128 + C = numpy.zeros(s, dtype=numpy.bool)
  129 + for i in range(1, len(unique)):
  130 + C[i-1, :, :] = L == unique[i]
  131 + return C
  132 +
122 #Function to convert a set of class labels to a matrix of neuron responses for an ANN 133 #Function to convert a set of class labels to a matrix of neuron responses for an ANN
123 134
124 #Function CNN extraction function 135 #Function CNN extraction function
125 \ No newline at end of file 136 \ No newline at end of file
@@ -182,12 +182,13 @@ class envi_header: @@ -182,12 +182,13 @@ class envi_header:
182 f.close() 182 f.close()
183 183
184 class envi: 184 class envi:
185 - def __init__(self, filename, headername = "", maskname = ""): 185 + def __init__(self, filename, headername = "", mask = []):
186 self.open(filename, headername) 186 self.open(filename, headername)
187 - if maskname == "": 187 + if mask == []:
188 self.mask = numpy.ones((self.header.samples, self.header.lines), dtype=numpy.bool) 188 self.mask = numpy.ones((self.header.samples, self.header.lines), dtype=numpy.bool)
189 else: 189 else:
190 - self.mask = scipy.misc.imread(maskname, flatten=True).astype(numpy.bool) 190 + self.mask = mask
  191 + self.idx = 0 #initialize the batch IDX to 0 for batch reading
191 192
192 def open(self, filename, headername = ""): 193 def open(self, filename, headername = ""):
193 if headername == "": 194 if headername == "":
@@ -232,6 +233,7 @@ class envi: @@ -232,6 +233,7 @@ class envi:
232 M = numpy.zeros((B, P), dtype=self.header.data_type) 233 M = numpy.zeros((B, P), dtype=self.header.data_type)
233 type_bytes = numpy.dtype(self.header.data_type).itemsize 234 type_bytes = numpy.dtype(self.header.data_type).itemsize
234 235
  236 + prev_pos = self.file.tell()
235 self.file.seek(0) 237 self.file.seek(0)
236 if self.header.interleave == "bip": 238 if self.header.interleave == "bip":
237 spectrum = numpy.zeros(B, dtype=self.header.data_type) 239 spectrum = numpy.zeros(B, dtype=self.header.data_type)
@@ -262,6 +264,7 @@ class envi: @@ -262,6 +264,7 @@ class envi:
262 M[:, p:p+i.shape[0]] = plane[:, i] 264 M[:, p:p+i.shape[0]] = plane[:, i]
263 p = p + i.shape[0] 265 p = p + i.shape[0]
264 bar.update(l+1) 266 bar.update(l+1)
  267 + self.file.seek(prev_pos)
265 return M 268 return M
266 269
267 def loadband(self, n): 270 def loadband(self, n):
@@ -271,11 +274,12 @@ class envi: @@ -271,11 +274,12 @@ class envi:
271 274
272 band = numpy.zeros((Y, X), dtype=self.header.data_type) 275 band = numpy.zeros((Y, X), dtype=self.header.data_type)
273 type_bytes = numpy.dtype(self.header.data_type).itemsize 276 type_bytes = numpy.dtype(self.header.data_type).itemsize
274 - 277 +
  278 + prev_pos = self.file.tell()
275 if self.header.interleave == "bsq": 279 if self.header.interleave == "bsq":
276 self.file.seek(n * X * Y * type_bytes) 280 self.file.seek(n * X * Y * type_bytes)
277 self.file.readinto(band) 281 self.file.readinto(band)
278 - 282 + self.file.seek(prev_pos)
279 return band 283 return band
280 284
281 #create a set of feature/target pairs for classification 285 #create a set of feature/target pairs for classification
@@ -296,7 +300,33 @@ class envi: @@ -296,7 +300,33 @@ class envi:
296 T.append(t) 300 T.append(t)
297 301
298 return numpy.concatenate(F, 1).transpose(), numpy.concatenate(T) 302 return numpy.concatenate(F, 1).transpose(), numpy.concatenate(T)
299 - 303 +
  304 + #read a batch of data based on the mask
  305 + def loadbatch(self, npixels):
  306 + i = numpy.flatnonzero(self.mask) #get the indices of valid pixels
  307 + npixels = min(npixels, len(i) - self.idx - 1) #if there aren't enough pixels, change the batch size
  308 + B = self.header.bands
  309 +
  310 + batch = numpy.zeros((B, npixels), dtype=self.header.data_type) #allocate space for the batch
  311 + pixel = numpy.zeros((B), dtype=self.header.data_type) #allocate space for a single pixel
  312 + type_bytes = numpy.dtype(self.header.data_type).itemsize #calculate the size of a single value
  313 + if self.header.interleave == "bip":
  314 + for n in range(0, npixels): #for each pixel in the batch
  315 + self.file.seek(i[self.idx] * B * type_bytes) #seek to the current pixel in the file
  316 + self.file.readinto(pixel) #read a single pixel
  317 + batch[:, n] = pixel #save the pixel into the batch matrix
  318 + self.idx = self.idx + 1
  319 + #print("idx: " + str(self.idx))
  320 + #print("tell(): " + str(self.file.tell()))
  321 + return batch
  322 + elif self.header.interleave == "bsq":
  323 + print("ERROR: BSQ batch loading isn't implemented yet!")
  324 + elif self.header.interleave == "bil":
  325 + print("ERROR: BIL batch loading isn't implemented yet!")
  326 +
  327 + #returns the current batch index
  328 + def getidx(self):
  329 + return self.idx
300 330
301 def __del__(self): 331 def __del__(self):
302 self.file.close() 332 self.file.close()
303 \ No newline at end of file 333 \ No newline at end of file