Blame view

rtsEnviRandomForest2C_train.m 1.98 KB
8be1ab93   David Mayerich   initial commit of...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  function RF = rtsEnviRandomForest2C_train(EnviFileName, EnviHeaderName, RoiImages, reference)
  
  %Creates a 2-class random forest classifier for an Envi BIP file using the 
  %provided ROI images.  The resulting classifier uses regression to map
  %class A to 1 and class B to 0.  This allows the creation of ROC curves.
  %
  %   EnviFilename - Name of an ENVI BIP file
  %   RoiImages - Cell array containing names of ROI images (masks)
  
  %default parameters
  maxPixels = 200000;
  threshold = 0.5;
  nTrees = 100;
  nThreads = 8;
  %trainPixels = 200000;
  
  %determine the number of classes
  nClasses = length(RoiImages);
  %make sure that there are only two classes
  if nClasses > 2
      disp('This classifier only supports 2 classes.');
      return;
  end
  
  %for each class, load the training data
  T = [];
  F = [];
  for c = 1:nClasses
      
      %load the class mask
      maskImage = imread(RoiImages{c});
      maskBinary = (maskImage(:, :, 1) > 0)';
      
      disp('------------------------------------------------------');
      %load epithelium spectra
      disp(['Loading Training Class ' num2str(c) ' pixels: ' EnviFileName]);
      tLoadTime = tic;
      fid = rtsEnviOpen(EnviFileName, EnviHeaderName, maskBinary);
      Fc = rtsEnviRead(fid, maxPixels);
      rtsEnviClose(fid);
      
      if c == 1
          Tc = ones(size(Fc, 2), 1);
      else
          Tc = zeros(size(Fc, 2), 1);
      end
      
      disp(['Time: ' num2str(toc(tLoadTime)) 's']);
      
      %add features and targets to the final vectors
      T = [T; Tc];
      F = [F; Fc'];  
  end
  
  %apply the reference
  if nargin == 4
      Fnorm = repmat(F(:, reference), 1, size(F, 2));
      F = F./Fnorm;
      F(:, reference) = [];
  end
  
  %parallelize if specified
  if nThreads > 1
      matlabpool('open', nThreads);
      paraoptions = statset('UseParallel', 'always');
  end
  
  %train the classifier
  disp('Creating a Random Forest classifier...');
  tTrainTime = tic;
  RF = TreeBagger(nTrees, F, T, 'method', 'regression', 'Options',paraoptions);
  disp(['Time: ' num2str(toc(tTrainTime)) 's']);
  
  if nThreads > 1
      matlabpool close
  end