我正在尝试学习Tensorflow。我已经整理了一个简单模型的玩具演示。我注意到,如果我采用最终预测并将其与实际值进行比较,我无法达到与Tensorflow相同的损失值。
玩具模型是时间序列回归。试图预测时间序列中的下一个值。这个模型没有任何花哨的东西,我不是在寻找模型构建的技巧。我只是想了解Tensorflow正在做什么,这段代码正在计算一个我无法复制的损失值。
我在R中使用tensorflow包执行此操作,该包是对python库的直接绑定。所以代码很容易移植。
我使用的时间序列值是:
prices = c(104.285380,101.347495,101.357033,102.778283,106.727258,106.841721,104.209071,105.134314,104.733694,101.891194,
101.099492,103.703526,104.495229,107.213726,107.766964,107.881427,104.104147,109.989455,113.413808,111.754094,
113.156266,113.175344,114.043355,114.405821,113.886963,114.643464,116.845936,119.584662,121.097665,121.691375,
122.409572,123.257045,123.003282,124.003971,127.360347,126.565541,123.328865,124.884959,123.012858,123.616144,
123.874695,123.089466,121.049785,121.231728,121.748831,119.230352,117.056607,119.172896,118.349363,119.651694,
121.653071,123.022434,122.088777,120.561411,121.815862,121.317912,118.148267,118.971800,118.023780,121.011481,
119.153744,118.981376,120.006005,121.949926,120.666746,120.274132,121.193425,121.710527,121.471128,120.944449,
121.404096,120.819962,119.460175,122.189325,121.528583,123.166074,124.171550,124.755684,127.025188,125.023811,
123.185225,119.843213,123.482080,123.242681,120.465651,119.709150,119.948549,122.715809,121.465765,121.028250,
121.167678,123.994700,123.821617,125.187049,125.071660,125.062044,126.340935,127.446743,124.638953,126.970765,
126.715948,125.273590,125.518791,124.965887,125.119739,124.388944,123.706228,122.888892,122.523495,123.927390,
123.648534,122.283102,122.042709,122.696577,122.408106,122.965818,121.735006,122.706193,122.148482,123.186979,
122.600420,121.879241,119.744552,120.605159,121.735006,121.581154,121.158062,120.859975,117.859871,115.455941,
118.542587,120.831128,120.783049,121.946551,123.571608,124.638953,126.994804,125.725529,120.408036,120.350342,
119.715705,118.052185,118.638744,118.263731,117.667556,116.638674,113.888579,110.234605,110.965400,110.705776,
111.582500,115.639343,109.621692,111.312044,111.225111,112.007502,113.166600,112.529097,111.089883,108.810324,
102.155170,99.605154,100.204021,105.951215,109.071121,109.428509,108.916574,104.048363,108.510890,106.608038,
105.545531,108.481913,106.395536,108.733051,110.317151,111.379658,112.316595,112.442164,110.037036,109.583056,
111.283066,109.534760,110.423402,111.080224,110.800109,108.607482,105.342689,106.202353,105.844965,106.617697,
107.004063,107.515998,107.004063,105.767692,108.298389,107.796113,107.979637,106.453491,108.047251,107.255201,
107.921682,109.892149,109.882489,111.563182,115.021157,111.350680,110.645562,115.204681,116.421734,115.426842,
117.049579,118.392201,117.841629,116.798441,117.436526,116.961193,113.274931,112.634686,112.256359,108.977526,
110.757603,110.287119,113.779367,115.224769,115.729205,114.225599,115.321776,114.497218,114.283803,114.759136,
113.827870,112.799597,111.751923,115.467287,114.739735,114.691232,112.159352,112.692890,109.792384,109.113336,
107.182899,108.007458,105.718095,102.856393,104.117482,104.020475,105.359170,104.796530,103.622747,105.485279,
104.107781,102.109440,102.196746,99.635764,97.685926,93.563134,94.057869,95.580877,96.968075,94.474998,96.541245,
94.222780,93.766848,93.892957,93.417623,98.384375,96.463639,96.997177,90.623825,91.273771,94.426495,93.543732,
91.652098,93.466127,93.708644,91.696830,92.662368,92.642862,91.940652,91.384737,91.667571,94.252091,95.695522,
93.881481,93.666917,94.486161,92.350274,93.725434,94.369126,94.515420,94.300856,98.045972,98.260536,98.992004,
100.464693,99.352862,98.533617,98.621394,98.670158,99.733225,99.986801,101.995899,103.351553,103.185754,103.302789,
103.293036,104.083021,103.507600,103.058966,102.590827,105.019300,106.852847,106.296931,107.272222,108.374300,
107.096670,108.218254,105.858050,105.975085,106.326190,107.711103,109.271568,109.330085,107.135681,104.824242,
104.268327,104.482891,103.351553,103.068719,102.483545,101.771582,95.402934,92.486815,91.423748,91.326219,92.828167,
91.862629,90.936103,90.981767,91.050455,91.668644,90.775704,88.646385,88.823011,92.120021,91.737332,92.787273,
92.434021,93.434899,94.622215,96.064657,97.752412,98.527602,98.468727,97.987913,96.614159,95.888032,96.084282,
96.780972,97.173473,97.085160,97.781850,96.977222,95.515156,95.632906,95.318905,95.721219,93.542837,93.317149,
94.111964,93.758713,94.298402,91.649019,90.314515,91.835457,92.630272,93.807775,94.092339,93.209211,93.739088,
94.141401,94.867529,95.161904,95.593656,95.053967,96.937972,96.928160,97.958475,97.997725,98.086038,97.565974,
96.810409,95.515156,94.857716,101.019984,102.383926,102.256363,104.061868,102.521301,103.806742,103.885243,106.032880,
106.910897,107.344972,106.545878,106.476821,106.723455,108.005951,107.907298,107.749452,107.611337,107.887567,
107.049012,107.384434,106.575474,106.121668,105.500150,105.381766,104.572806,104.671460,105.292978,106.279514,
106.249917,106.901031,104.099269,101.741448,104.020346,106.496551,110.265119,114.013955,113.372707,112.050749,
112.040883,112.021153,113.076746,111.192462,111.360173,111.567346,112.415767,110.669598,111.527885,111.005021,
111.478558,111.527885,112.356575,112.524286,114.487492,114.734126,115.760124,115.404971,116.046219,115.967296,
115.888373,115.543086,115.483894,115.030087,116.065950,116.657871,114.033686,112.938631,112.188864,112.011287,
109.988889,110.087542,108.351239,107.931825,109.488725,110.133301,109.954803,106.890586,107.525246,104.827942,
106.216260,109.072229,109.032563,109.141645,110.797711,110.867126,110.301883,110.857210,110.639046,110.529963,
109.597807,108.576401,108.982980,108.199572,109.032563,110.103551,111.184456,112.999187,112.354610,114.228840,
114.228840,114.853583,115.002331,115.666741,115.974154,116.083236,115.319661,115.547742,116.281568,115.785740,
115.755990,114.853583,115.180830,115.051914,115.636991,116.926144,117.997132,118.116131,118.750791,118.254963,
118.046715,118.998705,118.988788,118.780540,118.998705,119.078037,118.968955,120.863018,120.922517,120.932434,
120.615104,120.337440,127.675694,127.457529,128.002940,129.202844,130.432497,130.938241,131.315071,131.581537,
132.746769,134.469718,134.957721,134.793393,135.166865,136.142871,136.551200,135.973564,136.103034,136.371934,
136.431689,139.220278,138.393660,139.210318,138.772112,138.951378,138.433497,138.114801,138.572927,138.632682,
138.423538,139.887547,140.116610,139.419462,140.883471,139.270074,140.843634,140.345672,140.066813,140.305835,
143.213935,143.532630,143.343405,143.074505,143.114342,144.179981,143.433038,143.074505,142.755809,142.586502,
141.052778,141.222086,140.475142,141.251963,140.624531,140.106650,141.859477,141.690170,143.054587,143.950919,
143.065343,143.203975,143.064546,146.002523,146.908814,146.460648,145.932808,148.352905,152.376439,153.332527,
152.635380,153.322568,156.100000,155.700000,155.470000,150.250000,152.540000,152.960000,153.990000,153.800000,
153.340000,153.870000,153.610000,153.670000,152.760000,153.180000,155.450000,153.930000,154.450000,155.370000,
154.990000,148.980000,145.320000,146.590000)
我正在使用的代码是
#split into train and test
v.train = prices[1:500]
v.test = prices[-(1:500)]
sampleSize = 30
getJoinedLaggedData = function(dataVector){
data.len = length(dataVector)
result = NULL
for (i in seq(1, data.len - sampleSize + 1, by = 1)) {
result = rbind(result, dataVector[seq(i, i + sampleSize - 1, by = 1)])
}
list(input=result[,-sampleSize], output = result[,sampleSize])
}
m.train = getJoinedLaggedData(v.train)
#build the model
indata = tf$placeholder(tf$float32, shape(471, sampleSize - 1))
outData = tf$placeholder(tf$float32, shape(471))
w1 <- tf$Variable(tf$truncated_normal(shape(sampleSize - 1, 300),stddev = 1.0 / sqrt(sampleSize)))
b1 = tf$Variable(tf$zeros(shape(300)))
l1 <- tf$matmul(indata, w1) + b1
w2 <- tf$Variable(tf$truncated_normal(shape(300, 50),stddev = 1.0 / sqrt(300)))
b2 <- tf$Variable(tf$zeros(shape(50)))
l2 <- tf$matmul(l1, w2) + b2
w3 <- tf$Variable(tf$truncated_normal(shape(50, 1),stddev = 1.0 / sqrt(50)))
b3 <- tf$Variable(tf$zeros(shape(1)))
pred <- tf$matmul(l2, w3) + b3
#loss function
loss <- tf$reduce_mean(tf$abs(tf$subtract(x = pred, y = outData)))
#trainer
optimizer <- tf$train$GradientDescentOptimizer(0.00003)
train.op <- optimizer$minimize(loss)
#run the model
init = tf$global_variables_initializer()
sess = tf$Session()
sess$run(init)
for (i in 1:1000){
values = sess$run(list(train.op, loss, pred), feed_dict = dict(indata = m.train$input, outData = m.train$output))
loss_value = values[[2]]
pred_value = values[[3]]
print(loss_value)
}
myLoss = mean(abs(pred_value - m.train$output))
print(sprintf("Tensorflow Loss: %f, My Loss: %f", loss_value, myLoss))
当我运行此操作时,Tensorflow会计算出大约11.71的损失,并计算出大约5.03的损失
我尝试过的其他损失函数也是如此。
非常感谢任何建议!
答案 0 :(得分:0)
差异是由于m.train $ output,tensorflow将其读作矢量而不是矩阵。
我已根据需要修改了代码。请检查
prices = c(104.285380,101.347495,101.357033,102.778283,106.727258,106.841721,104.209071,105.134314,104.733694,101.891194,
101.099492,103.703526,104.495229,107.213726,107.766964,107.881427,104.104147,109.989455,113.413808,111.754094,
113.156266,113.175344,114.043355,114.405821,113.886963,114.643464,116.845936,119.584662,121.097665,121.691375,
122.409572,123.257045,123.003282,124.003971,127.360347,126.565541,123.328865,124.884959,123.012858,123.616144,
123.874695,123.089466,121.049785,121.231728,121.748831,119.230352,117.056607,119.172896,118.349363,119.651694,
121.653071,123.022434,122.088777,120.561411,121.815862,121.317912,118.148267,118.971800,118.023780,121.011481,
119.153744,118.981376,120.006005,121.949926,120.666746,120.274132,121.193425,121.710527,121.471128,120.944449,
121.404096,120.819962,119.460175,122.189325,121.528583,123.166074,124.171550,124.755684,127.025188,125.023811,
123.185225,119.843213,123.482080,123.242681,120.465651,119.709150,119.948549,122.715809,121.465765,121.028250,
121.167678,123.994700,123.821617,125.187049,125.071660,125.062044,126.340935,127.446743,124.638953,126.970765,
126.715948,125.273590,125.518791,124.965887,125.119739,124.388944,123.706228,122.888892,122.523495,123.927390,
123.648534,122.283102,122.042709,122.696577,122.408106,122.965818,121.735006,122.706193,122.148482,123.186979,
122.600420,121.879241,119.744552,120.605159,121.735006,121.581154,121.158062,120.859975,117.859871,115.455941,
118.542587,120.831128,120.783049,121.946551,123.571608,124.638953,126.994804,125.725529,120.408036,120.350342,
119.715705,118.052185,118.638744,118.263731,117.667556,116.638674,113.888579,110.234605,110.965400,110.705776,
111.582500,115.639343,109.621692,111.312044,111.225111,112.007502,113.166600,112.529097,111.089883,108.810324,
102.155170,99.605154,100.204021,105.951215,109.071121,109.428509,108.916574,104.048363,108.510890,106.608038,
105.545531,108.481913,106.395536,108.733051,110.317151,111.379658,112.316595,112.442164,110.037036,109.583056,
111.283066,109.534760,110.423402,111.080224,110.800109,108.607482,105.342689,106.202353,105.844965,106.617697,
107.004063,107.515998,107.004063,105.767692,108.298389,107.796113,107.979637,106.453491,108.047251,107.255201,
107.921682,109.892149,109.882489,111.563182,115.021157,111.350680,110.645562,115.204681,116.421734,115.426842,
117.049579,118.392201,117.841629,116.798441,117.436526,116.961193,113.274931,112.634686,112.256359,108.977526,
110.757603,110.287119,113.779367,115.224769,115.729205,114.225599,115.321776,114.497218,114.283803,114.759136,
113.827870,112.799597,111.751923,115.467287,114.739735,114.691232,112.159352,112.692890,109.792384,109.113336,
107.182899,108.007458,105.718095,102.856393,104.117482,104.020475,105.359170,104.796530,103.622747,105.485279,
104.107781,102.109440,102.196746,99.635764,97.685926,93.563134,94.057869,95.580877,96.968075,94.474998,96.541245,
94.222780,93.766848,93.892957,93.417623,98.384375,96.463639,96.997177,90.623825,91.273771,94.426495,93.543732,
91.652098,93.466127,93.708644,91.696830,92.662368,92.642862,91.940652,91.384737,91.667571,94.252091,95.695522,
93.881481,93.666917,94.486161,92.350274,93.725434,94.369126,94.515420,94.300856,98.045972,98.260536,98.992004,
100.464693,99.352862,98.533617,98.621394,98.670158,99.733225,99.986801,101.995899,103.351553,103.185754,103.302789,
103.293036,104.083021,103.507600,103.058966,102.590827,105.019300,106.852847,106.296931,107.272222,108.374300,
107.096670,108.218254,105.858050,105.975085,106.326190,107.711103,109.271568,109.330085,107.135681,104.824242,
104.268327,104.482891,103.351553,103.068719,102.483545,101.771582,95.402934,92.486815,91.423748,91.326219,92.828167,
91.862629,90.936103,90.981767,91.050455,91.668644,90.775704,88.646385,88.823011,92.120021,91.737332,92.787273,
92.434021,93.434899,94.622215,96.064657,97.752412,98.527602,98.468727,97.987913,96.614159,95.888032,96.084282,
96.780972,97.173473,97.085160,97.781850,96.977222,95.515156,95.632906,95.318905,95.721219,93.542837,93.317149,
94.111964,93.758713,94.298402,91.649019,90.314515,91.835457,92.630272,93.807775,94.092339,93.209211,93.739088,
94.141401,94.867529,95.161904,95.593656,95.053967,96.937972,96.928160,97.958475,97.997725,98.086038,97.565974,
96.810409,95.515156,94.857716,101.019984,102.383926,102.256363,104.061868,102.521301,103.806742,103.885243,106.032880,
106.910897,107.344972,106.545878,106.476821,106.723455,108.005951,107.907298,107.749452,107.611337,107.887567,
107.049012,107.384434,106.575474,106.121668,105.500150,105.381766,104.572806,104.671460,105.292978,106.279514,
106.249917,106.901031,104.099269,101.741448,104.020346,106.496551,110.265119,114.013955,113.372707,112.050749,
112.040883,112.021153,113.076746,111.192462,111.360173,111.567346,112.415767,110.669598,111.527885,111.005021,
111.478558,111.527885,112.356575,112.524286,114.487492,114.734126,115.760124,115.404971,116.046219,115.967296,
115.888373,115.543086,115.483894,115.030087,116.065950,116.657871,114.033686,112.938631,112.188864,112.011287,
109.988889,110.087542,108.351239,107.931825,109.488725,110.133301,109.954803,106.890586,107.525246,104.827942,
106.216260,109.072229,109.032563,109.141645,110.797711,110.867126,110.301883,110.857210,110.639046,110.529963,
109.597807,108.576401,108.982980,108.199572,109.032563,110.103551,111.184456,112.999187,112.354610,114.228840,
114.228840,114.853583,115.002331,115.666741,115.974154,116.083236,115.319661,115.547742,116.281568,115.785740,
115.755990,114.853583,115.180830,115.051914,115.636991,116.926144,117.997132,118.116131,118.750791,118.254963,
118.046715,118.998705,118.988788,118.780540,118.998705,119.078037,118.968955,120.863018,120.922517,120.932434,
120.615104,120.337440,127.675694,127.457529,128.002940,129.202844,130.432497,130.938241,131.315071,131.581537,
132.746769,134.469718,134.957721,134.793393,135.166865,136.142871,136.551200,135.973564,136.103034,136.371934,
136.431689,139.220278,138.393660,139.210318,138.772112,138.951378,138.433497,138.114801,138.572927,138.632682,
138.423538,139.887547,140.116610,139.419462,140.883471,139.270074,140.843634,140.345672,140.066813,140.305835,
143.213935,143.532630,143.343405,143.074505,143.114342,144.179981,143.433038,143.074505,142.755809,142.586502,
141.052778,141.222086,140.475142,141.251963,140.624531,140.106650,141.859477,141.690170,143.054587,143.950919,
143.065343,143.203975,143.064546,146.002523,146.908814,146.460648,145.932808,148.352905,152.376439,153.332527,
152.635380,153.322568,156.100000,155.700000,155.470000,150.250000,152.540000,152.960000,153.990000,153.800000,
153.340000,153.870000,153.610000,153.670000,152.760000,153.180000,155.450000,153.930000,154.450000,155.370000,
154.990000,148.980000,145.320000,146.590000)
v.train = prices[1:500]
v.test = prices[-(1:500)]
sampleSize = 30
getJoinedLaggedData = function(dataVector){
data.len = length(dataVector)
result = NULL
for (i in seq(1, data.len - sampleSize + 1, by = 1)) {
result = rbind(result, dataVector[seq(i, i + sampleSize - 1, by = 1)])
}
list(input=matrix(result[,-sampleSize], ncol = sampleSize - 1), output = matrix(result[,sampleSize]))
}
m.train = getJoinedLaggedData(v.train)
#build the model
indata = tf$placeholder(tf$float32, shape(471, sampleSize - 1))
outData = tf$placeholder(tf$float32, shape(471, 1))
w1 <- tf$Variable(tf$truncated_normal(shape(sampleSize - 1, 300),stddev = 1.0 / sqrt(sampleSize)))
b1 = tf$Variable(tf$zeros(shape(300)))
l1 <- tf$matmul(indata, w1) + b1
w2 <- tf$Variable(tf$truncated_normal(shape(300, 50),stddev = 1.0 / sqrt(300)))
b2 <- tf$Variable(tf$zeros(shape(50)))
l2 <- tf$matmul(l1, w2) + b2
w3 <- tf$Variable(tf$truncated_normal(shape(50, 1),stddev = 1.0 / sqrt(50)))
b3 <- tf$Variable(tf$zeros(shape(1)))
pred <- tf$matmul(l2, w3) + b3
#loss function
loss <- tf$reduce_mean(tf$abs(tf$sub(x = pred, y = outData)))
#trainer
optimizer <- tf$train$GradientDescentOptimizer(0.00003)
train.op <- optimizer$minimize(loss)
#run the model
init = tf$initialize_all_variables()
sess = tf$Session()
sess$run(init)
for (i in 1:1000){
values = sess$run(list(train.op, loss, pred), feed_dict = dict(indata = m.train$input, outData = m.train$output))
loss_value = values[[2]]
pred_value = values[[3]]
print(loss_value)
}
myLoss = mean(abs(pred_value - m.train$output))
print(sprintf("Tensorflow Loss: %f, My Loss: %f", loss_value, myLoss))