Commit 18368aa95c6e8cd961dd3316cf9eecb14ea45815

Authored by David Mayerich
1 parent e0e01dfb

added a new set of python tools for working with ENVI files, as well as basic cl…

…assification and spectral manipulation
Showing 3 changed files with 338 additions and 0 deletions   Show diff stats
python/classify.py 0 → 100644
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +Created on Sun Jul 23 16:04:33 2017
  4 +
  5 +@author: david
  6 +"""
  7 +
  8 +import numpy
  9 +import colorsys
  10 +
  11 +#generate a 2D color class map using a stack of binary class images
  12 +def classcolor2(C):
  13 +
  14 + #determine the number of classes
  15 + nc = C.shape[-1]
  16 +
  17 + #generate an RGB image
  18 + RGB = numpy.zeros((C.shape[0], C.shape[1], 3), dtype=numpy.ubyte)
  19 +
  20 + #for each class
  21 + for c in range(0, nc):
  22 + hsv = (c * 1.0 / nc, 1, 1)
  23 + color = numpy.asarray(colorsys.hsv_to_rgb(hsv[0], hsv[1], hsv[2])) * 255
  24 + RGB[C[:, :, c], :] = color
  25 +
  26 + return RGB
0 27 \ No newline at end of file
... ...
python/envi.py 0 → 100644
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +Created on Fri Jul 21 20:18:01 2017
  4 +
  5 +@author: david
  6 +"""
  7 +
  8 +import os
  9 +import numpy
  10 +import scipy
  11 +import matplotlib.pyplot as plt
  12 +import progressbar
  13 +
  14 +class envi_header:
  15 + def __init__(self, filename = ""):
  16 + if filename != "":
  17 + self.load(filename)
  18 + else:
  19 + self.initialize()
  20 +
  21 + #initialization function
  22 + def initialize(self):
  23 + self.samples = int(0)
  24 + self.lines = int(0)
  25 + self.bands = int(0)
  26 + self.header_offset = int(0)
  27 + self.data_type = int(4)
  28 + self.interleave = "bsq"
  29 + self.sensor_type = ""
  30 + self.byte_order = int(0)
  31 + self.x_start = int(0)
  32 + self.y_start = int(0)
  33 + self.z_plot_titles = ""
  34 + self.pixel_size = [float(0), float(0)]
  35 + self.pixel_size_units = "Meters"
  36 + self.wavelength_units = "Wavenumber"
  37 + self.description = ""
  38 + self.band_names = []
  39 + self.wavelength = []
  40 +
  41 + #convert an ENVI data_type value to a numpy data type
  42 + def get_numpy_type(self, val):
  43 + if val == 1:
  44 + return numpy.byte
  45 + elif val == 2:
  46 + return numpy.int16
  47 + elif val == 3:
  48 + return numpy.int32
  49 + elif val == 4:
  50 + return numpy.float32
  51 + elif val == 5:
  52 + return numpy.float64
  53 + elif val == 6:
  54 + return numpy.complex64
  55 + elif val == 9:
  56 + return numpy.complex128
  57 + elif val == 12:
  58 + return numpy.uint16
  59 + elif val == 13:
  60 + return numpy.uint32
  61 + elif val == 14:
  62 + return numpy.int64
  63 + elif val == 15:
  64 + return numpy.uint64
  65 +
  66 + def get_envi_type(self, val):
  67 + if val == numpy.byte:
  68 + return 1
  69 + elif val == numpy.int16:
  70 + return 2
  71 + elif val == numpy.int32:
  72 + return 3
  73 + elif val == numpy.float32:
  74 + return 4
  75 + elif val == numpy.float64:
  76 + return 5
  77 + elif val == numpy.complex64:
  78 + return 6
  79 + elif val == numpy.complex128:
  80 + return 9
  81 + elif val == numpy.uint16:
  82 + return 12
  83 + elif val == numpy.uint32:
  84 + return 13
  85 + elif val == numpy.int64:
  86 + return 14
  87 + elif val == numpy.uint64:
  88 + return 15
  89 +
  90 + def load(self, fname):
  91 + f = open(fname)
  92 + l = f.readlines()
  93 + if l[0].strip() != "ENVI":
  94 + print("ERROR: not an ENVI file")
  95 + return
  96 + li = 1
  97 + while li < len(l):
  98 + #t = l[li].split() #split the line into tokens
  99 + #t = map(str.strip, t) #strip all of the tokens in the token list
  100 +
  101 + #handle the simple conditions
  102 + if l[li].startswith("file type"):
  103 + if not l[li].strip().endswith("ENVI Standard"):
  104 + print("ERROR: unsupported ENVI file format: " + l[li].strip())
  105 + return
  106 + elif l[li].startswith("samples"):
  107 + self.samples = int(l[li].split()[-1])
  108 + elif l[li].startswith("lines"):
  109 + self.lines = int(l[li].split()[-1])
  110 + elif l[li].startswith("bands"):
  111 + self.bands = int(l[li].split()[-1])
  112 + elif l[li].startswith("header offset"):
  113 + self.header_offset = int(l[li].split()[-1])
  114 + elif l[li].startswith("data type"):
  115 + self.data_type = self.get_numpy_type(int(l[li].split()[-1]))
  116 + elif l[li].startswith("interleave"):
  117 + self.interleave = l[li].split()[-1].strip()
  118 + elif l[li].startswith("sensor type"):
  119 + self.sensor_type = l[li].split()[-1].strip()
  120 + elif l[li].startswith("byte order"):
  121 + self.byte_order = int(l[li].split()[-1])
  122 + elif l[li].startswith("x start"):
  123 + self.x_start = int(l[li].split()[-1])
  124 + elif l[li].startswith("y start"):
  125 + self.y_start = int(l[li].split()[-1])
  126 + elif l[li].startswith("z plot titles"):
  127 + i0 = l[li].rindex('{')
  128 + i1 = l[li].rindex('}')
  129 + self.z_plot_titles = l[li][i0 + 1 : i1]
  130 + elif l[li].startswith("pixel size"):
  131 + i0 = l[li].rindex('{')
  132 + i1 = l[li].rindex('}')
  133 + s = l[li][i0 + 1 : i1].split(',')
  134 + self.pixel_size = [float(s[0]), float(s[1])]
  135 + self.pixel_size_units = s[2][s[2].rindex('=') + 1:].strip()
  136 + elif l[li].startswith("wavelength units"):
  137 + self.wavelength_units = l[li].split()[-1].strip()
  138 +
  139 + #handle the complicated conditions
  140 + elif l[li].startswith("description"):
  141 + desc = [l[li]]
  142 + while l[li].strip()[-1] != '}':
  143 + li += 1
  144 + desc.append(l[li])
  145 + desc = ''.join(list(map(str.strip, desc))) #strip all white space from the string list
  146 + i0 = desc.rindex('{')
  147 + i1 = desc.rindex('}')
  148 + self.description = desc[i0 + 1 : i1]
  149 +
  150 + elif l[li].startswith("band names"):
  151 + names = [l[li]]
  152 + while l[li].strip()[-1] != '}':
  153 + li += 1
  154 + names.append(l[li])
  155 + names = ''.join(list(map(str.strip, names))) #strip all white space from the string list
  156 + i0 = names.rindex('{')
  157 + i1 = names.rindex('}')
  158 + names = names[i0 + 1 : i1]
  159 + self.band_names = list(map(str.strip, names.split(',')))
  160 + elif l[li].startswith("wavelength"):
  161 + waves = [l[li]]
  162 + while l[li].strip()[-1] != '}':
  163 + li += 1
  164 + waves.append(l[li])
  165 + waves = ''.join(list(map(str.strip, waves))) #strip all white space from the string list
  166 + i0 = waves.rindex('{')
  167 + i1 = waves.rindex('}')
  168 + waves = waves[i0 + 1 : i1]
  169 + self.wavelength = list(map(float, waves.split(',')))
  170 +
  171 + li += 1
  172 +
  173 + f.close()
  174 +
  175 +class envi:
  176 + def __init__(self, filename, headername = "", maskname = ""):
  177 + self.open(filename, headername)
  178 + if maskname == "":
  179 + self.mask = numpy.ones((self.header.samples, self.header.lines), dtype=numpy.bool)
  180 + else:
  181 + self.mask = scipy.misc.imread(maskname, flatten=True).astype(numpy.bool)
  182 +
  183 +
  184 + def open(self, filename, headername = ""):
  185 + if headername == "":
  186 + headername = filename + ".hdr"
  187 +
  188 + if not os.path.isfile(filename):
  189 + print("ERROR: " + filename + " not found")
  190 + return
  191 + if not os.path.isfile(headername):
  192 + print("ERROR: " + headername + " not found")
  193 + return
  194 +
  195 + #open the file
  196 + self.header = envi_header(headername)
  197 + self.file = open(filename, "rb")
  198 +
  199 + def loadall(self):
  200 + X = self.header.samples
  201 + Y = self.header.lines
  202 + B = self.header.bands
  203 +
  204 + #load the data
  205 + D = numpy.fromfile(self.file, dtype=self.header.data_type)
  206 +
  207 + if self.header.interleave == "bsq":
  208 + return numpy.reshape(D, (B, Y, X))
  209 + #return numpy.swapaxes(D, 0, 2)
  210 + elif self.header.interleave == "bip":
  211 + D = numpy.reshape(D, (Y, X, B))
  212 + return numpy.rollaxis(D, 2)
  213 + elif self.header.interleave == "bil":
  214 + D = numpy.reshape(D, (Y, B, X))
  215 + return numpy.rollaxis(D, 1)
  216 +
  217 + #loads all of the pixels where mask != 0 and returns them as a matrix
  218 + def loadmask(self, mask):
  219 + X = self.header.samples
  220 + Y = self.header.lines
  221 + B = self.header.bands
  222 +
  223 + P = numpy.count_nonzero(mask) #count the number of zeros in the mask file
  224 + M = numpy.zeros((B, P), dtype=self.header.data_type)
  225 + type_bytes = numpy.dtype(self.header.data_type).itemsize
  226 +
  227 + self.file.seek(0)
  228 + if self.header.interleave == "bip":
  229 + spectrum = numpy.zeros(B, dtype=self.header.data_type)
  230 + flatmask = numpy.reshape(mask, (X * Y))
  231 + i = numpy.flatnonzero(flatmask)
  232 + bar = progressbar.ProgressBar(max_value = P)
  233 + for p in range(0, P):
  234 + self.file.seek(i[p] * B * type_bytes)
  235 + self.file.readinto(spectrum)
  236 + M[:, p] = spectrum
  237 + bar.update(p+1)
  238 + if self.header.interleave == "bsq":
  239 + band = numpy.zeros(mask.shape, dtype=self.header.data_type)
  240 + i = numpy.nonzero(mask)
  241 + bar = progressbar.ProgressBar(max_value=B)
  242 + for b in range(0, B):
  243 + self.file.seek(b * X * Y * type_bytes)
  244 + self.file.readinto(band)
  245 + M[b, :] = band[i]
  246 + bar.update(b+1)
  247 + if self.header.interleave == "bil":
  248 + plane = numpy.zeros((B, X), dtype=self.header.data_type)
  249 + p = 0
  250 + bar = progressbar.ProgressBar(max_value=Y)
  251 + for l in range(0, Y):
  252 + i = numpy.flatnonzero(mask[l, :])
  253 + self.file.readinto(plane)
  254 + M[:, p:p+i.shape[0]] = plane[:, i]
  255 + p = p + i.shape[0]
  256 + bar.update(l+1)
  257 + return M
  258 +
  259 +
  260 +
  261 + def __del__(self):
  262 + self.file.close()
0 263 \ No newline at end of file
... ...
python/spectral.py 0 → 100644
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +Created on Sun Jul 23 13:52:22 2017
  4 +
  5 +@author: david
  6 +"""
  7 +import numpy
  8 +
  9 +#sift a 2D hyperspectral image into a PxB matrix where P is the number of pixels and B is the number of bands
  10 +def sift2(I, mask = []):
  11 +
  12 + #get the shape of the input array
  13 + S = I.shape
  14 +
  15 + #convert that array into a 1D matrix
  16 + M = numpy.reshape(I, (S[2], S[0] * S[1]))
  17 +
  18 + #gif no mask is provided, just return all pixels
  19 + if mask == []:
  20 + return M
  21 +
  22 + #if a mask is provided, only return pixels corresponding to that mask
  23 + flatmask = numpy.reshape(mask, (S[0] * S[1]))
  24 + i = numpy.flatnonzero(flatmask) #get the nonzero indices
  25 + return M[:, i] #return pixels corresponding to the masked values
  26 +
  27 +def unsift2(M, mask):
  28 +
  29 + #get the size of the input matrix
  30 + S = M.shape
  31 +
  32 + #count the number of nonzero values in the mask
  33 + nnz = numpy.count_nonzero(mask)
  34 +
  35 + #the number of masked values should be the same as the number of pixels in the input matrix
  36 + if len(S) == 1:
  37 + if not S[0] == nnz:
  38 + print("ERROR: expected " + str(nnz) + " pixels based on the mask but there are " + str(S[0]) + " in the matrix.")
  39 + elif not S[1] == nnz:
  40 + print("ERROR: expected " + str(nnz) + " pixels based on the mask but there are " + str(S[1]) + " in the matrix.")
  41 +
  42 +
  43 + i = numpy.nonzero(mask)
  44 +
  45 + if len(S) == 1:
  46 + I = numpy.zeros((1, mask.shape[0], mask.shape[1]), dtype=M.dtype)
  47 + else:
  48 + I = numpy.zeros((M.shape[0], mask.shape[0], mask.shape[1]), dtype=M.dtype)
  49 + I[:, i[0], i[1]] = M
  50 + return I
0 51 \ No newline at end of file
... ...