根据keras documentation:
predict_on_batch(self, x)
Returns predictions for a single batch of samples.
但是,在批处理调用时,标准predict
方法似乎没有任何差别,无论它是使用一个还是多个元素。
model.predict_on_batch(np.zeros((n, d_in)))
与
相同model.predict(np.zeros((n, d_in)))
(形状numpy.ndarray
的{{1}})
答案 0 :(得分:13)
区别在于您传递的x
数据大于一个批次。
predict
将遍历所有数据,逐批,预测标签。
因此,它在内部分批进行分批并一次喂食一批。
predict_on_batch
假设您传入的数据恰好是一个批次,因此将其提供给网络。它不会尝试拆分它(根据您的设置,如果阵列非常大,可能会对您的GPU内存造成问题)
答案 1 :(得分:2)
我只想添加一些不适合评论的内容。似乎predict
仔细检查 输出形状:
class ExtractShape(keras.engine.topology.Layer):
def call(self, x):
return keras.backend.sum(x, axis=0)
def compute_output_shape(self, input_shape):
return input_shape
a = keras.layers.Input((None, None))
b = ExtractShape()(a)
m = keras.Model(a, b)
m.compile(optimizer=keras.optimizers.Adam(), loss='binary_crossentropy')
A = np.ones((5,4,3))
然后:
In [163]: m.predict_on_batch(A)
Out[163]:
array([[5., 5., 5.],
[5., 5., 5.],
[5., 5., 5.],
[5., 5., 5.]], dtype=float32)
In [164]: m.predict_on_batch(A).shape
Out[164]: (4, 3)
可是:
In [165]: m.predict(A)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-165-c5ba5fc88b6e> in <module>()
----> 1 m.predict(A)
~/miniconda3/envs/ccia/lib/python3.6/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose, steps)
1746 f = self.predict_function
1747 return self._predict_loop(f, ins, batch_size=batch_size,
-> 1748 verbose=verbose, steps=steps)
1749
1750 def train_on_batch(self, x, y,
~/miniconda3/envs/ccia/lib/python3.6/site-packages/keras/engine/training.py in _predict_loop(self, f, ins, batch_size, verbose, steps)
1306 outs.append(np.zeros(shape, dtype=batch_out.dtype))
1307 for i, batch_out in enumerate(batch_outs):
-> 1308 outs[i][batch_start:batch_end] = batch_out
1309 if verbose == 1:
1310 progbar.update(batch_end)
ValueError: could not broadcast input array from shape (4,3) into shape (5,3)
我不确定这是不是真的。
答案 2 :(得分:0)
与预测是否在单个批次上执行相比,predict_on_batch似乎要快得多。
总而言之,predict方法具有额外的操作以确保正确处理一批批次,而predict_on_batch是用于预测应在单个批次上使用的轻量级选择。