# -*- coding: utf-8 -*- """ Created on Fri Jul 21 20:18:01 2017 @author: david """ import os import numpy import scipy import matplotlib.pyplot as plt import progressbar class envi_header: def __init__(self, filename = ""): if filename != "": self.load(filename) else: self.initialize() #initialization function def initialize(self): self.samples = int(0) self.lines = int(0) self.bands = int(0) self.header_offset = int(0) self.data_type = int(4) self.interleave = "bsq" self.sensor_type = "" self.byte_order = int(0) self.x_start = int(0) self.y_start = int(0) self.z_plot_titles = "" self.pixel_size = [float(0), float(0)] self.pixel_size_units = "Meters" self.wavelength_units = "Wavenumber" self.description = "" self.band_names = [] self.wavelength = [] #convert an ENVI data_type value to a numpy data type def get_numpy_type(self, val): if val == 1: return numpy.byte elif val == 2: return numpy.int16 elif val == 3: return numpy.int32 elif val == 4: return numpy.float32 elif val == 5: return numpy.float64 elif val == 6: return numpy.complex64 elif val == 9: return numpy.complex128 elif val == 12: return numpy.uint16 elif val == 13: return numpy.uint32 elif val == 14: return numpy.int64 elif val == 15: return numpy.uint64 def get_envi_type(self, val): if val == numpy.byte: return 1 elif val == numpy.int16: return 2 elif val == numpy.int32: return 3 elif val == numpy.float32: return 4 elif val == numpy.float64: return 5 elif val == numpy.complex64: return 6 elif val == numpy.complex128: return 9 elif val == numpy.uint16: return 12 elif val == numpy.uint32: return 13 elif val == numpy.int64: return 14 elif val == numpy.uint64: return 15 def load(self, fname): f = open(fname) l = f.readlines() if l[0].strip() != "ENVI": print("ERROR: not an ENVI file") return li = 1 while li < len(l): #t = l[li].split() #split the line into tokens #t = map(str.strip, t) #strip all of the tokens in the token list #handle the simple conditions if l[li].startswith("file type"): if not l[li].strip().endswith("ENVI Standard"): print("ERROR: unsupported ENVI file format: " + l[li].strip()) return elif l[li].startswith("samples"): self.samples = int(l[li].split()[-1]) elif l[li].startswith("lines"): self.lines = int(l[li].split()[-1]) elif l[li].startswith("bands"): self.bands = int(l[li].split()[-1]) elif l[li].startswith("header offset"): self.header_offset = int(l[li].split()[-1]) elif l[li].startswith("data type"): self.data_type = self.get_numpy_type(int(l[li].split()[-1])) elif l[li].startswith("interleave"): self.interleave = l[li].split()[-1].strip() elif l[li].startswith("sensor type"): self.sensor_type = l[li].split()[-1].strip() elif l[li].startswith("byte order"): self.byte_order = int(l[li].split()[-1]) elif l[li].startswith("x start"): self.x_start = int(l[li].split()[-1]) elif l[li].startswith("y start"): self.y_start = int(l[li].split()[-1]) elif l[li].startswith("z plot titles"): i0 = l[li].rindex('{') i1 = l[li].rindex('}') self.z_plot_titles = l[li][i0 + 1 : i1] elif l[li].startswith("pixel size"): i0 = l[li].rindex('{') i1 = l[li].rindex('}') s = l[li][i0 + 1 : i1].split(',') self.pixel_size = [float(s[0]), float(s[1])] self.pixel_size_units = s[2][s[2].rindex('=') + 1:].strip() elif l[li].startswith("wavelength units"): self.wavelength_units = l[li].split()[-1].strip() #handle the complicated conditions elif l[li].startswith("description"): desc = [l[li]] ''' while l[li].strip()[-1] != '}': #will fail if l[li].strip() is empty li += 1 desc.append(l[li]) ''' while True: if l[li].strip(): if l[li].strip()[-1] == '}': break li += 1 desc.append(l[li]) desc = ''.join(list(map(str.strip, desc))) #strip all white space from the string list i0 = desc.rindex('{') i1 = desc.rindex('}') self.description = desc[i0 + 1 : i1] elif l[li].startswith("band names"): names = [l[li]] while l[li].strip()[-1] != '}': li += 1 names.append(l[li]) names = ''.join(list(map(str.strip, names))) #strip all white space from the string list i0 = names.rindex('{') i1 = names.rindex('}') names = names[i0 + 1 : i1] self.band_names = list(map(str.strip, names.split(','))) elif l[li].startswith("wavelength"): waves = [l[li]] while l[li].strip()[-1] != '}': li += 1 waves.append(l[li]) waves = ''.join(list(map(str.strip, waves))) #strip all white space from the string list i0 = waves.rindex('{') i1 = waves.rindex('}') waves = waves[i0 + 1 : i1] self.wavelength = list(map(float, waves.split(','))) li += 1 f.close() class envi: def __init__(self, filename, headername = "", mask = []): self.open(filename, headername) if mask == []: self.mask = numpy.ones((self.header.samples, self.header.lines), dtype=numpy.bool) else: self.mask = mask self.idx = 0 #initialize the batch IDX to 0 for batch reading def open(self, filename, headername = ""): if headername == "": headername = filename + ".hdr" if not os.path.isfile(filename): print("ERROR: " + filename + " not found") return if not os.path.isfile(headername): print("ERROR: " + headername + " not found") return #open the file self.header = envi_header(headername) self.file = open(filename, "rb") def loadall(self): X = self.header.samples Y = self.header.lines B = self.header.bands #load the data D = numpy.fromfile(self.file, dtype=self.header.data_type) if self.header.interleave == "bsq": return numpy.reshape(D, (B, Y, X)) #return numpy.swapaxes(D, 0, 2) elif self.header.interleave == "bip": D = numpy.reshape(D, (Y, X, B)) return numpy.rollaxis(D, 2) elif self.header.interleave == "bil": D = numpy.reshape(D, (Y, B, X)) return numpy.rollaxis(D, 1) #loads all of the pixels where mask != 0 and returns them as a matrix def loadmask(self, mask): X = self.header.samples Y = self.header.lines B = self.header.bands P = numpy.count_nonzero(mask) #count the number of zeros in the mask file M = numpy.zeros((B, P), dtype=self.header.data_type) type_bytes = numpy.dtype(self.header.data_type).itemsize prev_pos = self.file.tell() self.file.seek(0) if self.header.interleave == "bip": spectrum = numpy.zeros(B, dtype=self.header.data_type) flatmask = numpy.reshape(mask, (X * Y)) i = numpy.flatnonzero(flatmask) bar = progressbar.ProgressBar(max_value = P) for p in range(0, P): self.file.seek(i[p] * B * type_bytes) self.file.readinto(spectrum) M[:, p] = spectrum bar.update(p+1) if self.header.interleave == "bsq": band = numpy.zeros(mask.shape, dtype=self.header.data_type) i = numpy.nonzero(mask) bar = progressbar.ProgressBar(max_value=B) for b in range(0, B): self.file.seek(b * X * Y * type_bytes) self.file.readinto(band) M[b, :] = band[i] bar.update(b+1) if self.header.interleave == "bil": plane = numpy.zeros((B, X), dtype=self.header.data_type) p = 0 bar = progressbar.ProgressBar(max_value=Y) for l in range(0, Y): i = numpy.flatnonzero(mask[l, :]) self.file.readinto(plane) M[:, p:p+i.shape[0]] = plane[:, i] p = p + i.shape[0] bar.update(l+1) self.file.seek(prev_pos) return M def loadband(self, n): X = self.header.samples Y = self.header.lines B = self.header.bands band = numpy.zeros((Y, X), dtype=self.header.data_type) type_bytes = numpy.dtype(self.header.data_type).itemsize prev_pos = self.file.tell() if self.header.interleave == "bsq": self.file.seek(n * X * Y * type_bytes) self.file.readinto(band) self.file.seek(prev_pos) return band #create a set of feature/target pairs for classification #input: envi file object, stack of class masks C x Y x X #output: feature matrix (features x pixels), target matrix (1 x pixels) #example: generate_training(("class_coll.bmp", "class_epith.bmp"), (1, 2)) def loadtrain(self, classimages): # get number of classes C = classimages.shape[0] F = [] T = [] for c in range(0, C): f = self.loadmask(classimages[c, :, :]) #load the feature matrix for class c t = numpy.ones((f.shape[1])) * (c+1) #generate a target array F.append(f) T.append(t) return numpy.concatenate(F, 1).transpose(), numpy.concatenate(T) #read a batch of data based on the mask def loadbatch(self, npixels): i = numpy.flatnonzero(self.mask) #get the indices of valid pixels npixels = min(npixels, len(i) - self.idx - 1) #if there aren't enough pixels, change the batch size B = self.header.bands batch = numpy.zeros((B, npixels), dtype=self.header.data_type) #allocate space for the batch pixel = numpy.zeros((B), dtype=self.header.data_type) #allocate space for a single pixel type_bytes = numpy.dtype(self.header.data_type).itemsize #calculate the size of a single value if self.header.interleave == "bip": for n in range(0, npixels): #for each pixel in the batch self.file.seek(i[self.idx] * B * type_bytes) #seek to the current pixel in the file self.file.readinto(pixel) #read a single pixel batch[:, n] = pixel #save the pixel into the batch matrix self.idx = self.idx + 1 #print("idx: " + str(self.idx)) #print("tell(): " + str(self.file.tell())) return batch elif self.header.interleave == "bsq": print("ERROR: BSQ batch loading isn't implemented yet!") elif self.header.interleave == "bil": print("ERROR: BIL batch loading isn't implemented yet!") #returns the current batch index def getidx(self): return self.idx def __del__(self): self.file.close()