Tensorflow:训练后尝试生成给定输出向量的输入

时间:2018-08-08 23:08:45

标签: python tensorflow neural-network gradient-descent

我有一个tensorflow模型,可以将6个输入变量映射到2个输出变量,并且精度很高。我现在想做的是提供一种在给定输出时生成输入的方法。

我的想法是生成一个输入大小的噪声向量,并将其输入到用于训练的模型中,但是我没有对权重和偏差进行优化,而是将其保持不变并针对噪声进行优化向量。

我尝试实现这一点,但是它所做的只是生成重新创建类似于第一个输入噪声向量转换为的值的值(即,第一个输入转换为[10.5,23.4]的所有值都与此值有关)。我的成本函数是生成的输出和传递的期望值之间的MSE。

我的代码在下面。

def get_inputs(self, output_vals):

    self.saver.restore(self.sess, "./model/model.ckpt")
    print("Model Restored")

    # normalize the values to work with the net
    desired_values = self.data_ob.normalize_one(output_vals, 1)
    # self.x_noise = tf.Variable(tf.constant(np.float32(self.data_ob.test_x[24].reshape(1, 6))))
    self.x_noise = tf.Variable(tf.random_normal([1, self.num_input]))
    self.de_ = self.generator(self.x_noise)
    self.de_loss = tf.reduce_mean(tf.square(desired_values-self.de_))
    self.d_optim = tf.train.GradientDescentOptimizer(learning_rate=.05).minimize(self.de_loss, var_list=[self.x_noise])
    # initialize appropriate values
    self.initialize_uninitialized_vars(self.sess)
    cost = 1; itern = 0
    while (cost > .0001):
        _, cost, y_ = self.sess.run([self.d_optim, self.de_loss, self.de_])
        itern += 1
        if itern % 1000 == 0:
            # denormalizes output values so I can more clearly see what is being generated
            desired = self.data_ob.denorm_one(y_, 1);
            print("Cost: ", cost, " Current Prediction: ", desired, " Using inputs: ", self.sess.run(self.x_noise))


def generator(self, inputs, reuse=False):
    with tf.variable_scope("generator") as scope:
        if reuse:scope.reuse_variables()

        hidden_layer1 = tf.nn.sigmoid(tf.add(tf.matmul(inputs, self.weights['dh1']), self.biases['dh1']))
        hidden_layer2 = tf.nn.sigmoid(tf.add(tf.matmul(hidden_layer1, self.weights['dh2']), self.biases['dh2']))
        hidden_layer3 = tf.nn.sigmoid(tf.add(tf.matmul(hidden_layer2, self.weights['dh3']), self.biases['dh3']))
        out_layer = tf.add(tf.matmul(hidden_layer3, self.weights['dout']), self.biases['dout'])
        return out_layer

def initialize_uninitialized_vars(self, sess):
    from itertools import compress
    global_vars = tf.global_variables()
    not_initialized = sess.run([~(tf.is_variable_initialized(var)) for var in global_vars])
    not_initialized_vars = list(compress(global_vars, not_initialized))
    if (len(not_initialized_vars)):
        sess.run(tf.variables_initializer(not_initialized_vars))

输出示例:

Cost:  0.15683977  Current Prediction:  [[8.24184   6.2319293]]  Using inputs:  [[-1.3438836  -0.29747388  0.15460747  0.5450186  -2.0147917  -0.1637771 ]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.3440028  -0.29753348  0.15459257  0.54543585 -2.0145533  -0.16362809]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.344122   -0.2975931   0.15457767  0.5458531  -2.014315   -0.16347907]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.3442413  -0.2976527   0.15456277  0.5462703  -2.0140765  -0.16333006]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.3443605  -0.2977123   0.15454787  0.54668754 -2.013838   -0.16318105]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.3444797  -0.2977719   0.15453297  0.5471048  -2.0135996  -0.16303204]]
Cost:  0.15683973  Current Prediction:  [[8.241842 6.231929]]  Using inputs:  [[-1.3445989  -0.2978315   0.15451807  0.547522   -2.0133612  -0.16288303]]
Cost:  0.15683973  Current Prediction:  [[8.241842 6.231929]]  Using inputs:  [[-1.3447181  -0.2978911   0.15450317  0.54793924 -2.0131228  -0.16273402]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3448373  -0.2979507   0.15448827  0.5483565  -2.0128844  -0.162585  ]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3449565  -0.29801032  0.15447336  0.5487737  -2.012646   -0.162436  ]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3450757  -0.29806992  0.15445846  0.54919094 -2.0124075  -0.16228698]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3451949  -0.29812953  0.15444356  0.5496082  -2.0121691  -0.16213797]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3453141  -0.29818913  0.15442866  0.5500254  -2.0119307  -0.16198896]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3454334  -0.29824874  0.15441376  0.55044264 -2.0116923  -0.16183995]]
Cost:  0.1568397  Current Prediction:  [[8.241843  6.2319293]]  Using inputs:  [[-1.3455526  -0.29830834  0.15439886  0.55085987 -2.0114539  -0.16169094]]
Cost:  0.1568397  Current Prediction:  [[8.241843 6.23193 ]]  Using inputs:  [[-1.3456718  -0.29836795  0.15438396  0.5512771  -2.0112154  -0.16154192]]
Cost:  0.1568397  Current Prediction:  [[8.241843 6.23193 ]]  Using inputs:  [[-1.345791   -0.29842755  0.15436906  0.55169433 -2.010977   -0.16139291]]
Cost:  0.1568397  Current Prediction:  [[8.241843 6.23193 ]]  Using inputs:  [[-1.3459102  -0.29848716  0.15435416  0.55211157 -2.0107386  -0.1612439 ]]
Cost:  0.1568397  Current Prediction:  [[8.241843 6.23193 ]]  Using inputs:  [[-1.3460294  -0.29854676  0.15433925  0.5525288  -2.0105002  -0.16109489]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3461486  -0.29860637  0.15432435  0.55294603 -2.0102618  -0.16094588]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3462678  -0.29866597  0.15430945  0.55336326 -2.0100234  -0.16079687]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.346387   -0.29872558  0.15429455  0.5537805  -2.009785   -0.16064785]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3465062  -0.29878518  0.15427965  0.5541977  -2.0095465  -0.16049884]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3466254  -0.29884478  0.15426475  0.55461496 -2.009308   -0.16034983]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3467447  -0.2989044   0.15424985  0.5550322  -2.0090697  -0.16020082]]
Cost:  0.15683964  Current Prediction:  [[8.241845 6.23193 ]]  Using inputs:  [[-1.3468639  -0.298964    0.15423495  0.5554494  -2.0088313  -0.16005181]]
Cost:  0.15683961  Current Prediction:  [[8.241846 6.23193 ]]  Using inputs:  [[-1.3469831  -0.2990236   0.15422004  0.55586666 -2.0085928  -0.1599028 ]]
Cost:  0.15683961  Current Prediction:  [[8.241846 6.23193 ]]  Using inputs:  [[-1.3471023  -0.2990832   0.15420514  0.5562839  -2.0083544  -0.15975378]]
Cost:  0.15683961  Current Prediction:  [[8.241846 6.23193 ]]  Using inputs:  [[-1.3472215  -0.2991428   0.15419024  0.5567011  -2.008116   -0.15960477]]
Cost:  0.15683961  Current Prediction:  [[8.241846 6.23193 ]]  Using inputs:  [[-1.3473407  -0.2992024   0.15417534  0.55711836 -2.0078776  -0.15945576]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3474599  -0.29926202  0.15416044  0.5575356  -2.0076392  -0.15930675]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3475791  -0.29932162  0.15414554  0.5579528  -2.0074008  -0.15915774]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3476983  -0.29938123  0.15413064  0.55837005 -2.0071623  -0.15900873]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3478175  -0.29944083  0.15411574  0.5587873  -2.006924   -0.15885971]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3479367  -0.29950044  0.15410084  0.5592045  -2.0066855  -0.1587107 ]]
Cost:  0.15683953  Current Prediction:  [[8.241849 6.23193 ]]  Using inputs:  [[-1.348056   -0.29956004  0.15408593  0.55962175 -2.006447   -0.15856169]]
Cost:  0.15683953  Current Prediction:  [[8.241849 6.23193 ]]  Using inputs:  [[-1.3481752  -0.29961964  0.15407103  0.560039   -2.0062087  -0.15841268]]
Cost:  0.15683953  Current Prediction:  [[8.241849 6.23193 ]]  Using inputs:  [[-1.3482944  -0.29967925  0.15405613  0.5604562  -2.0059702  -0.15826367]]
Cost:  0.15683953  Current Prediction:  [[8.241849  6.2319293]]  Using inputs:  [[-1.3484136  -0.29973885  0.15404123  0.56087345 -2.0057318  -0.15811466]]

另一个要注意的是,因为输入的值是在0到1之间标准化的,所以我也希望将生成的值也限制在此范围内,而不知道如何执行。

我正在摸索为什么会发生这种情况,因此将不胜感激。

任何帮助将不胜感激

0 个答案:

没有答案