我想在tensorflow上运行keras拟合后,使历史对象腌制。但是我遇到了错误。
import gzip
import numpy as np
import os
import pickle
import tensorflow as tf
from tensorflow import keras
with gzip.open('mnist.pkl.gz', 'rb') as f:
train_set, test_set = pickle.load(f, encoding='latin1')
X_train = np.asarray(train_set[0])
y_train = np.asarray(train_set[1])
X_test = np.asarray(test_set[0])
y_test = np.asarray(test_set[1])
X_valid, X_train = X_train[:5000]/255.0, X_train[5000:]/255.0
y_valid, y_train = y_train[:5000], y_train[5000:]
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot']
model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
model.add(keras.layers.Dense(300, activation = 'relu'))
model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(10, activation = 'softmax'))
model.summary()
model.compile(loss='sparse_categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
history = model.fit(X_train, y_train, epochs=1,
validation_data =(X_valid, y_valid))
if not os.path.isdir('models'):
os.mkdir('models')
model.save('models/basic.h5')
with open('models/basic_history.pickle', 'wb') as f:
pickle.dump(history, f)
它给我以下错误:
Traceback (most recent call last):
File "main.py", line 69, in <module>
pickle.dump(history, f)
TypeError: can't pickle _thread._local objects
PS:要运行代码,请下载fashion_mnist数据:https://s3.amazonaws.com/img-datasets/mnist.pkl.g
答案 0 :(得分:0)
如Karl所建议,不能对历史对象进行腌制。但它的字典可以:
with open('models/basic_history.pickle', 'wb') as f:
pickle.dump(history.history, f)