Blame view

python/classify.py 6.02 KB
18368aa9   David Mayerich   added a new set o...
1
2
3
4
5
6
7
8
9
  # -*- coding: utf-8 -*-
  """
  Created on Sun Jul 23 16:04:33 2017
  
  @author: david
  """
  
  import numpy
  import colorsys
90c935e3   David Mayerich   updates
10
11
  import sklearn
  import sklearn.metrics
f275c01c   David Mayerich   finalized the pre...
12
13
14
  import scipy
  import scipy.misc
  import envi
9b3cbdda   David Mayerich   changed the name ...
15
  import hyperspectral
777fef6c   David Mayerich   added digital sta...
16
  import random
59f31ea4   David Mayerich   fixed bugs in env...
17
18
  import progressbar
  import matplotlib.pyplot as plt
18368aa9   David Mayerich   added a new set o...
19
  
254e3bc8   David Mayerich   added a full clas...
20
21
22
23
24
25
26
27
28
29
30
31
32
33
  #generate N qualitative colors and return the value for color c
  def qualcolor(c, N):
      dN = numpy.ceil(numpy.sqrt(N)).astype(numpy.int32)
      h = c/N
      
      sp = c/N * 2 * numpy.pi * dN + numpy.pi/2
      s = numpy.sin(sp) * 0.25 + 0.75
      
      vp = c/N * 2 * numpy.pi * dN
      v = numpy.sin(vp) * 0.25 + 0.75
      
      rgb = numpy.array(colorsys.hsv_to_rgb(h, s, v))
      return rgb * 255
  
18368aa9   David Mayerich   added a new set o...
34
  #generate a 2D color class map using a stack of binary class images
f275c01c   David Mayerich   finalized the pre...
35
36
  #input: C is a C x Y x X binary image
  #output: an RGB color image with a unique color for each class
71ee1c23   David Mayerich   added additional ...
37
  def class2color(C):
18368aa9   David Mayerich   added a new set o...
38
39
      
      #determine the number of classes
f275c01c   David Mayerich   finalized the pre...
40
41
42
43
      nc = C.shape[0]
      
      s = C.shape[1:]
      s = numpy.append(s, 3)
18368aa9   David Mayerich   added a new set o...
44
45
  
      #generate an RGB image
f275c01c   David Mayerich   finalized the pre...
46
      RGB = numpy.zeros(s, dtype=numpy.ubyte)
18368aa9   David Mayerich   added a new set o...
47
48
49
      
      #for each class
      for c in range(0, nc):
254e3bc8   David Mayerich   added a full clas...
50
          color = qualcolor(c, nc)
f275c01c   David Mayerich   finalized the pre...
51
          RGB[C[c, ...], :] = color
18368aa9   David Mayerich   added a new set o...
52
      
538df2a2   David Mayerich   added function de...
53
54
55
56
      return RGB
  
  #create a function that loads a set of class images as a stack of binary masks
  #input: list of class image names
f275c01c   David Mayerich   finalized the pre...
57
  #output: C x Y x X binary image specifying class/pixel membership
538df2a2   David Mayerich   added function de...
58
  #example: image2class(("class_coll.bmp", "class_epith.bmp"))
59f31ea4   David Mayerich   fixed bugs in env...
59
  def filenames2class(masks):
6b2be991   sberisha   added utility fun...
60
61
62
63
64
65
66
67
      #get num of mask file names
      num_masks = len(masks)
  
      if num_masks == 0:
          print("ERROR: mask filenames not provided")
          print("Usage example: image2class(('class_coll.bmp', 'class_epith.bmp'))")
          return
  
f275c01c   David Mayerich   finalized the pre...
68
      classimages = []
59f31ea4   David Mayerich   fixed bugs in env...
69
70
71
      bar = progressbar.ProgressBar(max_value=num_masks)
      for m in range(0, num_masks):
          img = scipy.misc.imread(masks[m], flatten=True).astype(numpy.bool)
f275c01c   David Mayerich   finalized the pre...
72
          classimages.append(img)
59f31ea4   David Mayerich   fixed bugs in env...
73
          bar.update(m+1)
538df2a2   David Mayerich   added function de...
74
  
f275c01c   David Mayerich   finalized the pre...
75
76
      result = numpy.stack(classimages)
      sum_images = numpy.sum(result.astype(numpy.uint32), 0)
6b2be991   sberisha   added utility fun...
77
  
f275c01c   David Mayerich   finalized the pre...
78
79
80
      #identify and remove redundant pixels
      bad_idx = sum_images > 1
      result[:, bad_idx] = 0
6b2be991   sberisha   added utility fun...
81
  
f275c01c   David Mayerich   finalized the pre...
82
      return result
6b2be991   sberisha   added utility fun...
83
  
6b2be991   sberisha   added utility fun...
84
  
f275c01c   David Mayerich   finalized the pre...
85
86
87
  #create a class mask stack from an C x Y x X probability image
  #input: C x Y x X image giving the probability P(c |x,y)
  #output: C x Y x X binary class image
6b2be991   sberisha   added utility fun...
88
  def prob2class(prob_image):
f275c01c   David Mayerich   finalized the pre...
89
      class_image = numpy.zeros(prob_image.shape, dtype=numpy.bool)
6b2be991   sberisha   added utility fun...
90
91
      #get nonzero indices
      nnz_idx = numpy.transpose(numpy.nonzero(numpy.sum(prob_image, axis=0)))
fad1a72c   David Mayerich   merged sebastian ...
92
      
6b2be991   sberisha   added utility fun...
93
94
95
96
97
98
      #set pixel corresponding to max probability to 1
      for idx in nnz_idx:
          idx_max_prob = numpy.argmax(prob_image[:, idx[0], idx[1]])
          class_image[idx_max_prob, idx[0], idx[1]] = 1
  
      return class_image
f275c01c   David Mayerich   finalized the pre...
99
  
90c935e3   David Mayerich   updates
100
  #calculate an ROC curve given a probability image and mask of "True" values
f275c01c   David Mayerich   finalized the pre...
101
102
103
104
105
106
107
108
109
110
111
112
  #input:
  #       P is a Y x X probability image specifying P(c | x,y)
  #       t_vals is a Y x X binary image specifying points where x,y = c
  #       mask is a mask specifying all pixels to be considered (positives and negatives)
  #           use this mask to limit analysis to regions of the image that have been classified
  #output: fpr, tpr, thresholds
  #       fpr is the false-positive rate (x-axis of an ROC curve)
  #       tpr is the true-positive rate (y-axis of an ROC curve)
  #       thresholds stores the threshold associated with each point on the ROC curve
  #
  #note: the AUC can be calculated as auc = sklearn.metrics.auc(fpr, tpr)
  def prob2roc(P, t_vals, mask=[]):
90c935e3   David Mayerich   updates
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
      
      if not P.shape == t_vals.shape:
          print("ERROR: the probability and mask images must be the same shape")
          return
      
      #if a mask image isn't provided, create one for the entire image
      if mask == []:
          mask = numpy.ones(t_vals.shape, dtype=numpy.bool)
      
      #create masks for the positive and negative probability scores
      mask_p = t_vals
      mask_n = mask - mask * t_vals
      
      #calculate the indices for the positive and negative scores
      idx_p = numpy.nonzero(mask_p)
      idx_n = numpy.nonzero(mask_n)
      
      Pp = P[idx_p]
      Pn = P[idx_n]
6b2be991   sberisha   added utility fun...
132
  
90c935e3   David Mayerich   updates
133
134
135
136
137
138
      Lp = numpy.ones((Pp.shape), dtype=numpy.bool)
      Ln = numpy.zeros((Pn.shape), dtype=numpy.bool)
      
      scores = numpy.concatenate((Pp, Pn))
      labels = numpy.concatenate((Lp, Ln))
      
fad1a72c   David Mayerich   merged sebastian ...
139
      return sklearn.metrics.roc_curve(labels, scores)
f275c01c   David Mayerich   finalized the pre...
140
  
71ee1c23   David Mayerich   added additional ...
141
  #convert a label image to a C x Y x X class image
777fef6c   David Mayerich   added digital sta...
142
  def label2class(L, background=[]):
71ee1c23   David Mayerich   added additional ...
143
      unique = numpy.unique(L)
777fef6c   David Mayerich   added digital sta...
144
145
146
  
      if not background == []:                                                #if a background value is specified
          unique = numpy.delete(unique, numpy.nonzero(unique == background))  #remove it from the label array
71ee1c23   David Mayerich   added additional ...
147
      s = L.shape
777fef6c   David Mayerich   added digital sta...
148
      s = numpy.append(numpy.array((len(unique))), s)
71ee1c23   David Mayerich   added additional ...
149
      C = numpy.zeros(s, dtype=numpy.bool)
777fef6c   David Mayerich   added digital sta...
150
151
      for i in range(0, len(unique)):
          C[i, :, :] = L == unique[i]
71ee1c23   David Mayerich   added additional ...
152
153
      return C
  
777fef6c   David Mayerich   added digital sta...
154
155
156
157
158
159
160
161
  #randomizes a given mask to include a subset of n pixels in the original
  def random_mask(M, n):
      idx = numpy.flatnonzero(M)
      new_idx = numpy.random.permutation(idx)
      new_mask = numpy.zeros(M.shape, dtype=numpy.bool)
      new_mask[numpy.unravel_index(new_idx[0:n], new_mask.shape)] = True
      return new_mask
  
eccf10ff   David Mayerich   added comments fo...
162
163
164
165
  #perform classification of an ENVI image using batch processing
  # input:    E is the ENVI object (file is assumed to be loaded)
  #           C is a classifier - anything in sklearn should work
  #           batch is the batch size
59f31ea4   David Mayerich   fixed bugs in env...
166
167
  def envi_batch_predict(E, C, batch=10000):
  
254e3bc8   David Mayerich   added a full clas...
168
      Fv = E.loadbatch(batch)
59f31ea4   David Mayerich   fixed bugs in env...
169
170
171
172
173
      i = 0
      Tv = []
      plt.ion()
      bar = progressbar.ProgressBar(max_value=numpy.count_nonzero(E.mask))
      while not Fv == []:
ed3d7d30   David Mayerich   updated NWT file ...
174
          Fv = numpy.nan_to_num(Fv)                                                     #remove infinite values        
59f31ea4   David Mayerich   fixed bugs in env...
175
          if i == 0:
254e3bc8   David Mayerich   added a full clas...
176
              Tv = C.predict(Fv.transpose())
59f31ea4   David Mayerich   fixed bugs in env...
177
          else:
254e3bc8   David Mayerich   added a full clas...
178
              Tv = numpy.concatenate((Tv, C.predict(Fv.transpose()).transpose()), 0)
59f31ea4   David Mayerich   fixed bugs in env...
179
          tempmask = E.batchmask()
9b3cbdda   David Mayerich   changed the name ...
180
          Lv = hyperspectral.unsift2(Tv, tempmask)
59f31ea4   David Mayerich   fixed bugs in env...
181
182
183
184
          Cv = label2class(Lv.squeeze(), background=0)
          RGB = class2color(Cv)
          plt.imshow(RGB)
          plt.pause(0.05)
254e3bc8   David Mayerich   added a full clas...
185
          Fv = E.loadbatch(batch)   
59f31ea4   David Mayerich   fixed bugs in env...
186
          i = i + 1
ed3d7d30   David Mayerich   updated NWT file ...
187
          bar.update(len(Tv))