基于这些代码的架构,我们的 CNN 模型有三个卷积和一个池化层,其后是两个致密层,以及用于正则化的失活。让我们训练我们的模型。
import datetime -
logdir = os.path.join('/home/dipanzan_sarkar/projects/tensorboard_logs', datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1) reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=0.000001) callbacks = [reduce_lr, tensorboard_callback] -
history = model.fit(x=train_imgs_scaled, y=train_labels_enc, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_data=(val_imgs_scaled, val_labels_enc), callbacks=callbacks, verbose=1) -
# Output Train on 17361 samples, validate on 1929 samples Epoch 1/25 17361/17361 [====] - 32s 2ms/sample - loss: 0.4373 - accuracy: 0.7814 - val_loss: 0.1834 - val_accuracy: 0.9393 Epoch 2/25 17361/17361 [====] - 30s 2ms/sample - loss: 0.1725 - accuracy: 0.9434 - val_loss: 0.1567 - val_accuracy: 0.9513 ... ... Epoch 24/25 17361/17361 [====] - 30s 2ms/sample - loss: 0.0036 - accuracy: 0.9993 - val_loss: 0.3693 - val_accuracy: 0.9565 Epoch 25/25 17361/17361 [====] - 30s 2ms/sample - loss: 0.0034 - accuracy: 0.9994 - val_loss: 0.3699 - val_accuracy: 0.9559
我们获得了 95.6% 的验证精确率,这很好,尽管我们的模型看起来有些过拟合(通过查看我们的训练精确度,是 99.9%)。通过绘制训练和验证的精度和损失曲线,我们可以清楚地看到这一点。
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) t = f.suptitle('Basic CNN Performance', fontsize=12) f.subplots_adjust(top=0.85, wspace=0.3) -
max_epoch = len(history.history['accuracy'])+1 epoch_list = list(range(1,max_epoch)) ax1.plot(epoch_list, history.history['accuracy'], label='Train Accuracy') ax1.plot(epoch_list, history.history['val_accuracy'], label='Validation Accuracy') ax1.set_xticks(np.arange(1, max_epoch, 5)) ax1.set_ylabel('Accuracy Value') ax1.set_xlabel('Epoch') ax1.set_title('Accuracy') l1 = ax1.legend(loc="best") -
ax2.plot(epoch_list, history.history['loss'], label='Train Loss') ax2.plot(epoch_list, history.history['val_loss'], label='Validation Loss') ax2.set_xticks(np.arange(1, max_epoch, 5)) ax2.set_ylabel('Loss Value') ax2.set_xlabel('Epoch') ax2.set_title('Loss') l2 = ax2.legend(loc="best")

基础 CNN 学习曲线
(编辑:PHP编程网 - 黄冈站长网)
【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!
|