我正在尝试使用python SDk和tensorflow进行sagemaker上的测试分类。我可以修改此https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/tensorflow_abalone_age_predictor_using_keras/abalone.py并运行它,但是当我更改拱形以包含嵌入图层时,我收到错误
&#34; Fetch参数不能解释为Tensor。 (Tensor Tensor(&#34;第一层/嵌入:0&#34;,shape =(*,),dtype = float32_ref)不是此图的元素。&#34; < / p>
当我将其作为独立模型运行时,它运行完美。 这里是独立模型的拱门
model = Sequential()
model.add(Embedding(len(word_index) + 1,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH,
trainable=False))
model.add(Conv1D(64, kernel_size=10, padding='same', activation='relu'))
model.add(Conv1D(64, kernel_size=15, padding='same', activation='selu'))
model.add(Conv1D(128, kernel_size=15, padding='same', activation='relu'))
model.add(Conv1D(64, kernel_size=25, padding='same', activation='softmax'))
model.add(Conv1D(128, kernel_size=15, padding='same', activation='relu'))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(2, activation='softmax'))
这是sagemaker的model_fn:
embedding = tf.keras.layers.Embedding(len(word_index) + 1,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH,
trainable=False, name='first-layer')(features[INPUT_TENSOR_NAME])
first = tf.keras.layers.Conv1D(64, kernel_size=10, padding='same', activation='relu')(embedding)
second = tf.keras.layers.Conv1D(64, kernel_size=15, padding='same', activation='relu')(first)
third = tf.keras.layers.Conv1D(128, kernel_size=15, padding='same', activation='relu')(second)
fourth = tf.keras.layers.Conv1D(64, kernel_size=25, padding='same', activation='softmax')(third)
fifth = tf.keras.layers.Conv1D(128, kernel_size=15, padding='same', activation='relu')(fourth)
sixth = tf.keras.layers.BatchNormalization()(fifth)
output = tf.keras.layers.Flatten()(sixth)
output_layer = tf.keras.layers.Dense(2, activation='softmax'))(output)
输入维度或值没有问题,如果我只用简单的密集层拱形替换这个拱门,代码就可以完美运行。
我已经尝试过解决方案了 TensorFlow: The tensor is not the element of this graph但我收到了新错误
输入图和图层图不一样:Tensor(&#34; random_shuffle_queue_DequeueMany:1&#34;,shape =(128,200),dtype = float32,device = / device:CPU:0)不是来自传入的图表。*