我正在omniglot
数据集上构建暹罗模型。但是要避免使用Numpy,因为我希望模型尽可能通用。并且可以Git Repository。但是,我得到标题中描述的错误。当我尝试拟合模型时发生错误。下面的代码和下面的完整堆栈跟踪:
这是脚本代码
' Load Omniglot dataset '
ds, ds_info = tfds.load(name='Omniglot', with_info=True, as_supervised=True)
' Split the dataset into testing and training '
ds_train, ds_test = ds["train"], ds["test"]
'Create a CNN encoder'
model = conv_net (ds_info)
optimizer = keras.optimizers.Adam()
' Compile the model with the contrastive loss function '
cont_loss_model = compile_cnn(model, contrastive_loss, optimizer)
' Fit the results '
cont_loss_model.fit(
ds_train,
validation_data=ds_test,
epochs=epochs,
verbose=1
)
这里是使用的功能
@tf.function
def contrastive_loss(label, embedding, margin = 0.4):
'''
contrastive_loss function
@param p: Positive vector
@param n: Negetive vector
@returns (float): y - Integer value representing distance
'''
# Assign the label
y = label
# Assign the embeddings
p1 = embedding[0]
p2 = embedding[1]
# Get the euclean distance
d = tf.norm(p1 - p2, axis=-1)
if y == 0:
return (1/2) * tf.math.sqrt(d)
else:
return (1/2) * tf.math.sqrt(tf.math.maximum(0.0, (margin-d)))
#- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def conv_net (ds_info, batch_size = 128, epochs = 12):
'''
Create a cnn
@param x
@param y
@param ds_info
@optional loss
@optional batch
@optional opochs
'''
# Get the input shape
image_shape = ds_info.features['image'].shape
model = keras.Sequential(
[
Conv2D(32, 3, activation='relu', input_shape=image_shape),
MaxPooling2D(),
Conv2D(32, 3, activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(
128,
activation='relu',
kernel_regularizer=regularizers.l2(0.01),
bias_regularizer=regularizers.l1(0.01)
),
Dense(ds_info.features['label'].num_classes, activation='softmax')
])
model.summary()
return model
这是错误
Input 0 of layer sequential_17 is incompatible with the layer: : expected min_ndim=4, found ndim=3. Full shape received: [105, 105, 3]
完整堆栈跟踪
C:\Users\User\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py:806 train_function *
return step_function(self, iterator)
C:\Users\User\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py:796 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
C:\Users\User\anaconda3\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1211 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
C:\Users\User\anaconda3\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2585 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
C:\Users\User\anaconda3\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2945 _call_for_each_replica
return fn(*args, **kwargs)
C:\Users\User\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py:789 run_step **
outputs = model.train_step(data)
C:\Users\User\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py:747 train_step
y_pred = self(x, training=True)
C:\Users\User\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:975 __call__
input_spec.assert_input_compatibility(self.input_spec, inputs,
C:\Users\User\anaconda3\lib\site-packages\tensorflow\python\keras\engine\input_spec.py:191 assert_input_compatibility
raise ValueError('Input ' + str(input_index) + ' of layer ' +