Commit 90c935e3f1dae3308b3f8987866a9b8836e6321c
1 parent
538df2a2
updates
Showing
2 changed files
with
46 additions
and
3 deletions
Show diff stats
python/classify.py
@@ -7,6 +7,8 @@ Created on Sun Jul 23 16:04:33 2017 | @@ -7,6 +7,8 @@ Created on Sun Jul 23 16:04:33 2017 | ||
7 | 7 | ||
8 | import numpy | 8 | import numpy |
9 | import colorsys | 9 | import colorsys |
10 | +import sklearn | ||
11 | +import sklearn.metrics | ||
10 | 12 | ||
11 | #generate a 2D color class map using a stack of binary class images | 13 | #generate a 2D color class map using a stack of binary class images |
12 | def classcolor2(C): | 14 | def classcolor2(C): |
@@ -39,6 +41,32 @@ def classcolor2(C): | @@ -39,6 +41,32 @@ def classcolor2(C): | ||
39 | #input: X x Y x C image giving the probability P(c |x,y) | 41 | #input: X x Y x C image giving the probability P(c |x,y) |
40 | #output: X x Y x C binary class image | 42 | #output: X x Y x C binary class image |
41 | 43 | ||
42 | -#create an ROC curve calculator | ||
43 | -#input: X x Y x C image giving the probability P(c | x,y) | ||
44 | -#output: ROC curve | ||
45 | \ No newline at end of file | 44 | \ No newline at end of file |
45 | +#calculate an ROC curve given a probability image and mask of "True" values | ||
46 | +def image2roc(P, t_vals, mask=[]): | ||
47 | + | ||
48 | + if not P.shape == t_vals.shape: | ||
49 | + print("ERROR: the probability and mask images must be the same shape") | ||
50 | + return | ||
51 | + | ||
52 | + #if a mask image isn't provided, create one for the entire image | ||
53 | + if mask == []: | ||
54 | + mask = numpy.ones(t_vals.shape, dtype=numpy.bool) | ||
55 | + | ||
56 | + #create masks for the positive and negative probability scores | ||
57 | + mask_p = t_vals | ||
58 | + mask_n = mask - mask * t_vals | ||
59 | + | ||
60 | + #calculate the indices for the positive and negative scores | ||
61 | + idx_p = numpy.nonzero(mask_p) | ||
62 | + idx_n = numpy.nonzero(mask_n) | ||
63 | + | ||
64 | + Pp = P[idx_p] | ||
65 | + Pn = P[idx_n] | ||
66 | + | ||
67 | + Lp = numpy.ones((Pp.shape), dtype=numpy.bool) | ||
68 | + Ln = numpy.zeros((Pn.shape), dtype=numpy.bool) | ||
69 | + | ||
70 | + scores = numpy.concatenate((Pp, Pn)) | ||
71 | + labels = numpy.concatenate((Lp, Ln)) | ||
72 | + | ||
73 | + return sklearn.metrics.roc_curve(labels, scores) | ||
46 | \ No newline at end of file | 74 | \ No newline at end of file |
python/envi.py
@@ -254,6 +254,21 @@ class envi: | @@ -254,6 +254,21 @@ class envi: | ||
254 | p = p + i.shape[0] | 254 | p = p + i.shape[0] |
255 | bar.update(l+1) | 255 | bar.update(l+1) |
256 | return M | 256 | return M |
257 | + | ||
258 | + def loadband(self, n): | ||
259 | + X = self.header.samples | ||
260 | + Y = self.header.lines | ||
261 | + B = self.header.bands | ||
262 | + | ||
263 | + band = numpy.zeros((Y, X), dtype=self.header.data_type) | ||
264 | + type_bytes = numpy.dtype(self.header.data_type).itemsize | ||
265 | + | ||
266 | + if self.header.interleave == "bsq": | ||
267 | + self.file.seek(n * X * Y * type_bytes) | ||
268 | + self.file.readinto(band) | ||
269 | + | ||
270 | + return band | ||
271 | + | ||
257 | 272 | ||
258 | def __del__(self): | 273 | def __del__(self): |
259 | self.file.close() | 274 | self.file.close() |
260 | \ No newline at end of file | 275 | \ No newline at end of file |