我正在学习第一次机器学习练习。
这是月温的预测系统。
train_t
具有温度,train_x
具有每个数据的权重。
但是我有一个问题,即初始化train_x
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
x = tf.placeholder(tf.float32,[None,5])
w = tf.Variable(tf.zeros([5,1]))
y = tf.matmul(x,w)
t = tf.placeholder(tf.float32,[None,1])
loss = tf.reduce_sum(tf.square(y-t))
train_step = tf.train.AdamOptimizer().minimize(loss)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
train_t = np.array([5.2,5.7,8.6,14.9,18.2,20.4,25.5,26.4,22.8,17.5,11.1,6.6]) #montly temperature
train_t = train_t.reshape([12,1])
train_x = np.zeros([12,5])
for row, month in enumerate(range(1,13)):
for col, n in enumerate(range(0,5)):
train_x[row][col] = month**n ## why initialize like this??
i = 0
for _ in range(10000):
i += 1
sess.run(train_step,feed_dict={x:train_x,t:train_t})
if i % 1000 == 0:
loss_val = sess.run(loss,feed_dict={x:train_x,t:train_t})
print('step : %d,Loss: %f' % (i,loss_val))
w_val = sess.run(w)
pprint(w_val)
def predict(x):
result = 0.0
for n in range(0,5):
result += w_val[n][0] * x**n
return result
fig = plt.figure()
subplot = fig.add_subplot(1,1,1)
subplot.set_xlim(1,12)
subplot.scatter(range(1,13),train_t)
linex = np.linspace(1,12,100)
liney = predict(linex)
subplot.plot(linex, liney)
但是我不明白这里
for row, month in enumerate(range(1,13)): #
for col, n in enumerate(range(0,5)): #
train_x[row][col] = month**n ## why initialize like this??
这是什么意思? 我的书中没有对此有任何评论? 为什么train_x在这里初始化?
答案 0 :(得分:1)
事实上,这段代码:
train_t = np.array([5.2,5.7,8.6,14.9,18.2,20.4,25.5,26.4,22.8,17.5,11.1,6.6]) #montly temperature
train_t = train_t.reshape([12,1])
train_x = np.zeros([12,5])
for row, month in enumerate(range(1,13)):
for col, n in enumerate(range(0,5)):
train_x[row][col] = month**n
是否生成了您的数据。它会初始化train_t
和train_x
,这些数据将被注入placeholders
x
和t
train_t
是一个温度的张量
train_x
是每种温度的一种重量。
它们构成了数据集。
答案 1 :(得分:0)
train_x
和train_t
都是包含训练数据的数组。在数组train_t
中,您拥有模型的目标,而train_x
包含模型输入中的功能。
模型的权重(经过培训的权重)为w
(代码中唯一的tf.Variable
),随机初始化。
您正在训练的模型是线性变量range(0, 5)
的4阶(最大month
)多项式,其范围为range(1, 13)
。剪切的代码从线性变量month
开始生成4次多项式的特征。