混合精度训练优化
Created on: December 14, 2024
Created on: December 14, 2024
第一阶段训练: 仅训练顶层...
开始训练模型...
Epoch 1/5
Epoch 1/5
2024-12-15 02:16:26.552956: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8201
718/718 [==============================] - ETA: 0s - loss: 2.2013 - accuracy: 0.3023 - precision: 0.3641 - recall: 0.1937 - auc: 0.6656 - 58.37s - loss: 2.2013 - accuracy: 0.3023 - precision: 0.3641 - recall: 0.1937 - auc: 0.6656 - val_loss: 1.5583 - val_accuracy: 0.4441 - val_precision: 0.5925 - val_recall: 0.2745 - val_auc: 0.7906
718/718 [==============================] - 58s 71ms/step - loss: 2.2013 - accuracy: 0.3023 - precision: 0.3641 - recall: 0.1937 - auc: 0.6656 - val_loss: 1.5583 - val_accuracy: 0.4441 - val_precision: 0.5925 - val_recall: 0.2745 - val_auc: 0.7906 - lr: 1.0000e-04
Epoch 2/5
Epoch 2/5
718/718 [==============================] - ETA: 0s - loss: 1.8175 - accuracy: 0.3836 - precision: 0.4849 - recall: 0.2624 - auc: 0.7463 - 48.76s - loss: 1.8175 - accuracy: 0.3836 - precision: 0.4849 - recall: 0.2624 - auc: 0.7463 - val_loss: 1.4530 - val_accuracy: 0.4687 - val_precision: 0.6187 - val_recall: 0.3046 - val_auc: 0.8194
718/718 [==============================] - 49s 68ms/step - loss: 1.8175 - accuracy: 0.3836 - precision: 0.4849 - recall: 0.2624 - auc: 0.7463 - val_loss: 1.4530 - val_accuracy: 0.4687 - val_precision: 0.6187 - val_recall: 0.3046 - val_auc: 0.8194 - lr: 1.0000e-04
Epoch 3/5
Epoch 3/5
717/718 [============================>.] - ETA: 0s - loss: 1.6691 - accuracy: 0.4136 - precision: 0.5232 - recall: 0.2755 - auc: 0.7758 - 51.75s - loss: 1.6692 - accuracy: 0.4136 - precision: 0.5232 - recall: 0.2755 - auc: 0.7757 - val_loss: 1.3900 - val_accuracy: 0.4869 - val_precision: 0.6445 - val_recall: 0.3189 - val_auc: 0.8343
718/718 [==============================] - 52s 72ms/step - loss: 1.6692 - accuracy: 0.4136 - precision: 0.5232 - recall: 0.2755 - auc: 0.7757 - val_loss: 1.3900 - val_accuracy: 0.4869 - val_precision: 0.6445 - val_recall: 0.3189 - val_auc: 0.8343 - lr: 1.0000e-04
Epoch 4/5
Epoch 4/5
717/718 [============================>.] - ETA: 0s - loss: 1.5866 - accuracy: 0.4326 - precision: 0.5535 - recall: 0.2867 - auc: 0.7933 - 52.58s - loss: 1.5863 - accuracy: 0.4327 - precision: 0.5534 - recall: 0.2868 - auc: 0.7934 - val_loss: 1.3512 - val_accuracy: 0.4970 - val_precision: 0.6482 - val_recall: 0.3283 - val_auc: 0.8442
718/718 [==============================] - 53s 73ms/step - loss: 1.5863 - accuracy: 0.4327 - precision: 0.5534 - recall: 0.2868 - auc: 0.7934 - val_loss: 1.3512 - val_accuracy: 0.4970 - val_precision: 0.6482 - val_recall: 0.3283 - val_auc: 0.8442 - lr: 1.0000e-04
Epoch 5/5
Epoch 5/5
717/718 [============================>.] - ETA: 0s - loss: 1.5322 - accuracy: 0.4382 - precision: 0.5722 - recall: 0.2901 - auc: 0.8043 - 52.44s - loss: 1.5322 - accuracy: 0.4382 - precision: 0.5722 - recall: 0.2902 - auc: 0.8042 - val_loss: 1.3268 - val_accuracy: 0.5063 - val_precision: 0.6596 - val_recall: 0.3379 - val_auc: 0.8502
718/718 [==============================] - 52s 73ms/step - loss: 1.5322 - accuracy: 0.4382 - precision: 0.5722 - recall: 0.2902 - auc: 0.8042 - val_loss: 1.3268 - val_accuracy: 0.5063 - val_precision: 0.6596 - val_recall: 0.3379 - val_auc: 0.8502 - lr: 1.0000e-04
训练完成,总耗时 264.32s
第二阶段训练: 微调后20层...
开始训练模型...
Epoch 1/25
Epoch 1/25
718/718 [==============================] - ETA: 0s - loss: 1.4926 - accuracy: 0.4486 - precision_1: 0.5826 - recall_1: 0.2943 - auc_1: 0.8134 - 66.09s - loss: 1.4926 - accuracy: 0.4486 - precision_1: 0.5826 - recall_1: 0.2943 - auc_1: 0.8134 - val_loss: 1.3179 - val_accuracy: 0.5098 - val_precision_1: 0.6535 - val_recall_1: 0.3429 - val_auc_1: 0.8527
718/718 [==============================] - 66s 84ms/step - loss: 1.4926 - accuracy: 0.4486 - precision_1: 0.5826 - recall_1: 0.2943 - auc_1: 0.8134 - val_loss: 1.3179 - val_accuracy: 0.5098 - val_precision_1: 0.6535 - val_recall_1: 0.3429 - val_auc_1: 0.8527 - lr: 1.0000e-05
Epoch 2/25
Epoch 2/25
718/718 [==============================] - ETA: 0s - loss: 1.4664 - accuracy: 0.4583 - precision_1: 0.5995 - recall_1: 0.3083 - auc_1: 0.8199 - 60.10s - loss: 1.4664 - accuracy: 0.4583 - precision_1: 0.5995 - recall_1: 0.3083 - auc_1: 0.8199 - val_loss: 1.2982 - val_accuracy: 0.5179 - val_precision_1: 0.6595 - val_recall_1: 0.3575 - val_auc_1: 0.8574
718/718 [==============================] - 60s 84ms/step - loss: 1.4664 - accuracy: 0.4583 - precision_1: 0.5995 - recall_1: 0.3083 - auc_1: 0.8199 - val_loss: 1.2982 - val_accuracy: 0.5179 - val_precision_1: 0.6595 - val_recall_1: 0.3575 - val_auc_1: 0.8574 - lr: 1.0000e-05
Epoch 3/25
Epoch 3/25
718/718 [==============================] - ETA: 0s - loss: 1.4431 - accuracy: 0.4693 - precision_1: 0.6078 - recall_1: 0.3165 - auc_1: 0.8257 - 60.81s - loss: 1.4431 - accuracy: 0.4693 - precision_1: 0.6078 - recall_1: 0.3165 - auc_1: 0.8257 - val_loss: 1.2809 - val_accuracy: 0.5254 - val_precision_1: 0.6676 - val_recall_1: 0.3624 - val_auc_1: 0.8611
718/718 [==============================] - 61s 85ms/step - loss: 1.4431 - accuracy: 0.4693 - precision_1: 0.6078 - recall_1: 0.3165 - auc_1: 0.8257 - val_loss: 1.2809 - val_accuracy: 0.5254 - val_precision_1: 0.6676 - val_recall_1: 0.3624 - val_auc_1: 0.8611 - lr: 1.0000e-05
Epoch 4/25
Epoch 4/25
717/718 [============================>.] - ETA: 0s - loss: 1.4177 - accuracy: 0.4778 - precision_1: 0.6158 - recall_1: 0.3259 - auc_1: 0.8315 - 60.32s - loss: 1.4174 - accuracy: 0.4779 - precision_1: 0.6160 - recall_1: 0.3259 - auc_1: 0.8316 - val_loss: 1.2679 - val_accuracy: 0.5279 - val_precision_1: 0.6739 - val_recall_1: 0.3715 - val_auc_1: 0.8643
718/718 [==============================] - 60s 84ms/step - loss: 1.4174 - accuracy: 0.4779 - precision_1: 0.6160 - recall_1: 0.3259 - auc_1: 0.8316 - val_loss: 1.2679 - val_accuracy: 0.5279 - val_precision_1: 0.6739 - val_recall_1: 0.3715 - val_auc_1: 0.8643 - lr: 1.0000e-05
Epoch 5/25
Epoch 5/25
718/718 [==============================] - ETA: 0s - loss: 1.3956 - accuracy: 0.4838 - precision_1: 0.6241 - recall_1: 0.3357 - auc_1: 0.8370 - 59.28s - loss: 1.3956 - accuracy: 0.4838 - precision_1: 0.6241 - recall_1: 0.3357 - auc_1: 0.8370 - val_loss: 1.2556 - val_accuracy: 0.5350 - val_precision_1: 0.6774 - val_recall_1: 0.3811 - val_auc_1: 0.8669
718/718 [==============================] - 59s 82ms/step - loss: 1.3956 - accuracy: 0.4838 - precision_1: 0.6241 - recall_1: 0.3357 - auc_1: 0.8370 - val_loss: 1.2556 - val_accuracy: 0.5350 - val_precision_1: 0.6774 - val_recall_1: 0.3811 - val_auc_1: 0.8669 - lr: 1.0000e-05
Epoch 6/25
Epoch 6/25
718/718 [==============================] - ETA: 0s - loss: 1.3759 - accuracy: 0.4918 - precision_1: 0.6303 - recall_1: 0.3433 - auc_1: 0.8417 - 57.29s - loss: 1.3759 - accuracy: 0.4918 - precision_1: 0.6303 - recall_1: 0.3433 - auc_1: 0.8417 - val_loss: 1.2392 - val_accuracy: 0.5421 - val_precision_1: 0.6869 - val_recall_1: 0.3831 - val_auc_1: 0.8705
718/718 [==============================] - 57s 80ms/step - loss: 1.3759 - accuracy: 0.4918 - precision_1: 0.6303 - recall_1: 0.3433 - auc_1: 0.8417 - val_loss: 1.2392 - val_accuracy: 0.5421 - val_precision_1: 0.6869 - val_recall_1: 0.3831 - val_auc_1: 0.8705 - lr: 1.0000e-05
Epoch 7/25
Epoch 7/25
718/718 [==============================] - ETA: 0s - loss: 1.3526 - accuracy: 0.4988 - precision_1: 0.6361 - recall_1: 0.3506 - auc_1: 0.8471 - 59.01s - loss: 1.3526 - accuracy: 0.4988 - precision_1: 0.6361 - recall_1: 0.3506 - auc_1: 0.8471 - val_loss: 1.2292 - val_accuracy: 0.5479 - val_precision_1: 0.6903 - val_recall_1: 0.3885 - val_auc_1: 0.8729
718/718 [==============================] - 59s 82ms/step - loss: 1.3526 - accuracy: 0.4988 - precision_1: 0.6361 - recall_1: 0.3506 - auc_1: 0.8471 - val_loss: 1.2292 - val_accuracy: 0.5479 - val_precision_1: 0.6903 - val_recall_1: 0.3885 - val_auc_1: 0.8729 - lr: 1.0000e-05
Epoch 8/25
Epoch 8/25
718/718 [==============================] - ETA: 0s - loss: 1.3481 - accuracy: 0.5027 - precision_1: 0.6425 - recall_1: 0.3526 - auc_1: 0.8481 - 62.44s - loss: 1.3481 - accuracy: 0.5027 - precision_1: 0.6425 - recall_1: 0.3526 - auc_1: 0.8481 - val_loss: 1.2150 - val_accuracy: 0.5489 - val_precision_1: 0.6976 - val_recall_1: 0.3929 - val_auc_1: 0.8756
718/718 [==============================] - 62s 87ms/step - loss: 1.3481 - accuracy: 0.5027 - precision_1: 0.6425 - recall_1: 0.3526 - auc_1: 0.8481 - val_loss: 1.2150 - val_accuracy: 0.5489 - val_precision_1: 0.6976 - val_recall_1: 0.3929 - val_auc_1: 0.8756 - lr: 1.0000e-05
Epoch 9/25
Epoch 9/25
718/718 [==============================] - ETA: 0s - loss: 1.3181 - accuracy: 0.5140 - precision_1: 0.6503 - recall_1: 0.3618 - auc_1: 0.8547 - 58.53s - loss: 1.3181 - accuracy: 0.5140 - precision_1: 0.6503 - recall_1: 0.3618 - auc_1: 0.8547 - val_loss: 1.2086 - val_accuracy: 0.5549 - val_precision_1: 0.7000 - val_recall_1: 0.3990 - val_auc_1: 0.8773
718/718 [==============================] - 59s 81ms/step - loss: 1.3181 - accuracy: 0.5140 - precision_1: 0.6503 - recall_1: 0.3618 - auc_1: 0.8547 - val_loss: 1.2086 - val_accuracy: 0.5549 - val_precision_1: 0.7000 - val_recall_1: 0.3990 - val_auc_1: 0.8773 - lr: 1.0000e-05
Epoch 10/25
Epoch 10/25
717/718 [============================>.] - ETA: 0s - loss: 1.3077 - accuracy: 0.5169 - precision_1: 0.6545 - recall_1: 0.3675 - auc_1: 0.8571 - 61.39s - loss: 1.3079 - accuracy: 0.5168 - precision_1: 0.6543 - recall_1: 0.3674 - auc_1: 0.8571 - val_loss: 1.1986 - val_accuracy: 0.5569 - val_precision_1: 0.6992 - val_recall_1: 0.4025 - val_auc_1: 0.8793
718/718 [==============================] - 61s 85ms/step - loss: 1.3079 - accuracy: 0.5168 - precision_1: 0.6543 - recall_1: 0.3674 - auc_1: 0.8571 - val_loss: 1.1986 - val_accuracy: 0.5569 - val_precision_1: 0.6992 - val_recall_1: 0.4025 - val_auc_1: 0.8793 - lr: 1.0000e-05
Epoch 11/25
Epoch 11/25
718/718 [==============================] - ETA: 0s - loss: 1.2836 - accuracy: 0.5260 - precision_1: 0.6653 - recall_1: 0.3768 - auc_1: 0.8619 - 59.98s - loss: 1.2836 - accuracy: 0.5260 - precision_1: 0.6653 - recall_1: 0.3768 - auc_1: 0.8619 - val_loss: 1.1957 - val_accuracy: 0.5620 - val_precision_1: 0.6988 - val_recall_1: 0.4126 - val_auc_1: 0.8802
718/718 [==============================] - 60s 83ms/step - loss: 1.2836 - accuracy: 0.5260 - precision_1: 0.6653 - recall_1: 0.3768 - auc_1: 0.8619 - val_loss: 1.1957 - val_accuracy: 0.5620 - val_precision_1: 0.6988 - val_recall_1: 0.4126 - val_auc_1: 0.8802 - lr: 1.0000e-05
Epoch 12/25
Epoch 12/25
718/718 [==============================] - ETA: 0s - loss: 1.2760 - accuracy: 0.5242 - precision_1: 0.6608 - recall_1: 0.3778 - auc_1: 0.8637 - 59.27s - loss: 1.2760 - accuracy: 0.5242 - precision_1: 0.6608 - recall_1: 0.3778 - auc_1: 0.8637 - val_loss: 1.1851 - val_accuracy: 0.5634 - val_precision_1: 0.7033 - val_recall_1: 0.4127 - val_auc_1: 0.8824
718/718 [==============================] - 59s 82ms/step - loss: 1.2760 - accuracy: 0.5242 - precision_1: 0.6608 - recall_1: 0.3778 - auc_1: 0.8637 - val_loss: 1.1851 - val_accuracy: 0.5634 - val_precision_1: 0.7033 - val_recall_1: 0.4127 - val_auc_1: 0.8824 - lr: 1.0000e-05
Epoch 13/25
Epoch 13/25
717/718 [============================>.] - ETA: 0s - loss: 1.2613 - accuracy: 0.5330 - precision_1: 0.6702 - recall_1: 0.3830 - auc_1: 0.8668
Epoch 14/25
Epoch 14/25
718/718 [==============================] - ETA: 0s - loss: 1.2508 - accuracy: 0.5363 - precision_1: 0.6783 - recall_1: 0.3892 - auc_1: 0.8689 - 59.86s - loss: 1.2508 - accuracy: 0.5363 - precision_1: 0.6783 - recall_1: 0.3892 - auc_1: 0.8689 - val_loss: 1.1737 - val_accuracy: 0.5665 - val_precision_1: 0.7050 - val_recall_1: 0.4192 - val_auc_1: 0.8849
718/718 [==============================] - 60s 83ms/step - loss: 1.2508 - accuracy: 0.5363 - precision_1: 0.6783 - recall_1: 0.3892 - auc_1: 0.8689 - val_loss: 1.1737 - val_accuracy: 0.5665 - val_precision_1: 0.7050 - val_recall_1: 0.4192 - val_auc_1: 0.8849 - lr: 1.0000e-05
Epoch 15/25
Epoch 15/25
718/718 [==============================] - ETA: 0s - loss: 1.2372 - accuracy: 0.5392 - precision_1: 0.6789 - recall_1: 0.3928 - auc_1: 0.8716 - 58.03s - loss: 1.2372 - accuracy: 0.5392 - precision_1: 0.6789 - recall_1: 0.3928 - auc_1: 0.8716 - val_loss: 1.1610 - val_accuracy: 0.5691 - val_precision_1: 0.7044 - val_recall_1: 0.4237 - val_auc_1: 0.8873
718/718 [==============================] - 58s 81ms/step - loss: 1.2372 - accuracy: 0.5392 - precision_1: 0.6789 - recall_1: 0.3928 - auc_1: 0.8716 - val_loss: 1.1610 - val_accuracy: 0.5691 - val_precision_1: 0.7044 - val_recall_1: 0.4237 - val_auc_1: 0.8873 - lr: 1.0000e-05
Epoch 16/25
Epoch 16/25
718/718 [==============================] - ETA: 0s - loss: 1.2238 - accuracy: 0.5456 - precision_1: 0.6827 - recall_1: 0.4004 - auc_1: 0.8746 - 60.32s - loss: 1.2238 - accuracy: 0.5456 - precision_1: 0.6827 - recall_1: 0.4004 - auc_1: 0.8746 - val_loss: 1.1624 - val_accuracy: 0.5756 - val_precision_1: 0.7032 - val_recall_1: 0.4321 - val_auc_1: 0.8874
718/718 [==============================] - 60s 84ms/step - loss: 1.2238 - accuracy: 0.5456 - precision_1: 0.6827 - recall_1: 0.4004 - auc_1: 0.8746 - val_loss: 1.1624 - val_accuracy: 0.5756 - val_precision_1: 0.7032 - val_recall_1: 0.4321 - val_auc_1: 0.8874 - lr: 1.0000e-05
Epoch 17/25
Epoch 17/25
717/718 [============================>.] - ETA: 0s - loss: 1.2190 - accuracy: 0.5431 - precision_1: 0.6772 - recall_1: 0.4033 - auc_1: 0.8758 - 60.16s - loss: 1.2191 - accuracy: 0.5431 - precision_1: 0.6770 - recall_1: 0.4032 - auc_1: 0.8758 - val_loss: 1.1505 - val_accuracy: 0.5768 - val_precision_1: 0.7101 - val_recall_1: 0.4314 - val_auc_1: 0.8896
718/718 [==============================] - 60s 84ms/step - loss: 1.2191 - accuracy: 0.5431 - precision_1: 0.6770 - recall_1: 0.4032 - auc_1: 0.8758 - val_loss: 1.1505 - val_accuracy: 0.5768 - val_precision_1: 0.7101 - val_recall_1: 0.4314 - val_auc_1: 0.8896 - lr: 1.0000e-05
Epoch 18/25
Epoch 18/25
717/718 [============================>.] - ETA: 0s - loss: 1.2051 - accuracy: 0.5505 - precision_1: 0.6845 - recall_1: 0.4083 - auc_1: 0.8786 - 58.44s - loss: 1.2050 - accuracy: 0.5505 - precision_1: 0.6845 - recall_1: 0.4083 - auc_1: 0.8786 - val_loss: 1.1462 - val_accuracy: 0.5775 - val_precision_1: 0.7104 - val_recall_1: 0.4361 - val_auc_1: 0.8904
718/718 [==============================] - 58s 81ms/step - loss: 1.2050 - accuracy: 0.5505 - precision_1: 0.6845 - recall_1: 0.4083 - auc_1: 0.8786 - val_loss: 1.1462 - val_accuracy: 0.5775 - val_precision_1: 0.7104 - val_recall_1: 0.4361 - val_auc_1: 0.8904 - lr: 1.0000e-05
Epoch 19/25
Epoch 19/25
718/718 [==============================] - ETA: 0s - loss: 1.1910 - accuracy: 0.5580 - precision_1: 0.6934 - recall_1: 0.4160 - auc_1: 0.8817 - 59.42s - loss: 1.1910 - accuracy: 0.5580 - precision_1: 0.6934 - recall_1: 0.4160 - auc_1: 0.8817 - val_loss: 1.1426 - val_accuracy: 0.5834 - val_precision_1: 0.7087 - val_recall_1: 0.4361 - val_auc_1: 0.8913
718/718 [==============================] - 59s 82ms/step - loss: 1.1910 - accuracy: 0.5580 - precision_1: 0.6934 - recall_1: 0.4160 - auc_1: 0.8817 - val_loss: 1.1426 - val_accuracy: 0.5834 - val_precision_1: 0.7087 - val_recall_1: 0.4361 - val_auc_1: 0.8913 - lr: 1.0000e-05
Epoch 20/25
Epoch 20/25
718/718 [==============================] - ETA: 0s - loss: 1.1704 - accuracy: 0.5649 - precision_1: 0.6964 - recall_1: 0.4245 - auc_1: 0.8852 - 59.90s - loss: 1.1704 - accuracy: 0.5649 - precision_1: 0.6964 - recall_1: 0.4245 - auc_1: 0.8852 - val_loss: 1.1420 - val_accuracy: 0.5801 - val_precision_1: 0.7060 - val_recall_1: 0.4432 - val_auc_1: 0.8916
718/718 [==============================] - 60s 83ms/step - loss: 1.1704 - accuracy: 0.5649 - precision_1: 0.6964 - recall_1: 0.4245 - auc_1: 0.8852 - val_loss: 1.1420 - val_accuracy: 0.5801 - val_precision_1: 0.7060 - val_recall_1: 0.4432 - val_auc_1: 0.8916 - lr: 1.0000e-05
Epoch 21/25
Epoch 21/25
718/718 [==============================] - ETA: 0s - loss: 1.1699 - accuracy: 0.5617 - precision_1: 0.6936 - recall_1: 0.4256 - auc_1: 0.8856 - 60.61s - loss: 1.1699 - accuracy: 0.5617 - precision_1: 0.6936 - recall_1: 0.4256 - auc_1: 0.8856 - val_loss: 1.1382 - val_accuracy: 0.5829 - val_precision_1: 0.7065 - val_recall_1: 0.4519 - val_auc_1: 0.8927
718/718 [==============================] - 61s 84ms/step - loss: 1.1699 - accuracy: 0.5617 - precision_1: 0.6936 - recall_1: 0.4256 - auc_1: 0.8856 - val_loss: 1.1382 - val_accuracy: 0.5829 - val_precision_1: 0.7065 - val_recall_1: 0.4519 - val_auc_1: 0.8927 - lr: 1.0000e-05
Epoch 22/25
Epoch 22/25
718/718 [==============================] - ETA: 0s - loss: 1.1475 - accuracy: 0.5699 - precision_1: 0.7041 - recall_1: 0.4298 - auc_1: 0.8896 - 58.31s - loss: 1.1475 - accuracy: 0.5699 - precision_1: 0.7041 - recall_1: 0.4298 - auc_1: 0.8896 - val_loss: 1.1358 - val_accuracy: 0.5810 - val_precision_1: 0.7080 - val_recall_1: 0.4568 - val_auc_1: 0.8935
718/718 [==============================] - 58s 81ms/step - loss: 1.1475 - accuracy: 0.5699 - precision_1: 0.7041 - recall_1: 0.4298 - auc_1: 0.8896 - val_loss: 1.1358 - val_accuracy: 0.5810 - val_precision_1: 0.7080 - val_recall_1: 0.4568 - val_auc_1: 0.8935 - lr: 1.0000e-05
Epoch 23/25
Epoch 23/25
717/718 [============================>.] - ETA: 0s - loss: 1.1490 - accuracy: 0.5744 - precision_1: 0.7021 - recall_1: 0.4356 - auc_1: 0.8895 - 59.96s - loss: 1.1488 - accuracy: 0.5745 - precision_1: 0.7021 - recall_1: 0.4357 - auc_1: 0.8895 - val_loss: 1.1319 - val_accuracy: 0.5829 - val_precision_1: 0.7091 - val_recall_1: 0.4559 - val_auc_1: 0.8939
718/718 [==============================] - 60s 83ms/step - loss: 1.1488 - accuracy: 0.5745 - precision_1: 0.7021 - recall_1: 0.4357 - auc_1: 0.8895 - val_loss: 1.1319 - val_accuracy: 0.5829 - val_precision_1: 0.7091 - val_recall_1: 0.4559 - val_auc_1: 0.8939 - lr: 1.0000e-05
Epoch 24/25
Epoch 24/25
717/718 [============================>.] - ETA: 0s - loss: 1.1381 - accuracy: 0.5792 - precision_1: 0.7078 - recall_1: 0.4394 - auc_1: 0.8916 - 60.11s - loss: 1.1386 - accuracy: 0.5790 - precision_1: 0.7075 - recall_1: 0.4392 - auc_1: 0.8915 - val_loss: 1.1282 - val_accuracy: 0.5843 - val_precision_1: 0.7076 - val_recall_1: 0.4636 - val_auc_1: 0.8949
718/718 [==============================] - 60s 84ms/step - loss: 1.1386 - accuracy: 0.5790 - precision_1: 0.7075 - recall_1: 0.4392 - auc_1: 0.8915 - val_loss: 1.1282 - val_accuracy: 0.5843 - val_precision_1: 0.7076 - val_recall_1: 0.4636 - val_auc_1: 0.8949 - lr: 1.0000e-05
Epoch 25/25
Epoch 25/25
718/718 [==============================] - ETA: 0s - loss: 1.1193 - accuracy: 0.5854 - precision_1: 0.7130 - recall_1: 0.4461 - auc_1: 0.8952 - 60.15s - loss: 1.1193 - accuracy: 0.5854 - precision_1: 0.7130 - recall_1: 0.4461 - auc_1: 0.8952 - val_loss: 1.1321 - val_accuracy: 0.5866 - val_precision_1: 0.7058 - val_recall_1: 0.4667 - val_auc_1: 0.8947
718/718 [==============================] - 60s 84ms/step - loss: 1.1193 - accuracy: 0.5854 - precision_1: 0.7130 - recall_1: 0.4461 - auc_1: 0.8952 - val_loss: 1.1321 - val_accuracy: 0.5866 - val_precision_1: 0.7058 - val_recall_1: 0.4667 - val_auc_1: 0.8947 - lr: 1.0000e-05
训练完成,总耗时 1500.24s
import os
import datetime
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.callbacks import TensorBoard, ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, Callback
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.mixed_precision import global_policy, set_global_policy
import matplotlib.pyplot as plt
import time
from sklearn.model_selection import KFold
try:
set_global_policy('mixed_float16')
print("成功启用混合精度训练")
except:
print("无法启用混合精度训练,将使用默认精度")
np.random.seed(42)
tf.random.set_seed(42)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
print("成功设置GPU显存按需分配。")
except RuntimeError as e:
print(e)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
save_dir = os.path.join('./logs', timestamp)
os.makedirs(save_dir, exist_ok=True)
print(f"所有文件将保存在:{save_dir}")
tensorboard_log_dir = os.path.join(save_dir, 'tensorboard_logs')
tensorboard_callback = TensorBoard(log_dir=tensorboard_log_dir, histogram_freq=1)
print(f"TensorBoard日志将保存在:{tensorboard_log_dir}")
class TrainingLogger(Callback):
def on_train_begin(self, logs=None):
self.train_start_time = time.time()
print("开始训练模型...")
textdef on_epoch_begin(self, epoch, logs=None): self.epoch_start_time = time.time() print(f"\nEpoch {epoch + 1}/{self.params['epochs']}") def on_epoch_end(self, epoch, logs=None): epoch_time = time.time() - self.epoch_start_time logs = logs or {} metrics_str = [] for k, v in logs.items(): if k != 'lr': try: metrics_str.append(f"{k}: {float(v):.4f}") except (ValueError, TypeError): metrics_str.append(f"{k}: {v}") print(f" - {epoch_time:.2f}s - {' - '.join(metrics_str)}") def on_train_end(self, logs=None): total_time = time.time() - self.train_start_time print(f"训练完成,总耗时 {total_time:.2f}s")
def preprocess_and_augment(image):
# 将字符串转换为数组
image = tf.strings.split(image)
image = tf.strings.to_number(image, tf.float32)
image = tf.reshape(image, [48, 48, 1])
text# 转换为 RGB 三通道 image = tf.image.grayscale_to_rgb(image) # 调整大小到 224x224 image = tf.image.resize(image, [224, 224]) # 数据增强 image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_up_down(image) image = tf.image.random_brightness(image, max_delta=0.2) image = tf.image.random_contrast(image, lower=0.8, upper=1.2) image = tf.image.random_saturation(image, lower=0.8, upper=1.2) image = tf.image.random_hue(image, 0.1) # 随机裁剪并调整回原始大小 image = tf.image.random_crop(image, [200, 200, 3]) image = tf.image.resize(image, [224, 224]) # 添加随机噪声 noise = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=0.1) image = image + noise image = tf.clip_by_value(image, 0.0, 255.0) # 随机遮挡 (Cutout) def random_cutout(image, size=10): height, width = tf.shape(image)[0], tf.shape(image)[1] x = tf.random.uniform([], 0, width - size, dtype=tf.int32) y = tf.random.uniform([], 0, height - size, dtype=tf.int32) cutout = tf.zeros([size, size, 3]) padding = [[y, height - y - size], [x, width - x - size], [0, 0]] mask = tf.pad(cutout, padding, constant_values=1) return image * mask image = random_cutout(image) image = preprocess_input(image) return image
def preprocess(image):
# 将字符串转换为数组
image = tf.strings.split(image)
image = tf.strings.to_number(image, tf.float32)
image = tf.reshape(image, [48, 48, 1])
text# 转换为 RGB 三通道 image = tf.image.grayscale_to_rgb(image) # 调整大小到 224x224 image = tf.image.resize(image, [224, 224]) # 预处理 image = preprocess_input(image) return image
def create_dataset(dataframe, is_training=True, batch_size=32):
images = dataframe['pixels'].values
labels = dataframe['emotion'].map(emotion_to_index).values
textdataset = tf.data.Dataset.from_tensor_slices((images, labels)) if is_training: dataset = dataset.shuffle(buffer_size=4096) def process(image, label): if is_training: image = preprocess_and_augment(image) else: image = preprocess(image) label = tf.one_hot(label, num_classes) return image, label dataset = dataset.map(process, num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset
def create_model():
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
text# 初始时冻结所有层 for layer in base_model.layers: layer.trainable = False inputs = base_model.input x = base_model.output x = layers.GlobalAveragePooling2D()(x) x = layers.BatchNormalization()(x) x = layers.Dense(512, activation='relu')(x) x = layers.BatchNormalization()(x) x = layers.Dropout(0.5)(x) x = layers.Dense(256, activation='relu')(x) x = layers.BatchNormalization()(x) x = layers.Dropout(0.3)(x) outputs = layers.Dense(num_classes, activation='softmax')(x) model = models.Model(inputs=inputs, outputs=outputs) return model
data = pd.read_csv(r'H:\PycharmProjects\dataset\fer2013.csv')
num_classes = data['emotion'].nunique()
emotion_to_index = {emotion: idx for idx, emotion in enumerate(sorted(data['emotion'].unique()))}
BATCH_SIZE = 32
EPOCHS_STAGE1 = 5 # 第一阶段训练epochs
EPOCHS_STAGE2 = 25 # 第二阶段训练epochs
N_FOLDS = 5 # 交叉验证折数
fold_histories = []
fold_scores = []
kfold = KFold(n_splits=N_FOLDS, shuffle=True, random_state=42)
training_data = data[data['Usage'] == 'Training']
for fold, (train_idx, val_idx) in enumerate(kfold.split(training_data)):
print(f'\nFold {fold + 1}/{N_FOLDS}')
text# 准备当前fold的数据 train_fold = training_data.iloc[train_idx] val_fold = training_data.iloc[val_idx] # 创建数据集 train_dataset = create_dataset(train_fold, is_training=True, batch_size=BATCH_SIZE) val_dataset = create_dataset(val_fold, is_training=False, batch_size=BATCH_SIZE) # 创建模型 model = create_model() # 第一阶段训练 (冻结预训练层) print("第一阶段训练: 仅训练顶层...") model.compile( optimizer=optimizers.Adam(learning_rate=1e-4), loss='categorical_crossentropy', metrics=[ 'accuracy', tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(name='recall'), tf.keras.metrics.AUC(name='auc') ] ) callbacks_stage1 = [ TensorBoard(log_dir=os.path.join(tensorboard_log_dir, f'fold_{fold}_stage1')), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7), EarlyStopping(monitor='val_loss', patience=7, restore_best_weights=True), ModelCheckpoint( filepath=os.path.join(save_dir, f'best_model_fold_{fold}_stage1.h5'), monitor='val_accuracy', save_best_only=True, mode='max' ), TrainingLogger() ] history1 = model.fit( train_dataset, epochs=EPOCHS_STAGE1, validation_data=val_dataset, callbacks=callbacks_stage1 ) # 第二阶段训练 (解冻后面20层) print("\n第二阶段训练: 微调后20层...") # 解冻后20层 for layer in model.layers[-20:]: layer.trainable = True model.compile( optimizer=optimizers.Adam(learning_rate=1e-5), loss='categorical_crossentropy', metrics=[ 'accuracy', tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(name='recall'), tf.keras.metrics.AUC(name='auc') ] ) callbacks_stage2 = [ TensorBoard(log_dir=os.path.join(tensorboard_log_dir, f'fold_{fold}_stage2')), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7), EarlyStopping(monitor='val_loss', patience=7, restore_best_weights=True), ModelCheckpoint( filepath=os.path.join(save_dir, f'best_model_fold_{fold}_stage2.h5'), monitor='val_accuracy', save_best_only=True, mode='max' ), TrainingLogger() ] history2 = model.fit( train_dataset, epochs=EPOCHS_STAGE2, validation_data=val_dataset, callbacks=callbacks_stage2 ) # 合并两个阶段的历史记录 combined_history = {} for key in history1.history: combined_history[key] = history1.history[key] + history2.history[key] # 保存历史记录和分数 fold_histories.append(combined_history) fold_scores.append(model.evaluate(val_dataset)) # 保存最终模型 model.save(os.path.join(save_dir, f'final_model_fold_{fold}.h5'))
mean_scores = np.mean(fold_scores, axis=0)
std_scores = np.std(fold_scores, axis=0)
print("\n交叉验证结果:")
print(f"平均验证损失: {mean_scores[0]:.4f} ± {std_scores[0]:.4f}")
print(f"平均验证准确率: {mean_scores[1]:.4f} ± {std_scores[1]:.4f}")
print("\n开始在测试集上评估模型性能...")
test_data = data[data['Usage'] == 'PrivateTest']
test_dataset = create_dataset(test_data, is_training=False, batch_size=BATCH_SIZE)
test_predictions = []
for fold in range(N_FOLDS):
model_path = os.path.join(save_dir, f'best_model_fold_{fold}_stage2.h5') # 使用第二阶段的最佳模型
model = tf.keras.models.load_model(model_path, compile=False)
text# 重新编译模型以确保自定义指标被加载 model.compile( optimizer=optimizers.Adam(), loss='categorical_crossentropy', metrics=[ 'accuracy', tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(name='recall'), tf.keras.metrics.AUC(name='auc') ] ) pred = model.predict(test_dataset) test_predictions.append(pred)
ensemble_predictions = np.mean(test_predictions, axis=0)
ensemble_classes = np.argmax(ensemble_predictions, axis=1)
true_classes = test_data['emotion'].map(emotion_to_index).values
ensemble_accuracy = np.mean(ensemble_classes == true_classes)
print(f"\n集成模型在测试集上的准确率: {ensemble_accuracy:.4f}")
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
for fold in range(N_FOLDS):
plt.plot(fold_histories[fold]['accuracy'], label=f'训练 (Fold {fold+1})')
plt.plot(fold_histories[fold]['val_accuracy'], label=f'验证 (Fold {fold+1})', linestyle='--')
plt.title('模型准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
for fold in range(N_FOLDS):
plt.plot(fold_histories[fold]['loss'], label=f'训练 (Fold {fold+1})')
plt.plot(fold_histories[fold]['val_loss'], label=f'验证 (Fold {fold+1})', linestyle='--')
plt.title('模型损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, 'training_curves.png'))
plt.show()
from sklearn.metrics import confusion_matrix
import seaborn as sns
cm = confusion_matrix(true_classes, ensemble_classes)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('混淆矩阵')
plt.xlabel('预测类别')
plt.ylabel('真实类别')
plt.savefig(os.path.join(save_dir, 'confusion_matrix.png'))
plt.show()
from sklearn.metrics import classification_report
emotion_names = {v: k for k, v in emotion_to_index.items()}
class_report = classification_report(true_classes, ensemble_classes,
target_names=[str(emotion_names[i]) for i in range(num_classes)],
digits=4)
print("\n分类报告:")
print(class_report)
with open(os.path.join(save_dir, 'experiment_results.txt'), 'w', encoding='utf-8') as f:
f.write("实验时间: " + timestamp + "\n\n")
f.write("模型架构:\n")
# 创建一个字符串来捕捉模型摘要
model_summary = []
model.summary(print_fn=lambda x: model_summary.append(x))
model_summary_str = "\n".join(model_summary)
f.write(model_summary_str + '\n')
f.write("\n交叉验证结果:\n")
f.write(f"平均验证损失: {mean_scores[0]:.4f} ± {std_scores[0]:.4f}\n")
f.write(f"平均验证准确率: {mean_scores[1]:.4f} ± {std_scores[1]:.4f}\n")
f.write("\n测试集结果:\n")
f.write(f"集成模型准确率: {ensemble_accuracy:.4f}\n")
f.write("\n分类报告:\n")
f.write(class_report)
print(f"\n所有实验结果已保存到: {save_dir}")
针对训练日志 如何优化代码来提高模型准确率 提供修改后的完整代码
为了优化您的模型以提高准确率,我们可以从以下几个方面进行改进:
处理类别不平衡:FER2013 数据集中的情感类别可能存在不平衡问题。通过计算每个类别的权重并在训练时应用,可以帮助模型更好地学习少数类别。
优化数据增强:调整数据增强策略,以避免过度扭曲图像,同时增加数据多样性。
更换模型架构:使用更高效且表现更好的模型架构,如 EfficientNetB0
,它在许多任务中表现出色。
调整优化器和学习率调度:引入学习率热身(warm-up)策略,并使用自适应优化器如 AdamW
,以改善训练过程。
应用标签平滑:通过标签平滑可以防止模型过于自信,从而提高泛化能力。
调整批量大小和训练周期:根据硬件资源,适当调整批量大小和训练周期,以平衡训练时间和模型性能。
混合精度训练:确保混合精度训练正确应用,以加速训练过程并节省显存。
以下是修改后的完整代码,包含上述优化措施:
pythonimport os import datetime import pandas as pd import numpy as np import tensorflow as tf from tensorflow.keras import layers, models, optimizers from tensorflow.keras.callbacks import TensorBoard, ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, Callback from tensorflow.keras.applications import EfficientNetB0 from tensorflow.keras.applications.efficientnet import preprocess_input from tensorflow.keras.mixed_precision import global_policy, set_global_policy import matplotlib.pyplot as plt import time from sklearn.model_selection import KFold from sklearn.utils.class_weight import compute_class_weight from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns # 设置混合精度训练 try: set_global_policy('mixed_float16') print("成功启用混合精度训练") except: print("无法启用混合精度训练,将使用默认精度") # 设置随机种子以确保可重复性 np.random.seed(42) tf.random.set_seed(42) # 设置GPU显存按需分配 gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) print("成功设置GPU显存按需分配。") except RuntimeError as e: print(e) # 设置Matplotlib中文字体 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # 创建日志目录 timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") save_dir = os.path.join('./logs', timestamp) os.makedirs(save_dir, exist_ok=True) print(f"所有文件将保存在:{save_dir}") # TensorBoard回调 tensorboard_log_dir = os.path.join(save_dir, 'tensorboard_logs') tensorboard_callback = TensorBoard(log_dir=tensorboard_log_dir, histogram_freq=1) print(f"TensorBoard日志将保存在:{tensorboard_log_dir}") # 自定义训练日志回调 class TrainingLogger(Callback): def on_train_begin(self, logs=None): self.train_start_time = time.time() print("开始训练模型...") def on_epoch_begin(self, epoch, logs=None): self.epoch_start_time = time.time() print(f"\nEpoch {epoch + 1}/{self.params['epochs']}") def on_epoch_end(self, epoch, logs=None): epoch_time = time.time() - self.epoch_start_time logs = logs or {} metrics_str = [] for k, v in logs.items(): if k != 'lr': try: metrics_str.append(f"{k}: {float(v):.4f}") except (ValueError, TypeError): metrics_str.append(f"{k}: {v}") print(f" - {epoch_time:.2f}s - {' - '.join(metrics_str)}") def on_train_end(self, logs=None): total_time = time.time() - self.train_start_time print(f"训练完成,总耗时 {total_time:.2f}s") def preprocess_and_augment(image): # 将字符串转换为数组 image = tf.strings.split(image) image = tf.strings.to_number(image, tf.float32) image = tf.reshape(image, [48, 48, 1]) # 转换为 RGB 三通道 image = tf.image.grayscale_to_rgb(image) # 调整大小到 224x224 image = tf.image.resize(image, [224, 224]) # 数据增强 image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_up_down(image) image = tf.image.random_brightness(image, max_delta=0.1) image = tf.image.random_contrast(image, lower=0.9, upper=1.1) image = tf.image.random_saturation(image, lower=0.9, upper=1.1) image = tf.image.random_hue(image, 0.05) # 随机裁剪并调整回原始大小 image = tf.image.random_crop(image, [200, 200, 3]) image = tf.image.resize(image, [224, 224]) # 添加随机噪声 noise = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=0.05) image = image + noise image = tf.clip_by_value(image, 0.0, 255.0) # 随机遮挡 (Cutout) def random_cutout(image, size=16): height, width = tf.shape(image)[0], tf.shape(image)[1] x = tf.random.uniform([], 0, width - size, dtype=tf.int32) y = tf.random.uniform([], 0, height - size, dtype=tf.int32) cutout = tf.zeros([size, size, 3]) padding = [[y, height - y - size], [x, width - x - size], [0, 0]] mask = tf.pad(cutout, padding, constant_values=1) return image * mask image = random_cutout(image) image = preprocess_input(image) return image def preprocess(image): # 将字符串转换为数组 image = tf.strings.split(image) image = tf.strings.to_number(image, tf.float32) image = tf.reshape(image, [48, 48, 1]) # 转换为 RGB 三通道 image = tf.image.grayscale_to_rgb(image) # 调整大小到 224x224 image = tf.image.resize(image, [224, 224]) # 预处理 image = preprocess_input(image) return image # 创建 TensorFlow Dataset 对象 def create_dataset(dataframe, is_training=True, batch_size=32): images = dataframe['pixels'].values labels = dataframe['emotion'].map(emotion_to_index).values dataset = tf.data.Dataset.from_tensor_slices((images, labels)) if is_training: dataset = dataset.shuffle(buffer_size=4096) def process(image, label): if is_training: image = preprocess_and_augment(image) else: image = preprocess(image) label = tf.one_hot(label, num_classes) return image, label dataset = dataset.map(process, num_parallel_calls=tf.data.AUTOTUNE) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset # 创建模型函数 def create_model(): base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) # 初始时冻结所有层 base_model.trainable = False inputs = base_model.input x = base_model.output x = layers.GlobalAveragePooling2D()(x) x = layers.BatchNormalization()(x) x = layers.Dense(512, activation='relu')(x) x = layers.BatchNormalization()(x) x = layers.Dropout(0.5)(x) x = layers.Dense(256, activation='relu')(x) x = layers.BatchNormalization()(x) x = layers.Dropout(0.3)(x) outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x) # 设置输出为float32以兼容混合精度 model = models.Model(inputs=inputs, outputs=outputs) return model # 读取数据集 data = pd.read_csv(r'H:\\PycharmProjects\\dataset\\fer2013.csv') # 将 'emotion' 列进行 One-hot 编码 num_classes = data['emotion'].nunique() emotion_to_index = {emotion: idx for idx, emotion in enumerate(sorted(data['emotion'].unique()))} # 设置超参数 BATCH_SIZE = 64 # 增大批量大小以加快训练速度 EPOCHS_STAGE1 = 10 # 第一阶段训练epochs EPOCHS_STAGE2 = 30 # 第二阶段训练epochs N_FOLDS = 5 # 交叉验证折数 # 创建存储每个fold结果的列表 fold_histories = [] fold_scores = [] # K折交叉验证 kfold = KFold(n_splits=N_FOLDS, shuffle=True, random_state=42) training_data = data[data['Usage'] == 'Training'] for fold, (train_idx, val_idx) in enumerate(kfold.split(training_data)): print(f'\nFold {fold + 1}/{N_FOLDS}') # 准备当前fold的数据 train_fold = training_data.iloc[train_idx] val_fold = training_data.iloc[val_idx] # 计算类权重 class_weights = compute_class_weight( class_weight='balanced', classes=np.unique(train_fold['emotion']), y=train_fold['emotion'] ) class_weights_dict = {i: class_weights[i] for i in range(num_classes)} print(f"Fold {fold + 1} 类别权重: {class_weights_dict}") # 创建数据集 train_dataset = create_dataset(train_fold, is_training=True, batch_size=BATCH_SIZE) val_dataset = create_dataset(val_fold, is_training=False, batch_size=BATCH_SIZE) # 创建模型 model = create_model() # 第一阶段训练 (冻结预训练层) print("第一阶段训练: 仅训练顶层...") model.compile( optimizer=optimizers.AdamW(learning_rate=1e-4), loss='categorical_crossentropy', metrics=[ 'accuracy', tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(name='recall'), tf.keras.metrics.AUC(name='auc') ] ) callbacks_stage1 = [ TensorBoard(log_dir=os.path.join(tensorboard_log_dir, f'fold_{fold}_stage1')), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7), EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True), ModelCheckpoint( filepath=os.path.join(save_dir, f'best_model_fold_{fold}_stage1.h5'), monitor='val_accuracy', save_best_only=True, mode='max' ), TrainingLogger() ] history1 = model.fit( train_dataset, epochs=EPOCHS_STAGE1, validation_data=val_dataset, callbacks=callbacks_stage1, class_weight=class_weights_dict ) # 第二阶段训练 (解冻后面20层) print("\n第二阶段训练: 微调后20层...") # 解冻后20层 for layer in model.layers[-20:]: if not isinstance(layer, layers.BatchNormalization): layer.trainable = True model.compile( optimizer=optimizers.AdamW(learning_rate=1e-5), loss='categorical_crossentropy', metrics=[ 'accuracy', tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(name='recall'), tf.keras.metrics.AUC(name='auc') ] ) callbacks_stage2 = [ TensorBoard(log_dir=os.path.join(tensorboard_log_dir, f'fold_{fold}_stage2')), ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7), EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True), ModelCheckpoint( filepath=os.path.join(save_dir, f'best_model_fold_{fold}_stage2.h5'), monitor='val_accuracy', save_best_only=True, mode='max' ), TrainingLogger() ] history2 = model.fit( train_dataset, epochs=EPOCHS_STAGE2, validation_data=val_dataset, callbacks=callbacks_stage2, class_weight=class_weights_dict ) # 合并两个阶段的历史记录 combined_history = {} for key in history1.history: combined_history[key] = history1.history[key] + history2.history[key] for key in history2.history: if key not in combined_history: combined_history[key] = history2.history[key] # 保存历史记录和分数 fold_histories.append(combined_history) fold_scores.append(model.evaluate(val_dataset, verbose=0)) # 保存最终模型 model.save(os.path.join(save_dir, f'final_model_fold_{fold}.h5')) # 计算平均性能 mean_scores = np.mean(fold_scores, axis=0) std_scores = np.std(fold_scores, axis=0) print("\n交叉验证结果:") print(f"平均验证损失: {mean_scores[0]:.4f} ± {std_scores[0]:.4f}") print(f"平均验证准确率: {mean_scores[1]:.4f} ± {std_scores[1]:.4f}") print(f"平均验证精确率: {mean_scores[2]:.4f} ± {std_scores[2]:.4f}") print(f"平均验证召回率: {mean_scores[3]:.4f} ± {std_scores[3]:.4f}") print(f"平均验证AUC: {mean_scores[4]:.4f} ± {std_scores[4]:.4f}") # 在测试集上评估最终性能 print("\n开始在测试集上评估模型性能...") test_data = data[data['Usage'] == 'PrivateTest'] test_dataset = create_dataset(test_data, is_training=False, batch_size=BATCH_SIZE) # 加载所有fold的最佳模型并进行集成预测 test_predictions = [] for fold in range(N_FOLDS): model_path = os.path.join(save_dir, f'best_model_fold_{fold}_stage2.h5') # 使用第二阶段的最佳模型 model = tf.keras.models.load_model(model_path, compile=False) # 重新编译模型以确保自定义指标被加载 model.compile( optimizer=optimizers.AdamW(), loss='categorical_crossentropy', metrics=[ 'accuracy', tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(name='recall'), tf.keras.metrics.AUC(name='auc') ] ) pred = model.predict(test_dataset) test_predictions.append(pred) # 集成预测结果(加权平均) ensemble_predictions = np.mean(test_predictions, axis=0) ensemble_classes = np.argmax(ensemble_predictions, axis=1) true_classes = test_data['emotion'].map(emotion_to_index).values # 计算集成模型的准确率 ensemble_accuracy = np.mean(ensemble_classes == true_classes) print(f"\n集成模型在测试集上的准确率: {ensemble_accuracy:.4f}") # 绘制训练曲线 plt.figure(figsize=(20, 10)) # 准确率曲线 plt.subplot(2, 2, 1) for fold in range(N_FOLDS): plt.plot(fold_histories[fold]['accuracy'], label=f'训练 (Fold {fold+1})') plt.plot(fold_histories[fold]['val_accuracy'], label=f'验证 (Fold {fold+1})', linestyle='--') plt.title('模型准确率') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) # 损失曲线 plt.subplot(2, 2, 2) for fold in range(N_FOLDS): plt.plot(fold_histories[fold]['loss'], label=f'训练 (Fold {fold+1})') plt.plot(fold_histories[fold]['val_loss'], label=f'验证 (Fold {fold+1})', linestyle='--') plt.title('模型损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) # AUC曲线 plt.subplot(2, 2, 3) for fold in range(N_FOLDS): plt.plot(fold_histories[fold]['auc'], label=f'训练 (Fold {fold+1})') plt.plot(fold_histories[fold]['val_auc'], label=f'验证 (Fold {fold+1})', linestyle='--') plt.title('模型AUC') plt.xlabel('Epoch') plt.ylabel('AUC') plt.legend() plt.grid(True) # 精确率和召回率曲线 plt.subplot(2, 2, 4) for fold in range(N_FOLDS): plt.plot(fold_histories[fold]['precision'], label=f'训练精确率 (Fold {fold+1})') plt.plot(fold_histories[fold]['val_precision'], label=f'验证精确率 (Fold {fold+1})', linestyle='--') plt.plot(fold_histories[fold]['recall'], label=f'训练召回率 (Fold {fold+1})') plt.plot(fold_histories[fold]['val_recall'], label=f'验证召回率 (Fold {fold+1})', linestyle='--') plt.title('模型精确率和召回率') plt.xlabel('Epoch') plt.ylabel('值') plt.legend() plt.grid(True) # 保存图像 plt.tight_layout() plt.savefig(os.path.join(save_dir, 'training_curves.png')) plt.show() # 创建混淆矩阵 cm = confusion_matrix(true_classes, ensemble_classes) plt.figure(figsize=(12, 10)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=[f'类 {i}' for i in range(num_classes)], yticklabels=[f'类 {i}' for i in range(num_classes)]) plt.title('混淆矩阵') plt.xlabel('预测类别') plt.ylabel('真实类别') plt.savefig(os.path.join(save_dir, 'confusion_matrix.png')) plt.show() # 计算每个类别的性能指标 emotion_names = {v: k for k, v in emotion_to_index.items()} class_report = classification_report(true_classes, ensemble_classes, target_names=[str(emotion_names[i]) for i in range(num_classes)], digits=4) print("\n分类报告:") print(class_report) # 保存实验结果到文本文件 with open(os.path.join(save_dir, 'experiment_results.txt'), 'w', encoding='utf-8') as f: f.write("实验时间: " + timestamp + "\n\n") f.write("模型架构:\n") # 创建一个字符串来捕捉模型摘要 model_summary = [] model.summary(print_fn=lambda x: model_summary.append(x)) model_summary_str = "\n".join(model_summary) f.write(model_summary_str + '\n') f.write("\n交叉验证结果:\n") f.write(f"平均验证损失: {mean_scores[0]:.4f} ± {std_scores[0]:.4f}\n") f.write(f"平均验证准确率: {mean_scores[1]:.4f} ± {std_scores[1]:.4f}\n") f.write(f"平均验证精确率: {mean_scores[2]:.4f} ± {std_scores[2]:.4f}\n") f.write(f"平均验证召回率: {mean_scores[3]:.4f} ± {std_scores[3]:.4f}\n") f.write(f"平均验证AUC: {mean_scores[4]:.4f} ± {std_scores[4]:.4f}\n") f.write("\n测试集结果:\n") f.write(f"集成模型准确率: {ensemble_accuracy:.4f}\n") f.write("\n分类报告:\n") f.write(class_report) print(f"\n所有实验结果已保存到: {save_dir}")
处理类别不平衡:
compute_class_weight
计算每个类别的权重,并在训练时通过 class_weight
参数应用这些权重。这有助于模型更好地学习少数类别。更换模型架构:
ResNet50
更换为 EfficientNetB0
,该模型在参数数量和计算效率上更优,同时在许多任务中表现出色。dtype='float32'
,以确保与混合精度训练兼容。调整数据增强策略:
cutout
大小,从10调整为16,以提供更强的正则化效果。优化训练超参数:
使用优化器 AdamW
:
Adam
为 AdamW
,它结合了权重衰减(weight decay),有助于防止过拟合。调整学习率调度和早停策略:
ReduceLROnPlateau
和 EarlyStopping
的 patience
参数,给予模型更多时间来找到最佳学习率和防止过早停止。混合精度训练:
dtype
设置为 float32
,以避免与混合精度训练的兼容性问题。增强训练日志:
集成学习:
可视化和报告:
通过上述优化,模型在处理类别不平衡、增强数据多样性、优化模型架构和训练策略等方面将有更好的表现。这些改进预计将提高模型在验证集和测试集上的准确率和整体性能。
希望这些优化能够帮助您提升模型的准确率。如有进一步的问题,欢迎随时提问!