Commit 43865a41e05e9650c893b4032341cc4f72f3dd94

Authored by Jiaming Guo
2 parents 0e0feff0 777fef6c

Merge branch 'master' of git.stim.ee.uh.edu:codebase/stimlib

python/classify.py 0 โ†’ 100644
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +Created on Sun Jul 23 16:04:33 2017
  4 +
  5 +@author: david
  6 +"""
  7 +
  8 +import numpy
  9 +import colorsys
  10 +import sklearn
  11 +import sklearn.metrics
  12 +import scipy
  13 +import scipy.misc
  14 +import envi
  15 +import random
  16 +
  17 +#generate a 2D color class map using a stack of binary class images
  18 +#input: C is a C x Y x X binary image
  19 +#output: an RGB color image with a unique color for each class
  20 +def class2color(C):
  21 +
  22 + #determine the number of classes
  23 + nc = C.shape[0]
  24 +
  25 + s = C.shape[1:]
  26 + s = numpy.append(s, 3)
  27 +
  28 + #generate an RGB image
  29 + RGB = numpy.zeros(s, dtype=numpy.ubyte)
  30 +
  31 + #for each class
  32 + for c in range(0, nc):
  33 + hsv = (c * 1.0 / nc, 1, 1)
  34 + color = numpy.asarray(colorsys.hsv_to_rgb(hsv[0], hsv[1], hsv[2])) * 255
  35 + RGB[C[c, ...], :] = color
  36 +
  37 + return RGB
  38 +
  39 +#create a function that loads a set of class images as a stack of binary masks
  40 +#input: list of class image names
  41 +#output: C x Y x X binary image specifying class/pixel membership
  42 +#example: image2class(("class_coll.bmp", "class_epith.bmp"))
  43 +def image2class(masks):
  44 + #get num of mask file names
  45 + num_masks = len(masks)
  46 +
  47 + if num_masks == 0:
  48 + print("ERROR: mask filenames not provided")
  49 + print("Usage example: image2class(('class_coll.bmp', 'class_epith.bmp'))")
  50 + return
  51 +
  52 + classimages = []
  53 + for m in masks:
  54 + img = scipy.misc.imread(m, flatten=True).astype(numpy.bool)
  55 + classimages.append(img)
  56 +
  57 + result = numpy.stack(classimages)
  58 + sum_images = numpy.sum(result.astype(numpy.uint32), 0)
  59 +
  60 + #identify and remove redundant pixels
  61 + bad_idx = sum_images > 1
  62 + result[:, bad_idx] = 0
  63 +
  64 + return result
  65 +
  66 +
  67 +#create a class mask stack from an C x Y x X probability image
  68 +#input: C x Y x X image giving the probability P(c |x,y)
  69 +#output: C x Y x X binary class image
  70 +def prob2class(prob_image):
  71 + class_image = numpy.zeros(prob_image.shape, dtype=numpy.bool)
  72 + #get nonzero indices
  73 + nnz_idx = numpy.transpose(numpy.nonzero(numpy.sum(prob_image, axis=0)))
  74 +
  75 + #set pixel corresponding to max probability to 1
  76 + for idx in nnz_idx:
  77 + idx_max_prob = numpy.argmax(prob_image[:, idx[0], idx[1]])
  78 + class_image[idx_max_prob, idx[0], idx[1]] = 1
  79 +
  80 + return class_image
  81 +
  82 +#calculate an ROC curve given a probability image and mask of "True" values
  83 +#input:
  84 +# P is a Y x X probability image specifying P(c | x,y)
  85 +# t_vals is a Y x X binary image specifying points where x,y = c
  86 +# mask is a mask specifying all pixels to be considered (positives and negatives)
  87 +# use this mask to limit analysis to regions of the image that have been classified
  88 +#output: fpr, tpr, thresholds
  89 +# fpr is the false-positive rate (x-axis of an ROC curve)
  90 +# tpr is the true-positive rate (y-axis of an ROC curve)
  91 +# thresholds stores the threshold associated with each point on the ROC curve
  92 +#
  93 +#note: the AUC can be calculated as auc = sklearn.metrics.auc(fpr, tpr)
  94 +def prob2roc(P, t_vals, mask=[]):
  95 +
  96 + if not P.shape == t_vals.shape:
  97 + print("ERROR: the probability and mask images must be the same shape")
  98 + return
  99 +
  100 + #if a mask image isn't provided, create one for the entire image
  101 + if mask == []:
  102 + mask = numpy.ones(t_vals.shape, dtype=numpy.bool)
  103 +
  104 + #create masks for the positive and negative probability scores
  105 + mask_p = t_vals
  106 + mask_n = mask - mask * t_vals
  107 +
  108 + #calculate the indices for the positive and negative scores
  109 + idx_p = numpy.nonzero(mask_p)
  110 + idx_n = numpy.nonzero(mask_n)
  111 +
  112 + Pp = P[idx_p]
  113 + Pn = P[idx_n]
  114 +
  115 + Lp = numpy.ones((Pp.shape), dtype=numpy.bool)
  116 + Ln = numpy.zeros((Pn.shape), dtype=numpy.bool)
  117 +
  118 + scores = numpy.concatenate((Pp, Pn))
  119 + labels = numpy.concatenate((Lp, Ln))
  120 +
  121 + return sklearn.metrics.roc_curve(labels, scores)
  122 +
  123 +#convert a label image to a C x Y x X class image
  124 +def label2class(L, background=[]):
  125 + unique = numpy.unique(L)
  126 +
  127 + if not background == []: #if a background value is specified
  128 + unique = numpy.delete(unique, numpy.nonzero(unique == background)) #remove it from the label array
  129 + s = L.shape
  130 + s = numpy.append(numpy.array((len(unique))), s)
  131 + C = numpy.zeros(s, dtype=numpy.bool)
  132 + for i in range(0, len(unique)):
  133 + C[i, :, :] = L == unique[i]
  134 + return C
  135 +
  136 +#randomizes a given mask to include a subset of n pixels in the original
  137 +def random_mask(M, n):
  138 + idx = numpy.flatnonzero(M)
  139 + new_idx = numpy.random.permutation(idx)
  140 + new_mask = numpy.zeros(M.shape, dtype=numpy.bool)
  141 + new_mask[numpy.unravel_index(new_idx[0:n], new_mask.shape)] = True
  142 + return new_mask
  143 +
  144 +
  145 +#Function to convert a set of class labels to a matrix of neuron responses for an ANN
  146 +
  147 +#Function CNN extraction function
0 148 \ No newline at end of file
... ...
python/envi.py 0 โ†’ 100644
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +Created on Fri Jul 21 20:18:01 2017
  4 +
  5 +@author: david
  6 +"""
  7 +
  8 +import os
  9 +import numpy
  10 +import scipy
  11 +import matplotlib.pyplot as plt
  12 +import progressbar
  13 +
  14 +class envi_header:
  15 + def __init__(self, filename = ""):
  16 + if filename != "":
  17 + self.load(filename)
  18 + else:
  19 + self.initialize()
  20 +
  21 + #initialization function
  22 + def initialize(self):
  23 + self.samples = int(0)
  24 + self.lines = int(0)
  25 + self.bands = int(0)
  26 + self.header_offset = int(0)
  27 + self.data_type = int(4)
  28 + self.interleave = "bsq"
  29 + self.sensor_type = ""
  30 + self.byte_order = int(0)
  31 + self.x_start = int(0)
  32 + self.y_start = int(0)
  33 + self.z_plot_titles = ""
  34 + self.pixel_size = [float(0), float(0)]
  35 + self.pixel_size_units = "Meters"
  36 + self.wavelength_units = "Wavenumber"
  37 + self.description = ""
  38 + self.band_names = []
  39 + self.wavelength = []
  40 +
  41 + #convert an ENVI data_type value to a numpy data type
  42 + def get_numpy_type(self, val):
  43 + if val == 1:
  44 + return numpy.byte
  45 + elif val == 2:
  46 + return numpy.int16
  47 + elif val == 3:
  48 + return numpy.int32
  49 + elif val == 4:
  50 + return numpy.float32
  51 + elif val == 5:
  52 + return numpy.float64
  53 + elif val == 6:
  54 + return numpy.complex64
  55 + elif val == 9:
  56 + return numpy.complex128
  57 + elif val == 12:
  58 + return numpy.uint16
  59 + elif val == 13:
  60 + return numpy.uint32
  61 + elif val == 14:
  62 + return numpy.int64
  63 + elif val == 15:
  64 + return numpy.uint64
  65 +
  66 + def get_envi_type(self, val):
  67 + if val == numpy.byte:
  68 + return 1
  69 + elif val == numpy.int16:
  70 + return 2
  71 + elif val == numpy.int32:
  72 + return 3
  73 + elif val == numpy.float32:
  74 + return 4
  75 + elif val == numpy.float64:
  76 + return 5
  77 + elif val == numpy.complex64:
  78 + return 6
  79 + elif val == numpy.complex128:
  80 + return 9
  81 + elif val == numpy.uint16:
  82 + return 12
  83 + elif val == numpy.uint32:
  84 + return 13
  85 + elif val == numpy.int64:
  86 + return 14
  87 + elif val == numpy.uint64:
  88 + return 15
  89 +
  90 + def load(self, fname):
  91 + f = open(fname)
  92 + l = f.readlines()
  93 + if l[0].strip() != "ENVI":
  94 + print("ERROR: not an ENVI file")
  95 + return
  96 + li = 1
  97 + while li < len(l):
  98 + #t = l[li].split() #split the line into tokens
  99 + #t = map(str.strip, t) #strip all of the tokens in the token list
  100 +
  101 + #handle the simple conditions
  102 + if l[li].startswith("file type"):
  103 + if not l[li].strip().endswith("ENVI Standard"):
  104 + print("ERROR: unsupported ENVI file format: " + l[li].strip())
  105 + return
  106 + elif l[li].startswith("samples"):
  107 + self.samples = int(l[li].split()[-1])
  108 + elif l[li].startswith("lines"):
  109 + self.lines = int(l[li].split()[-1])
  110 + elif l[li].startswith("bands"):
  111 + self.bands = int(l[li].split()[-1])
  112 + elif l[li].startswith("header offset"):
  113 + self.header_offset = int(l[li].split()[-1])
  114 + elif l[li].startswith("data type"):
  115 + self.data_type = self.get_numpy_type(int(l[li].split()[-1]))
  116 + elif l[li].startswith("interleave"):
  117 + self.interleave = l[li].split()[-1].strip()
  118 + elif l[li].startswith("sensor type"):
  119 + self.sensor_type = l[li].split()[-1].strip()
  120 + elif l[li].startswith("byte order"):
  121 + self.byte_order = int(l[li].split()[-1])
  122 + elif l[li].startswith("x start"):
  123 + self.x_start = int(l[li].split()[-1])
  124 + elif l[li].startswith("y start"):
  125 + self.y_start = int(l[li].split()[-1])
  126 + elif l[li].startswith("z plot titles"):
  127 + i0 = l[li].rindex('{')
  128 + i1 = l[li].rindex('}')
  129 + self.z_plot_titles = l[li][i0 + 1 : i1]
  130 + elif l[li].startswith("pixel size"):
  131 + i0 = l[li].rindex('{')
  132 + i1 = l[li].rindex('}')
  133 + s = l[li][i0 + 1 : i1].split(',')
  134 + self.pixel_size = [float(s[0]), float(s[1])]
  135 + self.pixel_size_units = s[2][s[2].rindex('=') + 1:].strip()
  136 + elif l[li].startswith("wavelength units"):
  137 + self.wavelength_units = l[li].split()[-1].strip()
  138 +
  139 + #handle the complicated conditions
  140 + elif l[li].startswith("description"):
  141 + desc = [l[li]]
  142 + '''
  143 + while l[li].strip()[-1] != '}': #will fail if l[li].strip() is empty
  144 + li += 1
  145 + desc.append(l[li])
  146 + '''
  147 + while True:
  148 + if l[li].strip():
  149 + if l[li].strip()[-1] == '}':
  150 + break
  151 + li += 1
  152 + desc.append(l[li])
  153 +
  154 + desc = ''.join(list(map(str.strip, desc))) #strip all white space from the string list
  155 + i0 = desc.rindex('{')
  156 + i1 = desc.rindex('}')
  157 + self.description = desc[i0 + 1 : i1]
  158 +
  159 + elif l[li].startswith("band names"):
  160 + names = [l[li]]
  161 + while l[li].strip()[-1] != '}':
  162 + li += 1
  163 + names.append(l[li])
  164 + names = ''.join(list(map(str.strip, names))) #strip all white space from the string list
  165 + i0 = names.rindex('{')
  166 + i1 = names.rindex('}')
  167 + names = names[i0 + 1 : i1]
  168 + self.band_names = list(map(str.strip, names.split(',')))
  169 + elif l[li].startswith("wavelength"):
  170 + waves = [l[li]]
  171 + while l[li].strip()[-1] != '}':
  172 + li += 1
  173 + waves.append(l[li])
  174 + waves = ''.join(list(map(str.strip, waves))) #strip all white space from the string list
  175 + i0 = waves.rindex('{')
  176 + i1 = waves.rindex('}')
  177 + waves = waves[i0 + 1 : i1]
  178 + self.wavelength = list(map(float, waves.split(',')))
  179 +
  180 + li += 1
  181 +
  182 + f.close()
  183 +
  184 +class envi:
  185 + def __init__(self, filename, headername = "", mask = []):
  186 + self.open(filename, headername)
  187 + if mask == []:
  188 + self.mask = numpy.ones((self.header.lines, self.header.samples), dtype=numpy.bool)
  189 + else:
  190 + self.mask = mask
  191 + self.idx = 0 #initialize the batch IDX to 0 for batch reading
  192 +
  193 + def open(self, filename, headername = ""):
  194 + if headername == "":
  195 + headername = filename + ".hdr"
  196 +
  197 + if not os.path.isfile(filename):
  198 + print("ERROR: " + filename + " not found")
  199 + return
  200 + if not os.path.isfile(headername):
  201 + print("ERROR: " + headername + " not found")
  202 + return
  203 +
  204 + #open the file
  205 + self.header = envi_header(headername)
  206 + self.file = open(filename, "rb")
  207 +
  208 + def loadall(self):
  209 + X = self.header.samples
  210 + Y = self.header.lines
  211 + B = self.header.bands
  212 +
  213 + #load the data
  214 + D = numpy.fromfile(self.file, dtype=self.header.data_type)
  215 +
  216 + if self.header.interleave == "bsq":
  217 + return numpy.reshape(D, (B, Y, X))
  218 + #return numpy.swapaxes(D, 0, 2)
  219 + elif self.header.interleave == "bip":
  220 + D = numpy.reshape(D, (Y, X, B))
  221 + return numpy.rollaxis(D, 2)
  222 + elif self.header.interleave == "bil":
  223 + D = numpy.reshape(D, (Y, B, X))
  224 + return numpy.rollaxis(D, 1)
  225 +
  226 + #loads all of the pixels where mask != 0 and returns them as a matrix
  227 + def loadmask(self, mask):
  228 + X = self.header.samples
  229 + Y = self.header.lines
  230 + B = self.header.bands
  231 +
  232 + P = numpy.count_nonzero(mask) #count the number of zeros in the mask file
  233 + M = numpy.zeros((B, P), dtype=self.header.data_type)
  234 + type_bytes = numpy.dtype(self.header.data_type).itemsize
  235 +
  236 + prev_pos = self.file.tell()
  237 + self.file.seek(0)
  238 + if self.header.interleave == "bip":
  239 + spectrum = numpy.zeros(B, dtype=self.header.data_type)
  240 + flatmask = numpy.reshape(mask, (X * Y))
  241 + i = numpy.flatnonzero(flatmask)
  242 + bar = progressbar.ProgressBar(max_value = P)
  243 + for p in range(0, P):
  244 + self.file.seek(i[p] * B * type_bytes)
  245 + self.file.readinto(spectrum)
  246 + M[:, p] = spectrum
  247 + bar.update(p+1)
  248 + if self.header.interleave == "bsq":
  249 + band = numpy.zeros(mask.shape, dtype=self.header.data_type)
  250 + i = numpy.nonzero(mask)
  251 + bar = progressbar.ProgressBar(max_value=B)
  252 + for b in range(0, B):
  253 + self.file.seek(b * X * Y * type_bytes)
  254 + self.file.readinto(band)
  255 + M[b, :] = band[i]
  256 + bar.update(b+1)
  257 + if self.header.interleave == "bil":
  258 + plane = numpy.zeros((B, X), dtype=self.header.data_type)
  259 + p = 0
  260 + bar = progressbar.ProgressBar(max_value=Y)
  261 + for l in range(0, Y):
  262 + i = numpy.flatnonzero(mask[l, :])
  263 + self.file.readinto(plane)
  264 + M[:, p:p+i.shape[0]] = plane[:, i]
  265 + p = p + i.shape[0]
  266 + bar.update(l+1)
  267 + self.file.seek(prev_pos)
  268 + return M
  269 +
  270 + def loadband(self, n):
  271 + X = self.header.samples
  272 + Y = self.header.lines
  273 + B = self.header.bands
  274 +
  275 + band = numpy.zeros((Y, X), dtype=self.header.data_type)
  276 + type_bytes = numpy.dtype(self.header.data_type).itemsize
  277 +
  278 + prev_pos = self.file.tell()
  279 + if self.header.interleave == "bsq":
  280 + self.file.seek(n * X * Y * type_bytes)
  281 + self.file.readinto(band)
  282 + self.file.seek(prev_pos)
  283 + return band
  284 +
  285 + #create a set of feature/target pairs for classification
  286 + #input: envi file object, stack of class masks C x Y x X
  287 + #output: feature matrix (features x pixels), target matrix (1 x pixels)
  288 + #example: generate_training(("class_coll.bmp", "class_epith.bmp"), (1, 2))
  289 + def loadtrain(self, classimages):
  290 +
  291 + # get number of classes
  292 + C = classimages.shape[0]
  293 +
  294 + F = []
  295 + T = []
  296 + for c in range(0, C):
  297 + f = self.loadmask(classimages[c, :, :]) #load the feature matrix for class c
  298 + t = numpy.ones((f.shape[1])) * (c+1) #generate a target array
  299 + F.append(f)
  300 + T.append(t)
  301 +
  302 + return numpy.concatenate(F, 1).transpose(), numpy.concatenate(T)
  303 +
  304 + #read a batch of data based on the mask
  305 + def loadbatch(self, npixels):
  306 + i = numpy.flatnonzero(self.mask) #get the indices of valid pixels
  307 + if len(i) == self.idx: #if all of the pixels have been read, return an empyt array
  308 + return []
  309 + npixels = min(npixels, len(i) - self.idx) #if there aren't enough pixels, change the batch size
  310 + B = self.header.bands
  311 +
  312 + batch = numpy.zeros((B, npixels), dtype=self.header.data_type) #allocate space for the batch
  313 + pixel = numpy.zeros((B), dtype=self.header.data_type) #allocate space for a single pixel
  314 + type_bytes = numpy.dtype(self.header.data_type).itemsize #calculate the size of a single value
  315 + if self.header.interleave == "bip":
  316 + for n in range(0, npixels): #for each pixel in the batch
  317 + self.file.seek(i[self.idx] * B * type_bytes) #seek to the current pixel in the file
  318 + self.file.readinto(pixel) #read a single pixel
  319 + batch[:, n] = pixel #save the pixel into the batch matrix
  320 + self.idx = self.idx + 1
  321 + return batch
  322 + elif self.header.interleave == "bsq":
  323 + print("ERROR: BSQ batch loading isn't implemented yet!")
  324 + elif self.header.interleave == "bil":
  325 + print("ERROR: BIL batch loading isn't implemented yet!")
  326 +
  327 + #returns the current batch index
  328 + def getidx(self):
  329 + return self.idx
  330 +
  331 + #returns an image of the pixels that have been read using batch loading
  332 + def batchmask(self):
  333 + #allocate a new mask
  334 + outmask = numpy.zeros(self.mask.shape, dtype=numpy.bool)
  335 +
  336 + #zero out any unclassified pixels
  337 + idx = self.getidx()
  338 + i = numpy.nonzero(self.mask)
  339 + outmask[i[0][0:idx], i[1][0:idx]] = self.mask[i[0][0:idx], i[1][0:idx]]
  340 + return outmask
  341 +
  342 +
  343 + def __del__(self):
  344 + self.file.close()
0 345 \ No newline at end of file
... ...
python/example.py 0 โ†’ 100644
  1 +import numpy
  2 +import classify
  3 +import matplotlib.pyplot as plt
  4 +from envi import envi
  5 +
  6 +mask_path = '/home/sberisha/data/masks/'
  7 +mask_stack = classify.image2class(mask_path + "class_blood.png", mask_path + "class_coll.png", mask_path + "class_epith.png",
  8 + mask_path + "class_lymph.png", mask_path + "class_necrosis.png")
  9 +
  10 +color_image = classify.classcolor2(mask_stack)
  11 +plt.imshow(color_image)
  12 +
  13 +data_path ='/home/sberisha/data/cnn/brc961-nfp8/envi/'
  14 +
  15 +feature_matrix, target_matrix = classify.generate_training(data_path + 'brc961-nfp8-project-br1003', mask_stack)
  16 +
  17 +prob_path = '/home/sberisha/data/'
  18 +prob_envi= envi(prob_path + "cnn-response")
  19 +prob_image = prob_envi.loadall()
  20 +
  21 +class_image = classify.prob2class(prob_image)
  22 +plt.imshow(class_image[4,:,:])
0 23 \ No newline at end of file
... ...
python/spectral.py 0 โ†’ 100644
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +Created on Sun Jul 23 13:52:22 2017
  4 +
  5 +@author: david
  6 +"""
  7 +import numpy
  8 +
  9 +#sift a 2D hyperspectral image into a PxB matrix where P is the number of pixels and B is the number of bands
  10 +def sift2(I, mask = []):
  11 +
  12 + #get the shape of the input array
  13 + S = I.shape
  14 +
  15 + #convert that array into a 1D matrix
  16 + M = numpy.reshape(I, (S[0], S[1] * S[2]))
  17 +
  18 + #gif no mask is provided, just return all pixels
  19 + if mask == []:
  20 + return M
  21 +
  22 + #if a mask is provided, only return pixels corresponding to that mask
  23 + flatmask = numpy.reshape(mask, (S[1] * S[2]))
  24 + i = numpy.flatnonzero(flatmask) #get the nonzero indices
  25 + return M[:, i] #return pixels corresponding to the masked values
  26 +
  27 +def unsift2(M, mask):
  28 +
  29 + #get the size of the input matrix
  30 + S = M.shape
  31 +
  32 + #count the number of nonzero values in the mask
  33 + nnz = numpy.count_nonzero(mask)
  34 +
  35 + #the number of masked values should be the same as the number of pixels in the input matrix
  36 + if len(S) == 1:
  37 + if not S[0] == nnz:
  38 + print("ERROR: expected " + str(nnz) + " pixels based on the mask but there are " + str(S[0]) + " in the matrix.")
  39 + elif not S[1] == nnz:
  40 + print("ERROR: expected " + str(nnz) + " pixels based on the mask but there are " + str(S[1]) + " in the matrix.")
  41 +
  42 +
  43 + i = numpy.nonzero(mask)
  44 +
  45 + if len(S) == 1:
  46 + I = numpy.zeros((1, mask.shape[0], mask.shape[1]), dtype=M.dtype)
  47 + else:
  48 + I = numpy.zeros((M.shape[0], mask.shape[0], mask.shape[1]), dtype=M.dtype)
  49 + I[:, i[0], i[1]] = M
  50 + return I
  51 +
  52 +#create a function that sifts a color image
  53 +#input: image name, mask
0 54 \ No newline at end of file
... ...