Commit 90c935e3f1dae3308b3f8987866a9b8836e6321c
1 parent
538df2a2
updates
Showing
2 changed files
with
46 additions
and
3 deletions
Show diff stats
python/classify.py
... | ... | @@ -7,6 +7,8 @@ Created on Sun Jul 23 16:04:33 2017 |
7 | 7 | |
8 | 8 | import numpy |
9 | 9 | import colorsys |
10 | +import sklearn | |
11 | +import sklearn.metrics | |
10 | 12 | |
11 | 13 | #generate a 2D color class map using a stack of binary class images |
12 | 14 | def classcolor2(C): |
... | ... | @@ -39,6 +41,32 @@ def classcolor2(C): |
39 | 41 | #input: X x Y x C image giving the probability P(c |x,y) |
40 | 42 | #output: X x Y x C binary class image |
41 | 43 | |
42 | -#create an ROC curve calculator | |
43 | -#input: X x Y x C image giving the probability P(c | x,y) | |
44 | -#output: ROC curve | |
45 | 44 | \ No newline at end of file |
45 | +#calculate an ROC curve given a probability image and mask of "True" values | |
46 | +def image2roc(P, t_vals, mask=[]): | |
47 | + | |
48 | + if not P.shape == t_vals.shape: | |
49 | + print("ERROR: the probability and mask images must be the same shape") | |
50 | + return | |
51 | + | |
52 | + #if a mask image isn't provided, create one for the entire image | |
53 | + if mask == []: | |
54 | + mask = numpy.ones(t_vals.shape, dtype=numpy.bool) | |
55 | + | |
56 | + #create masks for the positive and negative probability scores | |
57 | + mask_p = t_vals | |
58 | + mask_n = mask - mask * t_vals | |
59 | + | |
60 | + #calculate the indices for the positive and negative scores | |
61 | + idx_p = numpy.nonzero(mask_p) | |
62 | + idx_n = numpy.nonzero(mask_n) | |
63 | + | |
64 | + Pp = P[idx_p] | |
65 | + Pn = P[idx_n] | |
66 | + | |
67 | + Lp = numpy.ones((Pp.shape), dtype=numpy.bool) | |
68 | + Ln = numpy.zeros((Pn.shape), dtype=numpy.bool) | |
69 | + | |
70 | + scores = numpy.concatenate((Pp, Pn)) | |
71 | + labels = numpy.concatenate((Lp, Ln)) | |
72 | + | |
73 | + return sklearn.metrics.roc_curve(labels, scores) | |
46 | 74 | \ No newline at end of file | ... | ... |
python/envi.py
... | ... | @@ -254,6 +254,21 @@ class envi: |
254 | 254 | p = p + i.shape[0] |
255 | 255 | bar.update(l+1) |
256 | 256 | return M |
257 | + | |
258 | + def loadband(self, n): | |
259 | + X = self.header.samples | |
260 | + Y = self.header.lines | |
261 | + B = self.header.bands | |
262 | + | |
263 | + band = numpy.zeros((Y, X), dtype=self.header.data_type) | |
264 | + type_bytes = numpy.dtype(self.header.data_type).itemsize | |
265 | + | |
266 | + if self.header.interleave == "bsq": | |
267 | + self.file.seek(n * X * Y * type_bytes) | |
268 | + self.file.readinto(band) | |
269 | + | |
270 | + return band | |
271 | + | |
257 | 272 | |
258 | 273 | def __del__(self): |
259 | 274 | self.file.close() |
260 | 275 | \ No newline at end of file | ... | ... |