1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
| import tensorflow as tf import os import numpy as np from matplotlib import pyplot as plt from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense from tensorflow.keras import Model
np.set_printoptions(threshold=np.inf)
cifar10 = tf.keras.datasets.cifar10 (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0
class AlexNet8(Model): def __init__(self): super(AlexNet8, self).__init__() self.c1 = Conv2D(filters=96, kernel_size=(3, 3)) self.b1 = BatchNormalization() self.a1 = Activation('relu') self.p1 = MaxPool2D(pool_size=(3, 3), strides=2)
self.c2 = Conv2D(filters=256, kernel_size=(3, 3)) self.b2 = BatchNormalization() self.a2 = Activation('relu') self.p2 = MaxPool2D(pool_size=(3, 3), strides=2)
self.c3 = Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu') self.c4 = Conv2D(filters=384, kernel_size=(3, 3), padding='same', activation='relu') self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same', activation='relu') self.p3 = MaxPool2D(pool_size=(3, 3), strides=2)
self.flatten = Flatten() self.f1 = Dense(2048, activation='relu') self.d1 = Dropout(0.5) self.f2 = Dense(2048, activation='relu') self.d2 = Dropout(0.5) self.f3 = Dense(10, activation='softmax')
def call(self, x): x = self.c1(x) x = self.b1(x) x = self.a1(x) x = self.p1(x)
x = self.c2(x) x = self.b2(x) x = self.a2(x) x = self.p2(x)
x = self.c3(x)
x = self.c4(x)
x = self.c5(x) x = self.p3(x)
x = self.flatten(x) x = self.f1(x) x = self.d1(x) x = self.f2(x) x = self.d2(x) y = self.f3(x) return y
model = AlexNet8()
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/AlexNet8.ckpt" if os.path.exists(checkpoint_save_path + '.index'): print('-------------load the model-----------------') model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) model.summary()
# print(model.trainable_variables) file = open('./weight_txt/weights.txt', 'w') for v in model.trainable_variables: file.write(str(v.name) + '\n') file.write(str(v.shape) + '\n') file.write(str(v.numpy()) + '\n') file.close()
############################################### show ###############################################
# 显示训练集和验证集的acc和loss曲线 acc = history.history['sparse_categorical_accuracy'] val_acc = history.history['val_sparse_categorical_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss']
plt.subplot(1, 2, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.legend()
plt.subplot(1, 2, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show()
|