Vous êtes sur la page 1sur 8

Amelioration de la precision de notre reseau

% Chargement du jeu de données d'entraînement et de test

[trainX, trainY, testX, testY] = load_dataset();

xTest: [28x28x10000 uint8]


xTrain: [28x28x60000 uint8]
yTest: [10000x1 categorical]
yTrain: [60000x1 categorical]

Nous avons basé notre amélioration sur le même réseau, mais nous avons apporté les changements suivants
pour améliorer ses performances :

1. Remodelage et Normalisation des données :

Les données d'entrée sont remodelées directement dans une nouvelle forme (28x28x1) avant la normalisation.
La normalisation est effectuée après le remodelage.

trainX = reshape(trainX, [28, 28, 1, size(trainX, 3)]);


testX = reshape(testX, [28, 28, 1, size(testX, 3)]);

% Normaliser les données d'entrée pour être dans la plage [0,1]


X_train = double(trainX) / 255.0;
X_test = double(testX) / 255.0;

% Convertir les labels en type categorical si ce n'est pas déjà le cas


y_train = categorical(trainY);
y_test = categorical(testY);

2. Augmentation des données :

Une étape d'augmentation des données est ajoutée à l'aide de imageDataAugmenter pour augmenter la
variabilité des données d'entraînement.

% Configuration pour l'augmentation des données


imageAugmenter = imageDataAugmenter( ...
'RandRotation', [-10, 10], ... % Rotation aléatoire de -10 à 10 degrés
'RandXTranslation', [-3, 3], ... % Translation horizontale
'RandYTranslation', [-3, 3]); % Translation verticale

1
augmentedTrainData = augmentedImageDatastore([28, 28, 1], X_train, y_train,
...
'DataAugmentation', imageAugmenter);

% Définir l'architecture du réseau avec des modifications


layers = [
imageInputLayer([28 28 1], 'Name', 'input', 'Normalization', 'none')

% Ajouter plus de couches de convolution


convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv1_1')
batchNormalizationLayer('Name', 'bn1_1')
reluLayer('Name', 'relu1_1')
convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'conv1_2')
batchNormalizationLayer('Name', 'bn1_2')
reluLayer('Name', 'relu1_2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool1')

3. Modification de l'architecture du réseau :

augmentons la complexité du réseau en ajoutant plus de couches de convolution, en augmentant le nombre de


neurones dans la couche entièrement connectée et en ajustant le taux de dropout.

dropoutLayer(0.3, 'Name', 'dropout1') % Augmenter le taux de dropout

% Répéter le motif pour une deuxième série de couches


convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2_1')
batchNormalizationLayer('Name', 'bn2_1')
reluLayer('Name', 'relu2_1')
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'conv2_2')
batchNormalizationLayer('Name', 'bn2_2')
reluLayer('Name', 'relu2_2')
maxPooling2dLayer(2, 'Stride', 2, 'Name', 'maxpool2')
dropoutLayer(0.3, 'Name', 'dropout2')

flattenLayer('Name', 'flatten')
fullyConnectedLayer(256, 'Name', 'fc1') % Augmenter le nombre de
neurones
reluLayer('Name', 'relu3')
dropoutLayer(0.4, 'Name', 'dropout3') % Augmenter encore le taux de
dropout
fullyConnectedLayer(10, 'Name', 'fc2')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'output')];

4. Ajustements des options d'entraînement :

Les options d'entraînement sont ajustées avec un taux d'apprentissage initial plus bas, une taille de lot
différente et un nombre d'époques augmenté.

2
% Configurer les options d'entraînement avec des modifications
options = trainingOptions('adam', ...
'InitialLearnRate', 0.001, ... % Ajuster le taux d'apprentissage
'MaxEpochs', 3, ... % Augmenter le nombre d'époques
'MiniBatchSize', 64, ... % Ajuster la taille des lots
'Shuffle', 'every-epoch', ...
'ValidationData', {X_test, y_test}, ...
'ValidationFrequency', 30, ...
'Verbose', true, ...
'Plots', 'training-progress');

% Entraîner le réseau avec les données augmentées


net = trainNetwork(augmentedTrainData, layers, options);

Training on single CPU.


|=========================================================================================================
| Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Ba
| | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss |
|=========================================================================================================
| 1 | 1 | 00:00:11 | 7.81% | 12.94% | 3.9545 | 4.1806 |
| 1 | 30 | 00:00:20 | 56.25% | 65.81% | 1.1529 | 0.9242 |
| 1 | 50 | 00:00:22 | 59.38% | | 0.9782 | |
| 1 | 60 | 00:00:28 | 62.50% | 70.83% | 1.0884 | 0.7799 |
| 1 | 90 | 00:00:36 | 62.50% | 72.19% | 1.1269 | 0.7197 |
| 1 | 100 | 00:00:37 | 78.12% | | 0.7573 | |
| 1 | 120 | 00:00:44 | 50.00% | 74.35% | 1.0395 | 0.6664 |
| 1 | 150 | 00:00:52 | 59.38% | 75.85% | 1.0115 | 0.6296 |
| 1 | 180 | 00:01:01 | 70.31% | 75.40% | 0.8423 | 0.6347 |
| 1 | 200 | 00:01:04 | 65.62% | | 0.7715 | |
| 1 | 210 | 00:01:10 | 67.19% | 76.95% | 0.8871 | 0.6081 |
| 1 | 240 | 00:01:18 | 65.62% | 77.54% | 0.7744 | 0.5746 |
| 1 | 250 | 00:01:19 | 71.88% | | 0.7609 | |
| 1 | 270 | 00:01:26 | 67.19% | 76.50% | 0.8979 | 0.5960 |
| 1 | 300 | 00:01:34 | 82.81% | 78.91% | 0.6081 | 0.5623 |
| 1 | 330 | 00:01:43 | 68.75% | 77.66% | 0.6919 | 0.5639 |
| 1 | 350 | 00:01:45 | 71.88% | | 0.6573 | |
| 1 | 360 | 00:01:51 | 75.00% | 78.16% | 0.5777 | 0.5495 |
| 1 | 390 | 00:02:00 | 79.69% | 78.76% | 0.5079 | 0.5414 |
| 1 | 400 | 00:02:01 | 75.00% | | 0.6444 | |
| 1 | 420 | 00:02:08 | 70.31% | 80.21% | 0.7104 | 0.5246 |
| 1 | 450 | 00:02:16 | 73.44% | 79.36% | 0.6870 | 0.5144 |
| 1 | 480 | 00:02:24 | 68.75% | 79.24% | 0.7484 | 0.5257 |
| 1 | 500 | 00:02:26 | 78.12% | | 0.6421 | |
| 1 | 510 | 00:02:33 | 73.44% | 79.89% | 0.7010 | 0.5007 |
| 1 | 540 | 00:02:41 | 71.88% | 81.00% | 0.6588 | 0.4737 |
| 1 | 550 | 00:02:42 | 84.38% | | 0.5129 | |
| 1 | 570 | 00:02:49 | 85.94% | 80.68% | 0.4962 | 0.4963 |
| 1 | 600 | 00:02:57 | 71.88% | 81.18% | 0.6397 | 0.4863 |
| 1 | 630 | 00:03:05 | 70.31% | 81.39% | 0.5886 | 0.4812 |
| 1 | 650 | 00:03:07 | 82.81% | | 0.5680 | |
| 1 | 660 | 00:03:14 | 67.19% | 82.11% | 0.6421 | 0.4692 |
| 1 | 690 | 00:03:22 | 70.31% | 81.57% | 0.6934 | 0.4574 |
| 1 | 700 | 00:03:23 | 81.25% | | 0.4477 | |
| 1 | 720 | 00:03:30 | 76.56% | 79.01% | 0.6089 | 0.5054 |
| 1 | 750 | 00:03:38 | 79.69% | 81.68% | 0.5258 | 0.4653 |
| 1 | 780 | 00:03:47 | 76.56% | 81.31% | 0.6362 | 0.4733 |
| 1 | 800 | 00:03:49 | 71.88% | | 0.6353 | |
| 1 | 810 | 00:03:55 | 65.62% | 79.67% | 0.8002 | 0.4991 |
| 1 | 840 | 00:04:03 | 73.44% | 82.54% | 0.8231 | 0.4661 |

3
| 1 | 850 | 00:04:04 | 76.56% | | 0.4977 | |
| 1 | 870 | 00:04:12 | 79.69% | 80.72% | 0.4187 | 0.4736 |
| 1 | 900 | 00:04:20 | 81.25% | 80.89% | 0.5981 | 0.4856 |
| 1 | 930 | 00:04:28 | 76.56% | 82.52% | 0.7118 | 0.4620 |
| 2 | 950 | 00:04:30 | 78.12% | | 0.6582 | |
| 2 | 960 | 00:04:37 | 87.50% | 82.30% | 0.4274 | 0.4497 |
| 2 | 990 | 00:04:45 | 70.31% | 82.39% | 0.5184 | 0.4571 |
| 2 | 1000 | 00:04:46 | 81.25% | | 0.4953 | |
| 2 | 1020 | 00:04:54 | 68.75% | 80.76% | 0.6403 | 0.4685 |
| 2 | 1050 | 00:05:02 | 73.44% | 82.06% | 1.0129 | 0.4466 |
| 2 | 1080 | 00:05:11 | 84.38% | 83.38% | 0.4175 | 0.4234 |
| 2 | 1100 | 00:05:12 | 85.94% | | 0.4127 | |
| 2 | 1110 | 00:05:19 | 73.44% | 82.02% | 0.5520 | 0.4477 |
| 2 | 1140 | 00:05:27 | 78.12% | 83.28% | 0.5909 | 0.4558 |
| 2 | 1150 | 00:05:28 | 81.25% | | 0.4765 | |
| 2 | 1170 | 00:05:35 | 78.12% | 82.92% | 0.5781 | 0.4588 |
| 2 | 1200 | 00:05:43 | 89.06% | 82.61% | 0.3641 | 0.4401 |
| 2 | 1230 | 00:05:51 | 82.81% | 83.07% | 0.4099 | 0.4563 |
| 2 | 1250 | 00:05:53 | 73.44% | | 0.6258 | |
| 2 | 1260 | 00:05:59 | 73.44% | 83.32% | 0.6693 | 0.4463 |
| 2 | 1290 | 00:06:08 | 79.69% | 82.99% | 0.6237 | 0.4370 |
| 2 | 1300 | 00:06:09 | 82.81% | | 0.5231 | |
| 2 | 1320 | 00:06:16 | 84.38% | 83.91% | 0.4574 | 0.4349 |
| 2 | 1350 | 00:06:25 | 84.38% | 84.08% | 0.5135 | 0.4395 |
| 2 | 1380 | 00:06:33 | 75.00% | 84.94% | 0.5709 | 0.4041 |
| 2 | 1400 | 00:06:35 | 75.00% | | 0.7250 | |
| 2 | 1410 | 00:06:41 | 73.44% | 83.76% | 0.7005 | 0.4229 |
| 2 | 1440 | 00:06:50 | 85.94% | 84.32% | 0.3974 | 0.4041 |
| 2 | 1450 | 00:06:51 | 82.81% | | 0.4679 | |
| 2 | 1470 | 00:06:58 | 76.56% | 84.38% | 0.8149 | 0.4103 |
| 2 | 1500 | 00:07:06 | 81.25% | 84.74% | 0.4401 | 0.3971 |
| 2 | 1530 | 00:07:15 | 78.12% | 85.66% | 0.6573 | 0.3884 |
| 2 | 1550 | 00:07:17 | 81.25% | | 0.4551 | |
| 2 | 1560 | 00:07:23 | 81.25% | 84.58% | 0.5312 | 0.3993 |
| 2 | 1590 | 00:07:32 | 68.75% | 85.93% | 0.7362 | 0.3858 |
| 2 | 1600 | 00:07:33 | 71.88% | | 0.5834 | |
| 2 | 1620 | 00:07:40 | 84.38% | 84.99% | 0.3614 | 0.3965 |
| 2 | 1650 | 00:07:49 | 84.38% | 85.11% | 0.3413 | 0.4007 |
| 2 | 1680 | 00:07:58 | 79.69% | 85.36% | 0.5911 | 0.3874 |
| 2 | 1700 | 00:08:00 | 87.50% | | 0.3413 | |
| 2 | 1710 | 00:08:07 | 78.12% | 85.25% | 0.6069 | 0.4068 |
| 2 | 1740 | 00:08:16 | 81.25% | 85.21% | 0.5232 | 0.3901 |
| 2 | 1750 | 00:08:17 | 82.81% | | 0.4178 | |
| 2 | 1770 | 00:08:25 | 82.81% | 84.78% | 0.4800 | 0.3956 |
| 2 | 1800 | 00:08:33 | 81.25% | 85.77% | 0.4321 | 0.3720 |
| 2 | 1830 | 00:08:41 | 75.00% | 85.69% | 0.5560 | 0.3750 |
| 2 | 1850 | 00:08:43 | 81.25% | | 0.5573 | |
| 2 | 1860 | 00:08:50 | 84.38% | 86.03% | 0.3298 | 0.3695 |
| 3 | 1890 | 00:08:58 | 75.00% | 85.51% | 0.6790 | 0.3917 |
| 3 | 1900 | 00:08:59 | 78.12% | | 0.4963 | |
| 3 | 1920 | 00:09:07 | 81.25% | 85.69% | 0.4163 | 0.3819 |
| 3 | 1950 | 00:09:15 | 79.69% | 85.52% | 0.4666 | 0.3815 |
| 3 | 1980 | 00:09:24 | 85.94% | 86.02% | 0.3937 | 0.3671 |
| 3 | 2000 | 00:09:26 | 81.25% | | 0.6078 | |
| 3 | 2010 | 00:09:33 | 82.81% | 86.19% | 0.4861 | 0.3636 |
| 3 | 2040 | 00:09:42 | 68.75% | 85.34% | 0.8095 | 0.3721 |
| 3 | 2050 | 00:09:43 | 81.25% | | 0.4318 | |
| 3 | 2070 | 00:09:51 | 84.38% | 84.03% | 0.5038 | 0.4106 |
| 3 | 2100 | 00:10:10 | 81.25% | 84.45% | 0.5300 | 0.4027 |
| 3 | 2130 | 00:10:30 | 76.56% | 86.49% | 0.6111 | 0.3549 |
| 3 | 2150 | 00:10:34 | 81.25% | | 0.4532 | |
| 3 | 2160 | 00:10:50 | 82.81% | 86.48% | 0.4063 | 0.3679 |
| 3 | 2190 | 00:11:09 | 78.12% | 86.60% | 0.5903 | 0.3607 |
| 3 | 2200 | 00:11:12 | 92.19% | | 0.3065 | |

4
| 3 | 2220 | 00:11:28 | 85.94% | 86.73% | 0.4335 | 0.3666 |
| 3 | 2250 | 00:11:47 | 85.94% | 85.80% | 0.4184 | 0.4003 |
| 3 | 2280 | 00:12:05 | 71.88% | 85.90% | 0.5946 | 0.3668 |
| 3 | 2300 | 00:12:10 | 81.25% | | 0.5759 | |
| 3 | 2310 | 00:12:24 | 84.38% | 86.60% | 0.4741 | 0.3557 |
| 3 | 2340 | 00:12:43 | 81.25% | 86.08% | 0.4817 | 0.3807 |
| 3 | 2350 | 00:12:45 | 78.12% | | 0.5004 | |
| 3 | 2370 | 00:13:04 | 79.69% | 86.66% | 0.5872 | 0.3592 |
| 3 | 2400 | 00:13:25 | 82.81% | 86.62% | 0.4866 | 0.3582 |
| 3 | 2430 | 00:13:45 | 85.94% | 86.18% | 0.3289 | 0.3573 |
| 3 | 2450 | 00:13:49 | 79.69% | | 0.4082 | |
| 3 | 2460 | 00:14:04 | 82.81% | 87.61% | 0.3620 | 0.3339 |
| 3 | 2490 | 00:14:23 | 82.81% | 87.00% | 0.4730 | 0.3430 |
| 3 | 2500 | 00:14:25 | 82.81% | | 0.4116 | |
| 3 | 2520 | 00:14:42 | 76.56% | 87.15% | 0.5256 | 0.3466 |
| 3 | 2550 | 00:15:01 | 85.94% | 86.21% | 0.4611 | 0.3543 |
| 3 | 2580 | 00:15:20 | 76.56% | 87.50% | 0.6258 | 0.3283 |
| 3 | 2600 | 00:15:24 | 84.38% | | 0.4969 | |
| 3 | 2610 | 00:15:39 | 81.25% | 87.05% | 0.4683 | 0.3464 |
| 3 | 2640 | 00:15:59 | 81.25% | 87.40% | 0.5859 | 0.3371 |
| 3 | 2650 | 00:16:01 | 84.38% | | 0.4476 | |
| 3 | 2670 | 00:16:19 | 85.94% | 87.90% | 0.5178 | 0.3263 |
| 3 | 2700 | 00:16:39 | 81.25% | 87.12% | 0.4942 | 0.3440 |
| 3 | 2730 | 00:16:59 | 87.50% | 87.08% | 0.4957 | 0.3352 |
| 3 | 2750 | 00:17:04 | 81.25% | | 0.4751 | |
| 3 | 2760 | 00:17:18 | 81.25% | 87.46% | 0.3276 | 0.3415 |
| 3 | 2790 | 00:17:36 | 79.69% | 87.67% | 0.5889 | 0.3247 |
| 3 | 2800 | 00:17:38 | 81.25% | | 0.5031 | |
| 3 | 2811 | 00:17:52 | 92.19% | 87.83% | 0.3225 | 0.3266 |
|=========================================================================================================
Training finished: Max epochs completed.

Faire des prédictions sur les données de test du datatest FashionMNIST

YPred = classify(net, X_test);

% Convertir les labels de test en numérique pour la comparaison


y_test_numeric = grp2idx(y_test);

% Préparer la grille de visualisation


L = 5;
W = 5;
figure;
set(gcf, 'Units', 'Inches', 'Position', [1, 1, 12, 12]);

for i = 1:L*W
subplot(L, W, i);

% Afficher l'image de test


imshow(X_test(:,:,:,i));

% Obtenir la classe prédite et la vraie classe


predictedClass = YPred(i);
trueClass = y_test(i);

5
% Définir le titre avec la classe prédite et la vraie classe
title(sprintf('Predicted: %s\nTrue: %s', string(predictedClass),
string(trueClass)));

% Enlever les axes pour une meilleure visualisation


axis off;
end

% Ajuster l'espacement
sgtitle('Predictions on the Test Set');

6
Testons notre reseau ameliorée sur d'autres images

% Lire l'image
img = imread('bag.jpg');

% Convertir en niveaux de gris si nécessaire

7
img = rgb2gray(img);

% Redimensionner l'image pour correspondre à l'entrée du réseau


img = imresize(img, [28, 28]);

% Normaliser l'image
img = double(img) / 255.0;

% Réorganiser l'image pour correspondre à l'entrée du réseau


img = reshape(img, [28, 28, 1, 1]);

% Faire une prédiction en utilisant le réseau entraîné


predictedClass = classify(net, img);

% Afficher la classe prédite


disp(['La classe prédite est : ', char(predictedClass)]);

La classe prédite est : Bag

Vous aimerez peut-être aussi