Keras / Tensorflow中的input_shape参数

时间:2020-06-22 06:50:46

标签: python tensorflow keras neural-network

我使用以下代码在Tensorflow中进行机器学习的教程:

import tensorflow as tf
import numpy as np
from tensorflow import keras
 
model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
model.summary()
 
model.compile(optimizer='sgd', loss='mean_squared_error')
 
xs = np.array([-1.0,  0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)
 
model.fit(xs, ys, epochs=50)
 
print(model.predict([10.0]))

工作正常,但我难以理解每次运行时将哪些数据作为输入。输入数据是两个数字数组,model.summary()调用显示模型需要两个输入,但我不知道该输入到底是什么-例如。 -1.0和-3.0是第一个纪元还是两个完整的数组都放入了该层的1个神经元?

1 个答案:

答案 0 :(得分:0)

它的工作方式是您提供数据和输出(在您的情况下为xs和ys)。网络将分批处理它。如果批处理大小为1,则将首先使用xs [0]和ys [0],然后反向传播,然后是下一个。如果批处理大小大于1,则根据批处理的数组将变为数据,例如,您的批处理大小为4,则首先xs [:4]和ys [:4]将通过网络,然后反向传播发生。