
时间:2018-08-08 23:08:45

标签: python tensorflow neural-network gradient-descent





def get_inputs(self, output_vals):

    self.saver.restore(self.sess, "./model/model.ckpt")
    print("Model Restored")

    # normalize the values to work with the net
    desired_values = self.data_ob.normalize_one(output_vals, 1)
    # self.x_noise = tf.Variable(tf.constant(np.float32(self.data_ob.test_x[24].reshape(1, 6))))
    self.x_noise = tf.Variable(tf.random_normal([1, self.num_input]))
    self.de_ = self.generator(self.x_noise)
    self.de_loss = tf.reduce_mean(tf.square(desired_values-self.de_))
    self.d_optim = tf.train.GradientDescentOptimizer(learning_rate=.05).minimize(self.de_loss, var_list=[self.x_noise])
    # initialize appropriate values
    cost = 1; itern = 0
    while (cost > .0001):
        _, cost, y_ = self.sess.run([self.d_optim, self.de_loss, self.de_])
        itern += 1
        if itern % 1000 == 0:
            # denormalizes output values so I can more clearly see what is being generated
            desired = self.data_ob.denorm_one(y_, 1);
            print("Cost: ", cost, " Current Prediction: ", desired, " Using inputs: ", self.sess.run(self.x_noise))

def generator(self, inputs, reuse=False):
    with tf.variable_scope("generator") as scope:
        if reuse:scope.reuse_variables()

        hidden_layer1 = tf.nn.sigmoid(tf.add(tf.matmul(inputs, self.weights['dh1']), self.biases['dh1']))
        hidden_layer2 = tf.nn.sigmoid(tf.add(tf.matmul(hidden_layer1, self.weights['dh2']), self.biases['dh2']))
        hidden_layer3 = tf.nn.sigmoid(tf.add(tf.matmul(hidden_layer2, self.weights['dh3']), self.biases['dh3']))
        out_layer = tf.add(tf.matmul(hidden_layer3, self.weights['dout']), self.biases['dout'])
        return out_layer

def initialize_uninitialized_vars(self, sess):
    from itertools import compress
    global_vars = tf.global_variables()
    not_initialized = sess.run([~(tf.is_variable_initialized(var)) for var in global_vars])
    not_initialized_vars = list(compress(global_vars, not_initialized))
    if (len(not_initialized_vars)):


Cost:  0.15683977  Current Prediction:  [[8.24184   6.2319293]]  Using inputs:  [[-1.3438836  -0.29747388  0.15460747  0.5450186  -2.0147917  -0.1637771 ]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.3440028  -0.29753348  0.15459257  0.54543585 -2.0145533  -0.16362809]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.344122   -0.2975931   0.15457767  0.5458531  -2.014315   -0.16347907]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.3442413  -0.2976527   0.15456277  0.5462703  -2.0140765  -0.16333006]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.3443605  -0.2977123   0.15454787  0.54668754 -2.013838   -0.16318105]]
Cost:  0.15683973  Current Prediction:  [[8.241841  6.2319293]]  Using inputs:  [[-1.3444797  -0.2977719   0.15453297  0.5471048  -2.0135996  -0.16303204]]
Cost:  0.15683973  Current Prediction:  [[8.241842 6.231929]]  Using inputs:  [[-1.3445989  -0.2978315   0.15451807  0.547522   -2.0133612  -0.16288303]]
Cost:  0.15683973  Current Prediction:  [[8.241842 6.231929]]  Using inputs:  [[-1.3447181  -0.2978911   0.15450317  0.54793924 -2.0131228  -0.16273402]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3448373  -0.2979507   0.15448827  0.5483565  -2.0128844  -0.162585  ]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3449565  -0.29801032  0.15447336  0.5487737  -2.012646   -0.162436  ]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3450757  -0.29806992  0.15445846  0.54919094 -2.0124075  -0.16228698]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3451949  -0.29812953  0.15444356  0.5496082  -2.0121691  -0.16213797]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3453141  -0.29818913  0.15442866  0.5500254  -2.0119307  -0.16198896]]
Cost:  0.15683973  Current Prediction:  [[8.241842  6.2319293]]  Using inputs:  [[-1.3454334  -0.29824874  0.15441376  0.55044264 -2.0116923  -0.16183995]]
Cost:  0.1568397  Current Prediction:  [[8.241843  6.2319293]]  Using inputs:  [[-1.3455526  -0.29830834  0.15439886  0.55085987 -2.0114539  -0.16169094]]
Cost:  0.1568397  Current Prediction:  [[8.241843 6.23193 ]]  Using inputs:  [[-1.3456718  -0.29836795  0.15438396  0.5512771  -2.0112154  -0.16154192]]
Cost:  0.1568397  Current Prediction:  [[8.241843 6.23193 ]]  Using inputs:  [[-1.345791   -0.29842755  0.15436906  0.55169433 -2.010977   -0.16139291]]
Cost:  0.1568397  Current Prediction:  [[8.241843 6.23193 ]]  Using inputs:  [[-1.3459102  -0.29848716  0.15435416  0.55211157 -2.0107386  -0.1612439 ]]
Cost:  0.1568397  Current Prediction:  [[8.241843 6.23193 ]]  Using inputs:  [[-1.3460294  -0.29854676  0.15433925  0.5525288  -2.0105002  -0.16109489]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3461486  -0.29860637  0.15432435  0.55294603 -2.0102618  -0.16094588]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3462678  -0.29866597  0.15430945  0.55336326 -2.0100234  -0.16079687]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.346387   -0.29872558  0.15429455  0.5537805  -2.009785   -0.16064785]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3465062  -0.29878518  0.15427965  0.5541977  -2.0095465  -0.16049884]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3466254  -0.29884478  0.15426475  0.55461496 -2.009308   -0.16034983]]
Cost:  0.15683967  Current Prediction:  [[8.241844 6.23193 ]]  Using inputs:  [[-1.3467447  -0.2989044   0.15424985  0.5550322  -2.0090697  -0.16020082]]
Cost:  0.15683964  Current Prediction:  [[8.241845 6.23193 ]]  Using inputs:  [[-1.3468639  -0.298964    0.15423495  0.5554494  -2.0088313  -0.16005181]]
Cost:  0.15683961  Current Prediction:  [[8.241846 6.23193 ]]  Using inputs:  [[-1.3469831  -0.2990236   0.15422004  0.55586666 -2.0085928  -0.1599028 ]]
Cost:  0.15683961  Current Prediction:  [[8.241846 6.23193 ]]  Using inputs:  [[-1.3471023  -0.2990832   0.15420514  0.5562839  -2.0083544  -0.15975378]]
Cost:  0.15683961  Current Prediction:  [[8.241846 6.23193 ]]  Using inputs:  [[-1.3472215  -0.2991428   0.15419024  0.5567011  -2.008116   -0.15960477]]
Cost:  0.15683961  Current Prediction:  [[8.241846 6.23193 ]]  Using inputs:  [[-1.3473407  -0.2992024   0.15417534  0.55711836 -2.0078776  -0.15945576]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3474599  -0.29926202  0.15416044  0.5575356  -2.0076392  -0.15930675]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3475791  -0.29932162  0.15414554  0.5579528  -2.0074008  -0.15915774]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3476983  -0.29938123  0.15413064  0.55837005 -2.0071623  -0.15900873]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3478175  -0.29944083  0.15411574  0.5587873  -2.006924   -0.15885971]]
Cost:  0.15683961  Current Prediction:  [[8.241847 6.23193 ]]  Using inputs:  [[-1.3479367  -0.29950044  0.15410084  0.5592045  -2.0066855  -0.1587107 ]]
Cost:  0.15683953  Current Prediction:  [[8.241849 6.23193 ]]  Using inputs:  [[-1.348056   -0.29956004  0.15408593  0.55962175 -2.006447   -0.15856169]]
Cost:  0.15683953  Current Prediction:  [[8.241849 6.23193 ]]  Using inputs:  [[-1.3481752  -0.29961964  0.15407103  0.560039   -2.0062087  -0.15841268]]
Cost:  0.15683953  Current Prediction:  [[8.241849 6.23193 ]]  Using inputs:  [[-1.3482944  -0.29967925  0.15405613  0.5604562  -2.0059702  -0.15826367]]
Cost:  0.15683953  Current Prediction:  [[8.241849  6.2319293]]  Using inputs:  [[-1.3484136  -0.29973885  0.15404123  0.56087345 -2.0057318  -0.15811466]]




0 个答案:
