Commit 71ee1c23951161871058ee9f50e377eb2481ef70
1 parent
f275c01c
added additional functions to allow batch reading of ENVI files and conversion o…
…f label images to class images
Showing
2 changed files
with
48 additions
and
7 deletions
Show diff stats
python/classify.py
... | ... | @@ -16,7 +16,7 @@ import envi |
16 | 16 | #generate a 2D color class map using a stack of binary class images |
17 | 17 | #input: C is a C x Y x X binary image |
18 | 18 | #output: an RGB color image with a unique color for each class |
19 | -def classcolor2(C): | |
19 | +def class2color(C): | |
20 | 20 | |
21 | 21 | #determine the number of classes |
22 | 22 | nc = C.shape[0] |
... | ... | @@ -119,6 +119,17 @@ def prob2roc(P, t_vals, mask=[]): |
119 | 119 | |
120 | 120 | return sklearn.metrics.roc_curve(labels, scores) |
121 | 121 | |
122 | +#convert a label image to a C x Y x X class image | |
123 | +def label2class(L): | |
124 | + unique = numpy.unique(L) | |
125 | + s = L.shape | |
126 | + s = numpy.append(numpy.array((len(unique)-1)), s) | |
127 | + print(s) | |
128 | + C = numpy.zeros(s, dtype=numpy.bool) | |
129 | + for i in range(1, len(unique)): | |
130 | + C[i-1, :, :] = L == unique[i] | |
131 | + return C | |
132 | + | |
122 | 133 | #Function to convert a set of class labels to a matrix of neuron responses for an ANN |
123 | 134 | |
124 | 135 | #Function CNN extraction function |
125 | 136 | \ No newline at end of file | ... | ... |
python/envi.py
... | ... | @@ -182,12 +182,13 @@ class envi_header: |
182 | 182 | f.close() |
183 | 183 | |
184 | 184 | class envi: |
185 | - def __init__(self, filename, headername = "", maskname = ""): | |
185 | + def __init__(self, filename, headername = "", mask = []): | |
186 | 186 | self.open(filename, headername) |
187 | - if maskname == "": | |
187 | + if mask == []: | |
188 | 188 | self.mask = numpy.ones((self.header.samples, self.header.lines), dtype=numpy.bool) |
189 | 189 | else: |
190 | - self.mask = scipy.misc.imread(maskname, flatten=True).astype(numpy.bool) | |
190 | + self.mask = mask | |
191 | + self.idx = 0 #initialize the batch IDX to 0 for batch reading | |
191 | 192 | |
192 | 193 | def open(self, filename, headername = ""): |
193 | 194 | if headername == "": |
... | ... | @@ -232,6 +233,7 @@ class envi: |
232 | 233 | M = numpy.zeros((B, P), dtype=self.header.data_type) |
233 | 234 | type_bytes = numpy.dtype(self.header.data_type).itemsize |
234 | 235 | |
236 | + prev_pos = self.file.tell() | |
235 | 237 | self.file.seek(0) |
236 | 238 | if self.header.interleave == "bip": |
237 | 239 | spectrum = numpy.zeros(B, dtype=self.header.data_type) |
... | ... | @@ -262,6 +264,7 @@ class envi: |
262 | 264 | M[:, p:p+i.shape[0]] = plane[:, i] |
263 | 265 | p = p + i.shape[0] |
264 | 266 | bar.update(l+1) |
267 | + self.file.seek(prev_pos) | |
265 | 268 | return M |
266 | 269 | |
267 | 270 | def loadband(self, n): |
... | ... | @@ -271,11 +274,12 @@ class envi: |
271 | 274 | |
272 | 275 | band = numpy.zeros((Y, X), dtype=self.header.data_type) |
273 | 276 | type_bytes = numpy.dtype(self.header.data_type).itemsize |
274 | - | |
277 | + | |
278 | + prev_pos = self.file.tell() | |
275 | 279 | if self.header.interleave == "bsq": |
276 | 280 | self.file.seek(n * X * Y * type_bytes) |
277 | 281 | self.file.readinto(band) |
278 | - | |
282 | + self.file.seek(prev_pos) | |
279 | 283 | return band |
280 | 284 | |
281 | 285 | #create a set of feature/target pairs for classification |
... | ... | @@ -296,7 +300,33 @@ class envi: |
296 | 300 | T.append(t) |
297 | 301 | |
298 | 302 | return numpy.concatenate(F, 1).transpose(), numpy.concatenate(T) |
299 | - | |
303 | + | |
304 | + #read a batch of data based on the mask | |
305 | + def loadbatch(self, npixels): | |
306 | + i = numpy.flatnonzero(self.mask) #get the indices of valid pixels | |
307 | + npixels = min(npixels, len(i) - self.idx - 1) #if there aren't enough pixels, change the batch size | |
308 | + B = self.header.bands | |
309 | + | |
310 | + batch = numpy.zeros((B, npixels), dtype=self.header.data_type) #allocate space for the batch | |
311 | + pixel = numpy.zeros((B), dtype=self.header.data_type) #allocate space for a single pixel | |
312 | + type_bytes = numpy.dtype(self.header.data_type).itemsize #calculate the size of a single value | |
313 | + if self.header.interleave == "bip": | |
314 | + for n in range(0, npixels): #for each pixel in the batch | |
315 | + self.file.seek(i[self.idx] * B * type_bytes) #seek to the current pixel in the file | |
316 | + self.file.readinto(pixel) #read a single pixel | |
317 | + batch[:, n] = pixel #save the pixel into the batch matrix | |
318 | + self.idx = self.idx + 1 | |
319 | + #print("idx: " + str(self.idx)) | |
320 | + #print("tell(): " + str(self.file.tell())) | |
321 | + return batch | |
322 | + elif self.header.interleave == "bsq": | |
323 | + print("ERROR: BSQ batch loading isn't implemented yet!") | |
324 | + elif self.header.interleave == "bil": | |
325 | + print("ERROR: BIL batch loading isn't implemented yet!") | |
326 | + | |
327 | + #returns the current batch index | |
328 | + def getidx(self): | |
329 | + return self.idx | |
300 | 330 | |
301 | 331 | def __del__(self): |
302 | 332 | self.file.close() |
303 | 333 | \ No newline at end of file | ... | ... |