Blame view

python/classify.py 6.15 KB
18368aa9   David Mayerich   added a new set o...
1
2
3
4
5
6
7
8
9
  # -*- coding: utf-8 -*-
  """
  Created on Sun Jul 23 16:04:33 2017
  
  @author: david
  """
  
  import numpy
  import colorsys
90c935e3   David Mayerich   updates
10
11
  import sklearn
  import sklearn.metrics
f275c01c   David Mayerich   finalized the pre...
12
13
14
  import scipy
  import scipy.misc
  import envi
9b3cbdda   David Mayerich   changed the name ...
15
  import hyperspectral
777fef6c   David Mayerich   added digital sta...
16
  import random
35534291   sberisha   envi.py: added fu...
17
  import pyprind
59f31ea4   David Mayerich   fixed bugs in env...
18
  import matplotlib.pyplot as plt
18368aa9   David Mayerich   added a new set o...
19
  
254e3bc8   David Mayerich   added a full clas...
20
21
22
23
24
25
26
27
28
29
30
31
32
33
  #generate N qualitative colors and return the value for color c
  def qualcolor(c, N):
      dN = numpy.ceil(numpy.sqrt(N)).astype(numpy.int32)
      h = c/N
      
      sp = c/N * 2 * numpy.pi * dN + numpy.pi/2
      s = numpy.sin(sp) * 0.25 + 0.75
      
      vp = c/N * 2 * numpy.pi * dN
      v = numpy.sin(vp) * 0.25 + 0.75
      
      rgb = numpy.array(colorsys.hsv_to_rgb(h, s, v))
      return rgb * 255
  
18368aa9   David Mayerich   added a new set o...
34
  #generate a 2D color class map using a stack of binary class images
f275c01c   David Mayerich   finalized the pre...
35
36
  #input: C is a C x Y x X binary image
  #output: an RGB color image with a unique color for each class
71ee1c23   David Mayerich   added additional ...
37
  def class2color(C):
18368aa9   David Mayerich   added a new set o...
38
39
      
      #determine the number of classes
f275c01c   David Mayerich   finalized the pre...
40
41
42
43
      nc = C.shape[0]
      
      s = C.shape[1:]
      s = numpy.append(s, 3)
18368aa9   David Mayerich   added a new set o...
44
45
  
      #generate an RGB image
f275c01c   David Mayerich   finalized the pre...
46
      RGB = numpy.zeros(s, dtype=numpy.ubyte)
18368aa9   David Mayerich   added a new set o...
47
48
49
      
      #for each class
      for c in range(0, nc):
254e3bc8   David Mayerich   added a full clas...
50
          color = qualcolor(c, nc)
f275c01c   David Mayerich   finalized the pre...
51
          RGB[C[c, ...], :] = color
18368aa9   David Mayerich   added a new set o...
52
      
538df2a2   David Mayerich   added function de...
53
54
55
56
      return RGB
  
  #create a function that loads a set of class images as a stack of binary masks
  #input: list of class image names
f275c01c   David Mayerich   finalized the pre...
57
  #output: C x Y x X binary image specifying class/pixel membership
538df2a2   David Mayerich   added function de...
58
  #example: image2class(("class_coll.bmp", "class_epith.bmp"))
59f31ea4   David Mayerich   fixed bugs in env...
59
  def filenames2class(masks):
6b2be991   sberisha   added utility fun...
60
61
62
63
64
65
66
67
      #get num of mask file names
      num_masks = len(masks)
  
      if num_masks == 0:
          print("ERROR: mask filenames not provided")
          print("Usage example: image2class(('class_coll.bmp', 'class_epith.bmp'))")
          return
  
f275c01c   David Mayerich   finalized the pre...
68
      classimages = []
35534291   sberisha   envi.py: added fu...
69
70
      #bar = progressbar.ProgressBar(max_value=num_masks)
      bar = pyprind.ProgBar(num_masks)
59f31ea4   David Mayerich   fixed bugs in env...
71
72
      for m in range(0, num_masks):
          img = scipy.misc.imread(masks[m], flatten=True).astype(numpy.bool)
f275c01c   David Mayerich   finalized the pre...
73
          classimages.append(img)
35534291   sberisha   envi.py: added fu...
74
75
          #bar.update(m+1)
          bar.update()
538df2a2   David Mayerich   added function de...
76
  
f275c01c   David Mayerich   finalized the pre...
77
78
      result = numpy.stack(classimages)
      sum_images = numpy.sum(result.astype(numpy.uint32), 0)
6b2be991   sberisha   added utility fun...
79
  
f275c01c   David Mayerich   finalized the pre...
80
81
82
      #identify and remove redundant pixels
      bad_idx = sum_images > 1
      result[:, bad_idx] = 0
6b2be991   sberisha   added utility fun...
83
  
f275c01c   David Mayerich   finalized the pre...
84
      return result
6b2be991   sberisha   added utility fun...
85
  
6b2be991   sberisha   added utility fun...
86
  
f275c01c   David Mayerich   finalized the pre...
87
88
89
  #create a class mask stack from an C x Y x X probability image
  #input: C x Y x X image giving the probability P(c |x,y)
  #output: C x Y x X binary class image
6b2be991   sberisha   added utility fun...
90
  def prob2class(prob_image):
f275c01c   David Mayerich   finalized the pre...
91
      class_image = numpy.zeros(prob_image.shape, dtype=numpy.bool)
6b2be991   sberisha   added utility fun...
92
93
      #get nonzero indices
      nnz_idx = numpy.transpose(numpy.nonzero(numpy.sum(prob_image, axis=0)))
fad1a72c   David Mayerich   merged sebastian ...
94
      
6b2be991   sberisha   added utility fun...
95
96
97
98
99
100
      #set pixel corresponding to max probability to 1
      for idx in nnz_idx:
          idx_max_prob = numpy.argmax(prob_image[:, idx[0], idx[1]])
          class_image[idx_max_prob, idx[0], idx[1]] = 1
  
      return class_image
f275c01c   David Mayerich   finalized the pre...
101
  
90c935e3   David Mayerich   updates
102
  #calculate an ROC curve given a probability image and mask of "True" values
f275c01c   David Mayerich   finalized the pre...
103
104
105
106
107
108
109
110
111
112
113
114
  #input:
  #       P is a Y x X probability image specifying P(c | x,y)
  #       t_vals is a Y x X binary image specifying points where x,y = c
  #       mask is a mask specifying all pixels to be considered (positives and negatives)
  #           use this mask to limit analysis to regions of the image that have been classified
  #output: fpr, tpr, thresholds
  #       fpr is the false-positive rate (x-axis of an ROC curve)
  #       tpr is the true-positive rate (y-axis of an ROC curve)
  #       thresholds stores the threshold associated with each point on the ROC curve
  #
  #note: the AUC can be calculated as auc = sklearn.metrics.auc(fpr, tpr)
  def prob2roc(P, t_vals, mask=[]):
90c935e3   David Mayerich   updates
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
      
      if not P.shape == t_vals.shape:
          print("ERROR: the probability and mask images must be the same shape")
          return
      
      #if a mask image isn't provided, create one for the entire image
      if mask == []:
          mask = numpy.ones(t_vals.shape, dtype=numpy.bool)
      
      #create masks for the positive and negative probability scores
      mask_p = t_vals
      mask_n = mask - mask * t_vals
      
      #calculate the indices for the positive and negative scores
      idx_p = numpy.nonzero(mask_p)
      idx_n = numpy.nonzero(mask_n)
      
      Pp = P[idx_p]
      Pn = P[idx_n]
6b2be991   sberisha   added utility fun...
134
  
90c935e3   David Mayerich   updates
135
136
137
138
139
140
      Lp = numpy.ones((Pp.shape), dtype=numpy.bool)
      Ln = numpy.zeros((Pn.shape), dtype=numpy.bool)
      
      scores = numpy.concatenate((Pp, Pn))
      labels = numpy.concatenate((Lp, Ln))
      
fad1a72c   David Mayerich   merged sebastian ...
141
      return sklearn.metrics.roc_curve(labels, scores)
f275c01c   David Mayerich   finalized the pre...
142
  
71ee1c23   David Mayerich   added additional ...
143
  #convert a label image to a C x Y x X class image
777fef6c   David Mayerich   added digital sta...
144
  def label2class(L, background=[]):
71ee1c23   David Mayerich   added additional ...
145
      unique = numpy.unique(L)
777fef6c   David Mayerich   added digital sta...
146
147
148
  
      if not background == []:                                                #if a background value is specified
          unique = numpy.delete(unique, numpy.nonzero(unique == background))  #remove it from the label array
71ee1c23   David Mayerich   added additional ...
149
      s = L.shape
777fef6c   David Mayerich   added digital sta...
150
      s = numpy.append(numpy.array((len(unique))), s)
71ee1c23   David Mayerich   added additional ...
151
      C = numpy.zeros(s, dtype=numpy.bool)
777fef6c   David Mayerich   added digital sta...
152
153
      for i in range(0, len(unique)):
          C[i, :, :] = L == unique[i]
71ee1c23   David Mayerich   added additional ...
154
155
      return C
  
777fef6c   David Mayerich   added digital sta...
156
157
158
159
160
161
162
163
  #randomizes a given mask to include a subset of n pixels in the original
  def random_mask(M, n):
      idx = numpy.flatnonzero(M)
      new_idx = numpy.random.permutation(idx)
      new_mask = numpy.zeros(M.shape, dtype=numpy.bool)
      new_mask[numpy.unravel_index(new_idx[0:n], new_mask.shape)] = True
      return new_mask
  
eccf10ff   David Mayerich   added comments fo...
164
165
166
167
  #perform classification of an ENVI image using batch processing
  # input:    E is the ENVI object (file is assumed to be loaded)
  #           C is a classifier - anything in sklearn should work
  #           batch is the batch size
59f31ea4   David Mayerich   fixed bugs in env...
168
169
  def envi_batch_predict(E, C, batch=10000):
  
254e3bc8   David Mayerich   added a full clas...
170
      Fv = E.loadbatch(batch)
59f31ea4   David Mayerich   fixed bugs in env...
171
172
173
      i = 0
      Tv = []
      plt.ion()
35534291   sberisha   envi.py: added fu...
174
175
      #bar = progressbar.ProgressBar(max_value=numpy.count_nonzero(E.mask))
      bar = pyprind.ProgBar(numpy.count_nonzero(E.mask))
59f31ea4   David Mayerich   fixed bugs in env...
176
      while not Fv == []:
ed3d7d30   David Mayerich   updated NWT file ...
177
          Fv = numpy.nan_to_num(Fv)                                                     #remove infinite values        
59f31ea4   David Mayerich   fixed bugs in env...
178
          if i == 0:
254e3bc8   David Mayerich   added a full clas...
179
              Tv = C.predict(Fv.transpose())
59f31ea4   David Mayerich   fixed bugs in env...
180
          else:
254e3bc8   David Mayerich   added a full clas...
181
              Tv = numpy.concatenate((Tv, C.predict(Fv.transpose()).transpose()), 0)
59f31ea4   David Mayerich   fixed bugs in env...
182
          tempmask = E.batchmask()
9b3cbdda   David Mayerich   changed the name ...
183
          Lv = hyperspectral.unsift2(Tv, tempmask)
59f31ea4   David Mayerich   fixed bugs in env...
184
185
186
187
          Cv = label2class(Lv.squeeze(), background=0)
          RGB = class2color(Cv)
          plt.imshow(RGB)
          plt.pause(0.05)
254e3bc8   David Mayerich   added a full clas...
188
          Fv = E.loadbatch(batch)   
59f31ea4   David Mayerich   fixed bugs in env...
189
          i = i + 1
35534291   sberisha   envi.py: added fu...
190
191
          #bar.update(len(Tv))
          bar.update()