不建议使用的colocate_with警告

时间:2019-10-17 09:11:07

标签: python tensorflow autoencoder

我收到以下警告:

  

警告:tensorflow:来自C:\ Users \ hadas \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ op_def_library.py:263:colocate_with(来自tensorflow.python.framework.ops)已被弃用在将来的版本中删除。   更新说明:   托管服务器由托管服务器自动处理。

这是我的代码,您能告诉我该更改什么吗?谢谢!

class Autoencoder:
    def __init__(self, D, d):
        self.X = tf.placeholder(tf.float32, shape=(None, D))

        # Input to hidden (D -> d)
        self.W1 = tf.Variable(tf.random_normal(shape=(D,d)))
        self.b1 = tf.Variable(np.zeros(d).astype(np.float32))

        # Hidden -> output (d -> D)
        self.W2 = tf.Variable(tf.random_normal(shape=(d,D)))
        self.b2 = tf.Variable(np.zeros(D).astype(np.float32))

        # Output
        self.Z = tf.nn.relu(tf.add(tf.matmul(self.X, self.W1), self.b1))
        logits = tf.add(tf.matmul(self.Z, self.W2), self.b2)
        self.X_hat = tf.nn.sigmoid(logits)

        # Define loss function            
        self.loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.X, logits=logits))

        self.optimizer = tf.train.RMSPropOptimizer(learning_rate=0.005).minimize(self.loss)

        self.init_op = tf.global_variables_initializer()
        self.sess = tf.get_default_session() 
        if(self.sess == None):
            self.sess = tf.Session()
        self.sess.run(self.init_op)


    def fit(self, X, epochs=10, bs=64):
        n_batches = len(X) // bs
        print("Training {} batches".format(n_batches))

        for i in range(epochs):
            print("Epoch: ", i)
            X_perm = np.random.permutation(X)
            for j in range(n_batches):
                batch = X_perm[j*bs:(j+1)*bs]
                _, _ = self.sess.run((self.optimizer, self.loss),
                                    feed_dict={self.X: batch})


    def predict(self, X):
        return self.sess.run(self.X_hat, feed_dict={self.X: X})

    def encode(self, X):
        return self.sess.run(self.Z, feed_dict={self.X: X})

    def decode(self, Z):
        return self.sess.run(self.X_hat, feed_dict={self.Z: Z})

    def terminate(self):
        self.sess.close()
        del self.sess

0 个答案:

没有答案