Commit fad1a72c9bb82576dcd0ff7eaa661d1c3f586662
merged sebastian and my work
Showing
2 changed files
with
46 additions
and
5 deletions
Show diff stats
python/classify.py
... | ... | @@ -7,6 +7,8 @@ Created on Sun Jul 23 16:04:33 2017 |
7 | 7 | |
8 | 8 | import numpy |
9 | 9 | import colorsys |
10 | +import sklearn | |
11 | +import sklearn.metrics | |
10 | 12 | from scipy import misc |
11 | 13 | from envi import envi |
12 | 14 | |
... | ... | @@ -89,15 +91,39 @@ def prob2class(prob_image): |
89 | 91 | class_image = numpy.zeros_like(prob_image) |
90 | 92 | #get nonzero indices |
91 | 93 | nnz_idx = numpy.transpose(numpy.nonzero(numpy.sum(prob_image, axis=0))) |
92 | - | |
94 | + | |
93 | 95 | #set pixel corresponding to max probability to 1 |
94 | 96 | for idx in nnz_idx: |
95 | 97 | idx_max_prob = numpy.argmax(prob_image[:, idx[0], idx[1]]) |
96 | 98 | class_image[idx_max_prob, idx[0], idx[1]] = 1 |
97 | 99 | |
98 | 100 | return class_image |
101 | +#calculate an ROC curve given a probability image and mask of "True" values | |
102 | +def image2roc(P, t_vals, mask=[]): | |
103 | + | |
104 | + if not P.shape == t_vals.shape: | |
105 | + print("ERROR: the probability and mask images must be the same shape") | |
106 | + return | |
107 | + | |
108 | + #if a mask image isn't provided, create one for the entire image | |
109 | + if mask == []: | |
110 | + mask = numpy.ones(t_vals.shape, dtype=numpy.bool) | |
111 | + | |
112 | + #create masks for the positive and negative probability scores | |
113 | + mask_p = t_vals | |
114 | + mask_n = mask - mask * t_vals | |
115 | + | |
116 | + #calculate the indices for the positive and negative scores | |
117 | + idx_p = numpy.nonzero(mask_p) | |
118 | + idx_n = numpy.nonzero(mask_n) | |
119 | + | |
120 | + Pp = P[idx_p] | |
121 | + Pn = P[idx_n] | |
99 | 122 | |
100 | - | |
101 | -#create an ROC curve calculator | |
102 | -#input: X x Y x C image giving the probability P(c | x,y) | |
103 | -#output: ROC curve | |
104 | 123 | \ No newline at end of file |
124 | + Lp = numpy.ones((Pp.shape), dtype=numpy.bool) | |
125 | + Ln = numpy.zeros((Pn.shape), dtype=numpy.bool) | |
126 | + | |
127 | + scores = numpy.concatenate((Pp, Pn)) | |
128 | + labels = numpy.concatenate((Lp, Ln)) | |
129 | + | |
130 | + return sklearn.metrics.roc_curve(labels, scores) | ... | ... |
python/envi.py
... | ... | @@ -263,6 +263,21 @@ class envi: |
263 | 263 | p = p + i.shape[0] |
264 | 264 | bar.update(l+1) |
265 | 265 | return M |
266 | + | |
267 | + def loadband(self, n): | |
268 | + X = self.header.samples | |
269 | + Y = self.header.lines | |
270 | + B = self.header.bands | |
271 | + | |
272 | + band = numpy.zeros((Y, X), dtype=self.header.data_type) | |
273 | + type_bytes = numpy.dtype(self.header.data_type).itemsize | |
274 | + | |
275 | + if self.header.interleave == "bsq": | |
276 | + self.file.seek(n * X * Y * type_bytes) | |
277 | + self.file.readinto(band) | |
278 | + | |
279 | + return band | |
280 | + | |
266 | 281 | |
267 | 282 | def __del__(self): |
268 | 283 | self.file.close() |
269 | 284 | \ No newline at end of file | ... | ... |