envi.py 9.76 KB
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 21 20:18:01 2017

@author: david
"""

import os
import numpy
import scipy
import matplotlib.pyplot as plt
import progressbar

class envi_header:
    def __init__(self, filename = ""):
        if filename != "":
            self.load(filename)
        else:
            self.initialize()
        
    #initialization function
    def initialize(self):
        self.samples = int(0)
        self.lines = int(0)
        self.bands = int(0)
        self.header_offset = int(0)
        self.data_type = int(4)
        self.interleave = "bsq"
        self.sensor_type = ""
        self.byte_order = int(0)
        self.x_start = int(0)
        self.y_start = int(0)
        self.z_plot_titles = ""
        self.pixel_size = [float(0), float(0)]
        self.pixel_size_units = "Meters"
        self.wavelength_units = "Wavenumber"
        self.description = ""
        self.band_names = []
        self.wavelength = []
        
    #convert an ENVI data_type value to a numpy data type        
    def get_numpy_type(self, val):
        if val == 1:
            return numpy.byte
        elif val == 2:
            return numpy.int16
        elif val == 3:
            return numpy.int32
        elif val == 4:
            return numpy.float32
        elif val == 5:
            return numpy.float64
        elif val == 6:
            return numpy.complex64
        elif val == 9:
            return numpy.complex128
        elif val == 12:
            return numpy.uint16
        elif val == 13:
            return numpy.uint32
        elif val == 14:
            return numpy.int64
        elif val == 15:
            return numpy.uint64
    
    def get_envi_type(self, val):
        if val == numpy.byte:
            return 1
        elif val == numpy.int16:
            return 2
        elif val == numpy.int32:
            return 3
        elif val == numpy.float32:
            return 4
        elif val == numpy.float64:
            return 5
        elif val == numpy.complex64:
            return 6
        elif val == numpy.complex128:
            return 9
        elif val == numpy.uint16:
            return 12
        elif val == numpy.uint32:
            return 13
        elif val == numpy.int64:
            return 14
        elif val == numpy.uint64:
            return 15
            
    def load(self, fname):
        f = open(fname)
        l = f.readlines()
        if l[0].strip() != "ENVI":
            print("ERROR: not an ENVI file")
            return
        li = 1
        while li < len(l):
            #t = l[li].split()               #split the line into tokens
            #t = map(str.strip, t)               #strip all of the tokens in the token list
            
            #handle the simple conditions
            if l[li].startswith("file type"):
                if not l[li].strip().endswith("ENVI Standard"):
                    print("ERROR: unsupported ENVI file format: " + l[li].strip())
                    return
            elif l[li].startswith("samples"):
                self.samples = int(l[li].split()[-1])
            elif l[li].startswith("lines"):
                self.lines = int(l[li].split()[-1])
            elif l[li].startswith("bands"):
                self.bands = int(l[li].split()[-1])
            elif l[li].startswith("header offset"):
                self.header_offset = int(l[li].split()[-1])
            elif l[li].startswith("data type"):
                self.data_type = self.get_numpy_type(int(l[li].split()[-1]))
            elif l[li].startswith("interleave"):
                self.interleave = l[li].split()[-1].strip()
            elif l[li].startswith("sensor type"):
                self.sensor_type = l[li].split()[-1].strip()
            elif l[li].startswith("byte order"):
                self.byte_order = int(l[li].split()[-1])
            elif l[li].startswith("x start"):
                self.x_start = int(l[li].split()[-1])
            elif l[li].startswith("y start"):
                self.y_start = int(l[li].split()[-1])
            elif l[li].startswith("z plot titles"):
                i0 = l[li].rindex('{')
                i1 = l[li].rindex('}')
                self.z_plot_titles = l[li][i0 + 1 : i1]
            elif l[li].startswith("pixel size"):
                i0 = l[li].rindex('{')
                i1 = l[li].rindex('}')
                s = l[li][i0 + 1 : i1].split(',')
                self.pixel_size = [float(s[0]), float(s[1])]
                self.pixel_size_units = s[2][s[2].rindex('=') + 1:].strip()
            elif l[li].startswith("wavelength units"):
                self.wavelength_units = l[li].split()[-1].strip()                
            
            #handle the complicated conditions
            elif l[li].startswith("description"):
                desc = [l[li]]
                while l[li].strip()[-1] != '}':
                    li += 1
                    desc.append(l[li])
                desc = ''.join(list(map(str.strip, desc)))           #strip all white space from the string list
                i0 = desc.rindex('{')
                i1 = desc.rindex('}')
                self.description = desc[i0 + 1 : i1]
                
            elif l[li].startswith("band names"):
                names = [l[li]]
                while l[li].strip()[-1] != '}':
                    li += 1
                    names.append(l[li])
                names = ''.join(list(map(str.strip, names)))           #strip all white space from the string list
                i0 = names.rindex('{')
                i1 = names.rindex('}')
                names = names[i0 + 1 : i1]
                self.band_names = list(map(str.strip, names.split(',')))
            elif l[li].startswith("wavelength"):
                waves = [l[li]]
                while l[li].strip()[-1] != '}':
                    li += 1
                    waves.append(l[li])
                waves = ''.join(list(map(str.strip, waves)))           #strip all white space from the string list
                i0 = waves.rindex('{')
                i1 = waves.rindex('}')
                waves = waves[i0 + 1 : i1]
                self.wavelength = list(map(float, waves.split(',')))

            li += 1          
        
        f.close()
        
class envi:
    def __init__(self, filename, headername = "", maskname = ""):
        self.open(filename, headername)
        if maskname == "":
            self.mask = numpy.ones((self.header.samples, self.header.lines), dtype=numpy.bool)
        else:
            self.mask = scipy.misc.imread(maskname, flatten=True).astype(numpy.bool)
        
    def open(self, filename, headername = ""):
        if headername == "":
            headername = filename + ".hdr"
            
        if not os.path.isfile(filename):
            print("ERROR: " + filename + " not found")
            return
        if not os.path.isfile(headername):
            print("ERROR: " + headername + " not found")
            return
        
        #open the file
        self.header = envi_header(headername)
        self.file = open(filename, "rb")
        
    def loadall(self):
        X = self.header.samples
        Y = self.header.lines
        B = self.header.bands
        
        #load the data
        D = numpy.fromfile(self.file, dtype=self.header.data_type)
        
        if self.header.interleave == "bsq":
            return numpy.reshape(D, (B, Y, X))
            #return numpy.swapaxes(D, 0, 2)
        elif self.header.interleave == "bip":
            D = numpy.reshape(D, (Y, X, B))
            return numpy.rollaxis(D, 2)
        elif self.header.interleave == "bil":
            D = numpy.reshape(D, (Y, B, X))
            return numpy.rollaxis(D, 1)
        
    #loads all of the pixels where mask != 0 and returns them as a matrix
    def loadmask(self, mask):
        X = self.header.samples
        Y = self.header.lines
        B = self.header.bands
        
        P = numpy.count_nonzero(mask)           #count the number of zeros in the mask file
        M = numpy.zeros((B, P), dtype=self.header.data_type)
        type_bytes = numpy.dtype(self.header.data_type).itemsize
        
        self.file.seek(0)
        if self.header.interleave == "bip":
            spectrum = numpy.zeros(B, dtype=self.header.data_type)
            flatmask = numpy.reshape(mask, (X * Y))
            i = numpy.flatnonzero(flatmask)
            bar = progressbar.ProgressBar(max_value = P)
            for p in range(0, P):
                self.file.seek(i[p] * B * type_bytes)
                self.file.readinto(spectrum)
                M[:, p] = spectrum
                bar.update(p+1)
        if self.header.interleave == "bsq":
            band = numpy.zeros(mask.shape, dtype=self.header.data_type)
            i = numpy.nonzero(mask)
            bar = progressbar.ProgressBar(max_value=B)
            for b in range(0, B):
                self.file.seek(b * X * Y * type_bytes)
                self.file.readinto(band)
                M[b, :] = band[i]
                bar.update(b+1)
        if self.header.interleave == "bil":
            plane = numpy.zeros((B, X), dtype=self.header.data_type)
            p = 0
            bar = progressbar.ProgressBar(max_value=Y)
            for l in range(0, Y):
                i = numpy.flatnonzero(mask[l, :])
                self.file.readinto(plane)
                M[:, p:p+i.shape[0]] = plane[:, i]
                p = p + i.shape[0]
                bar.update(l+1)
        return M

    def loadband(self, n):
        X = self.header.samples
        Y = self.header.lines
        B = self.header.bands

        band = numpy.zeros((Y, X), dtype=self.header.data_type)
        type_bytes = numpy.dtype(self.header.data_type).itemsize

        if self.header.interleave == "bsq":
            self.file.seek(n * X * Y * type_bytes)
            self.file.readinto(band)

        return band

            
    def __del__(self):
        self.file.close()