diff --git a/python/classify.py b/python/classify.py index 96a2a6b..ab5fa97 100644 --- a/python/classify.py +++ b/python/classify.py @@ -7,6 +7,8 @@ Created on Sun Jul 23 16:04:33 2017 import numpy import colorsys +import sklearn +import sklearn.metrics #generate a 2D color class map using a stack of binary class images def classcolor2(C): @@ -39,6 +41,32 @@ def classcolor2(C): #input: X x Y x C image giving the probability P(c |x,y) #output: X x Y x C binary class image -#create an ROC curve calculator -#input: X x Y x C image giving the probability P(c | x,y) -#output: ROC curve \ No newline at end of file +#calculate an ROC curve given a probability image and mask of "True" values +def image2roc(P, t_vals, mask=[]): + + if not P.shape == t_vals.shape: + print("ERROR: the probability and mask images must be the same shape") + return + + #if a mask image isn't provided, create one for the entire image + if mask == []: + mask = numpy.ones(t_vals.shape, dtype=numpy.bool) + + #create masks for the positive and negative probability scores + mask_p = t_vals + mask_n = mask - mask * t_vals + + #calculate the indices for the positive and negative scores + idx_p = numpy.nonzero(mask_p) + idx_n = numpy.nonzero(mask_n) + + Pp = P[idx_p] + Pn = P[idx_n] + + Lp = numpy.ones((Pp.shape), dtype=numpy.bool) + Ln = numpy.zeros((Pn.shape), dtype=numpy.bool) + + scores = numpy.concatenate((Pp, Pn)) + labels = numpy.concatenate((Lp, Ln)) + + return sklearn.metrics.roc_curve(labels, scores) \ No newline at end of file diff --git a/python/envi.py b/python/envi.py index f013b5b..18ef32f 100644 --- a/python/envi.py +++ b/python/envi.py @@ -254,6 +254,21 @@ class envi: p = p + i.shape[0] bar.update(l+1) return M + + def loadband(self, n): + X = self.header.samples + Y = self.header.lines + B = self.header.bands + + band = numpy.zeros((Y, X), dtype=self.header.data_type) + type_bytes = numpy.dtype(self.header.data_type).itemsize + + if self.header.interleave == "bsq": + self.file.seek(n * X * Y * type_bytes) + self.file.readinto(band) + + return band + def __del__(self): self.file.close() \ No newline at end of file -- libgit2 0.21.4