我的模型使用预处理的数据来预测客户是私人客户还是非私人客户。预处理步骤使用诸如feature_column.bucketized_column(...),feature_column.embedding_column(...)等步骤。
训练后,我试图保存模型,但是出现以下错误:
AssertionError:尝试导出引用未跟踪对象Tensor(“ 14867:0”,shape =(),dtype = resource)的函数。必须捕获由函数捕获的TensorFlow对象(例如tf.Variable)将它们分配给被跟踪对象的属性或直接分配给主对象的属性。
以下是相关代码:
(feature_columns, train_ds, val_ds, test_ds) = preprocessing.getPreProcessedDatasets(args.data, args.zip, args.batchSize)
feature_layer = tf.keras.layers.DenseFeatures(feature_columns, trainable=False)
model = tf.keras.models.Sequential([
feature_layer,
tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)
])
model.compile(optimizer='sgd',
loss='binary_crossentropy',
metrics=['accuracy'])
...
model.fit(train_ds,
validation_data=val_ds,
epochs=args.epoch,
callbacks=[tensorboard_callback])
model.summary()
if args.saveModel:
filepath = "./saved_models/logReg" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-e{}-b{}-a{}".format(args.epoch, args.batchSize, accuracy)
model.save(filepath)