rtsEnviRandomForest2C_train.m
1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
function RF = rtsEnviRandomForest2C_train(EnviFileName, EnviHeaderName, RoiImages, reference)
%Creates a 2-class random forest classifier for an Envi BIP file using the
%provided ROI images. The resulting classifier uses regression to map
%class A to 1 and class B to 0. This allows the creation of ROC curves.
%
% EnviFilename - Name of an ENVI BIP file
% RoiImages - Cell array containing names of ROI images (masks)
%default parameters
maxPixels = 200000;
threshold = 0.5;
nTrees = 100;
nThreads = 8;
%trainPixels = 200000;
%determine the number of classes
nClasses = length(RoiImages);
%make sure that there are only two classes
if nClasses > 2
disp('This classifier only supports 2 classes.');
return;
end
%for each class, load the training data
T = [];
F = [];
for c = 1:nClasses
%load the class mask
maskImage = imread(RoiImages{c});
maskBinary = (maskImage(:, :, 1) > 0)';
disp('------------------------------------------------------');
%load epithelium spectra
disp(['Loading Training Class ' num2str(c) ' pixels: ' EnviFileName]);
tLoadTime = tic;
fid = rtsEnviOpen(EnviFileName, EnviHeaderName, maskBinary);
Fc = rtsEnviRead(fid, maxPixels);
rtsEnviClose(fid);
if c == 1
Tc = ones(size(Fc, 2), 1);
else
Tc = zeros(size(Fc, 2), 1);
end
disp(['Time: ' num2str(toc(tLoadTime)) 's']);
%add features and targets to the final vectors
T = [T; Tc];
F = [F; Fc'];
end
%apply the reference
if nargin == 4
Fnorm = repmat(F(:, reference), 1, size(F, 2));
F = F./Fnorm;
F(:, reference) = [];
end
%parallelize if specified
if nThreads > 1
matlabpool('open', nThreads);
paraoptions = statset('UseParallel', 'always');
end
%train the classifier
disp('Creating a Random Forest classifier...');
tTrainTime = tic;
RF = TreeBagger(nTrees, F, T, 'method', 'regression', 'Options',paraoptions);
disp(['Time: ' num2str(toc(tTrainTime)) 's']);
if nThreads > 1
matlabpool close
end