我有一个tensorflow模型,可以将6个输入变量映射到2个输出变量,并且精度很高。我现在想做的是提供一种在给定输出时生成输入的方法。
我的想法是生成一个输入大小的噪声向量,并将其输入到用于训练的模型中,但是我没有对权重和偏差进行优化,而是将其保持不变并针对噪声进行优化向量。
我尝试实现这一点,但是它所做的只是生成重新创建类似于第一个输入噪声向量转换为的值的值(即,第一个输入转换为[10.5,23.4]的所有值都与此值有关)。我的成本函数是生成的输出和传递的期望值之间的MSE。
我的代码在下面。
def get_inputs(self, output_vals):
self.saver.restore(self.sess, "./model/model.ckpt")
print("Model Restored")
# normalize the values to work with the net
desired_values = self.data_ob.normalize_one(output_vals, 1)
# self.x_noise = tf.Variable(tf.constant(np.float32(self.data_ob.test_x[24].reshape(1, 6))))
self.x_noise = tf.Variable(tf.random_normal([1, self.num_input]))
self.de_ = self.generator(self.x_noise)
self.de_loss = tf.reduce_mean(tf.square(desired_values-self.de_))
self.d_optim = tf.train.GradientDescentOptimizer(learning_rate=.05).minimize(self.de_loss, var_list=[self.x_noise])
# initialize appropriate values
self.initialize_uninitialized_vars(self.sess)
cost = 1; itern = 0
while (cost > .0001):
_, cost, y_ = self.sess.run([self.d_optim, self.de_loss, self.de_])
itern += 1
if itern % 1000 == 0:
# denormalizes output values so I can more clearly see what is being generated
desired = self.data_ob.denorm_one(y_, 1);
print("Cost: ", cost, " Current Prediction: ", desired, " Using inputs: ", self.sess.run(self.x_noise))
def generator(self, inputs, reuse=False):
with tf.variable_scope("generator") as scope:
if reuse:scope.reuse_variables()
hidden_layer1 = tf.nn.sigmoid(tf.add(tf.matmul(inputs, self.weights['dh1']), self.biases['dh1']))
hidden_layer2 = tf.nn.sigmoid(tf.add(tf.matmul(hidden_layer1, self.weights['dh2']), self.biases['dh2']))
hidden_layer3 = tf.nn.sigmoid(tf.add(tf.matmul(hidden_layer2, self.weights['dh3']), self.biases['dh3']))
out_layer = tf.add(tf.matmul(hidden_layer3, self.weights['dout']), self.biases['dout'])
return out_layer
def initialize_uninitialized_vars(self, sess):
from itertools import compress
global_vars = tf.global_variables()
not_initialized = sess.run([~(tf.is_variable_initialized(var)) for var in global_vars])
not_initialized_vars = list(compress(global_vars, not_initialized))
if (len(not_initialized_vars)):
sess.run(tf.variables_initializer(not_initialized_vars))
输出示例:
Cost: 0.15683977 Current Prediction: [[8.24184 6.2319293]] Using inputs: [[-1.3438836 -0.29747388 0.15460747 0.5450186 -2.0147917 -0.1637771 ]]
Cost: 0.15683973 Current Prediction: [[8.241841 6.2319293]] Using inputs: [[-1.3440028 -0.29753348 0.15459257 0.54543585 -2.0145533 -0.16362809]]
Cost: 0.15683973 Current Prediction: [[8.241841 6.2319293]] Using inputs: [[-1.344122 -0.2975931 0.15457767 0.5458531 -2.014315 -0.16347907]]
Cost: 0.15683973 Current Prediction: [[8.241841 6.2319293]] Using inputs: [[-1.3442413 -0.2976527 0.15456277 0.5462703 -2.0140765 -0.16333006]]
Cost: 0.15683973 Current Prediction: [[8.241841 6.2319293]] Using inputs: [[-1.3443605 -0.2977123 0.15454787 0.54668754 -2.013838 -0.16318105]]
Cost: 0.15683973 Current Prediction: [[8.241841 6.2319293]] Using inputs: [[-1.3444797 -0.2977719 0.15453297 0.5471048 -2.0135996 -0.16303204]]
Cost: 0.15683973 Current Prediction: [[8.241842 6.231929]] Using inputs: [[-1.3445989 -0.2978315 0.15451807 0.547522 -2.0133612 -0.16288303]]
Cost: 0.15683973 Current Prediction: [[8.241842 6.231929]] Using inputs: [[-1.3447181 -0.2978911 0.15450317 0.54793924 -2.0131228 -0.16273402]]
Cost: 0.15683973 Current Prediction: [[8.241842 6.2319293]] Using inputs: [[-1.3448373 -0.2979507 0.15448827 0.5483565 -2.0128844 -0.162585 ]]
Cost: 0.15683973 Current Prediction: [[8.241842 6.2319293]] Using inputs: [[-1.3449565 -0.29801032 0.15447336 0.5487737 -2.012646 -0.162436 ]]
Cost: 0.15683973 Current Prediction: [[8.241842 6.2319293]] Using inputs: [[-1.3450757 -0.29806992 0.15445846 0.54919094 -2.0124075 -0.16228698]]
Cost: 0.15683973 Current Prediction: [[8.241842 6.2319293]] Using inputs: [[-1.3451949 -0.29812953 0.15444356 0.5496082 -2.0121691 -0.16213797]]
Cost: 0.15683973 Current Prediction: [[8.241842 6.2319293]] Using inputs: [[-1.3453141 -0.29818913 0.15442866 0.5500254 -2.0119307 -0.16198896]]
Cost: 0.15683973 Current Prediction: [[8.241842 6.2319293]] Using inputs: [[-1.3454334 -0.29824874 0.15441376 0.55044264 -2.0116923 -0.16183995]]
Cost: 0.1568397 Current Prediction: [[8.241843 6.2319293]] Using inputs: [[-1.3455526 -0.29830834 0.15439886 0.55085987 -2.0114539 -0.16169094]]
Cost: 0.1568397 Current Prediction: [[8.241843 6.23193 ]] Using inputs: [[-1.3456718 -0.29836795 0.15438396 0.5512771 -2.0112154 -0.16154192]]
Cost: 0.1568397 Current Prediction: [[8.241843 6.23193 ]] Using inputs: [[-1.345791 -0.29842755 0.15436906 0.55169433 -2.010977 -0.16139291]]
Cost: 0.1568397 Current Prediction: [[8.241843 6.23193 ]] Using inputs: [[-1.3459102 -0.29848716 0.15435416 0.55211157 -2.0107386 -0.1612439 ]]
Cost: 0.1568397 Current Prediction: [[8.241843 6.23193 ]] Using inputs: [[-1.3460294 -0.29854676 0.15433925 0.5525288 -2.0105002 -0.16109489]]
Cost: 0.15683967 Current Prediction: [[8.241844 6.23193 ]] Using inputs: [[-1.3461486 -0.29860637 0.15432435 0.55294603 -2.0102618 -0.16094588]]
Cost: 0.15683967 Current Prediction: [[8.241844 6.23193 ]] Using inputs: [[-1.3462678 -0.29866597 0.15430945 0.55336326 -2.0100234 -0.16079687]]
Cost: 0.15683967 Current Prediction: [[8.241844 6.23193 ]] Using inputs: [[-1.346387 -0.29872558 0.15429455 0.5537805 -2.009785 -0.16064785]]
Cost: 0.15683967 Current Prediction: [[8.241844 6.23193 ]] Using inputs: [[-1.3465062 -0.29878518 0.15427965 0.5541977 -2.0095465 -0.16049884]]
Cost: 0.15683967 Current Prediction: [[8.241844 6.23193 ]] Using inputs: [[-1.3466254 -0.29884478 0.15426475 0.55461496 -2.009308 -0.16034983]]
Cost: 0.15683967 Current Prediction: [[8.241844 6.23193 ]] Using inputs: [[-1.3467447 -0.2989044 0.15424985 0.5550322 -2.0090697 -0.16020082]]
Cost: 0.15683964 Current Prediction: [[8.241845 6.23193 ]] Using inputs: [[-1.3468639 -0.298964 0.15423495 0.5554494 -2.0088313 -0.16005181]]
Cost: 0.15683961 Current Prediction: [[8.241846 6.23193 ]] Using inputs: [[-1.3469831 -0.2990236 0.15422004 0.55586666 -2.0085928 -0.1599028 ]]
Cost: 0.15683961 Current Prediction: [[8.241846 6.23193 ]] Using inputs: [[-1.3471023 -0.2990832 0.15420514 0.5562839 -2.0083544 -0.15975378]]
Cost: 0.15683961 Current Prediction: [[8.241846 6.23193 ]] Using inputs: [[-1.3472215 -0.2991428 0.15419024 0.5567011 -2.008116 -0.15960477]]
Cost: 0.15683961 Current Prediction: [[8.241846 6.23193 ]] Using inputs: [[-1.3473407 -0.2992024 0.15417534 0.55711836 -2.0078776 -0.15945576]]
Cost: 0.15683961 Current Prediction: [[8.241847 6.23193 ]] Using inputs: [[-1.3474599 -0.29926202 0.15416044 0.5575356 -2.0076392 -0.15930675]]
Cost: 0.15683961 Current Prediction: [[8.241847 6.23193 ]] Using inputs: [[-1.3475791 -0.29932162 0.15414554 0.5579528 -2.0074008 -0.15915774]]
Cost: 0.15683961 Current Prediction: [[8.241847 6.23193 ]] Using inputs: [[-1.3476983 -0.29938123 0.15413064 0.55837005 -2.0071623 -0.15900873]]
Cost: 0.15683961 Current Prediction: [[8.241847 6.23193 ]] Using inputs: [[-1.3478175 -0.29944083 0.15411574 0.5587873 -2.006924 -0.15885971]]
Cost: 0.15683961 Current Prediction: [[8.241847 6.23193 ]] Using inputs: [[-1.3479367 -0.29950044 0.15410084 0.5592045 -2.0066855 -0.1587107 ]]
Cost: 0.15683953 Current Prediction: [[8.241849 6.23193 ]] Using inputs: [[-1.348056 -0.29956004 0.15408593 0.55962175 -2.006447 -0.15856169]]
Cost: 0.15683953 Current Prediction: [[8.241849 6.23193 ]] Using inputs: [[-1.3481752 -0.29961964 0.15407103 0.560039 -2.0062087 -0.15841268]]
Cost: 0.15683953 Current Prediction: [[8.241849 6.23193 ]] Using inputs: [[-1.3482944 -0.29967925 0.15405613 0.5604562 -2.0059702 -0.15826367]]
Cost: 0.15683953 Current Prediction: [[8.241849 6.2319293]] Using inputs: [[-1.3484136 -0.29973885 0.15404123 0.56087345 -2.0057318 -0.15811466]]
另一个要注意的是,因为输入的值是在0到1之间标准化的,所以我也希望将生成的值也限制在此范围内,而不知道如何执行。
我正在摸索为什么会发生这种情况,因此将不胜感激。
任何帮助将不胜感激