-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmainProgramSRNet.m
69 lines (61 loc) · 2.6 KB
/
mainProgramSRNet.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
clear
clc
% Load input data
% please update the folder of training input data
imds = imageDatastore('...\data\trainingSet\input','IncludeSubfolders',false,'FileExtensions','.png');
% Define three classes
classNames = ["NR" "LTE" "Noise"];
pixelLabelID = [127 255 0];
% Load groundtruth data
% please update the folder of training groundtruth data
pxdsTruth = pixelLabelDatastore('...\data\trainingSet\label',classNames,pixelLabelID,...
'IncludeSubfolders',false,'FileExtensions','.png');
% Analyze Dataset Statistics
tbl = countEachLabel(pxdsTruth);
frequency = tbl.PixelCount/sum(tbl.PixelCount);
figure
bar(1:numel(classNames),frequency)
grid on
xticks(1:numel(classNames))
xticklabels(tbl.Name)
xtickangle(45)
ylabel('Frequency')
%Prepare Training and Validation
[imdsTrain,pxdsTrain,imdsVal,pxdsVal] = helperSpecSensePartitionData(imds,pxdsTruth,[80 20]);
cdsTrain = combine(imdsTrain,pxdsTrain);
cdsVal = combine(imdsVal,pxdsVal);
% Apply a transform to resize the image and pixel label data
imageSize = [256 256];
cdsTrain = transform(cdsTrain, @(data)preprocessTrainingData(data,imageSize));
cdsVal = transform(cdsVal, @(data)preprocessTrainingData(data,imageSize));
% Load the architecture of deep model
load('SRNet.mat');
% Balance classes using class weighting
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ imageFreq;
pxLayer = pixelClassificationLayer('Name','labels','Classes',tbl.Name,'ClassWeights',classWeights);
lgraph = replaceLayer(lgraph,"labels",pxLayer);
%Select training options
opts = trainingOptions("sgdm",...
MiniBatchSize = 40,...
MaxEpochs = 100, ...
LearnRateSchedule = "piecewise",...
InitialLearnRate = 0.02,...
LearnRateDropPeriod = 10,...
LearnRateDropFactor = 0.1,...
ValidationFrequency = 200,...
ValidationData = cdsVal,...
ValidationPatience = inf,...
BatchNormalizationStatistics="moving",...
Shuffle="every-epoch",...
OutputNetwork = "best-validation-loss",...
Plots = 'training-progress');
[net,trainInfo] = trainNetwork(cdsTrain,lgraph,opts);
% Performance evaluation on test data
imds = imageDatastore('...\data\testSet\input' ,'IncludeSubfolders',false,'FileExtensions','.png');
pxdsResults = semanticseg(imds,net,"WriteLocation",'...\data\testSet\output'); % change the location to save segmented outputs
% Evaluate
pxdsTruth = pixelLabelDatastore('...\data\testSet\label',classNames,pixelLabelID,...
'IncludeSubfolders',false,'FileExtensions','.png');
% Measure the accuracy, IoU, and other metrics
metrics = evaluateSemanticSegmentation(pxdsResults,pxdsTruth);