为什么trainable_variables训练后不更改?

时间:2019-12-24 00:47:55

标签: tensorflow2.0 eager-execution

我过去了 a basic example of tf2.0

包含非常简单的代码

from __future__ import absolute_import, division, print_function, unicode_literals
import os

import tensorflow as tf

import cProfile

# Fetch and format the mnist data
(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data()

dataset = tf.data.Dataset.from_tensor_slices(
  (tf.cast(mnist_images[...,tf.newaxis]/255, tf.float32),
   tf.cast(mnist_labels,tf.int64)))
dataset = dataset.shuffle(1000).batch(32)

# Build the model
mnist_model = tf.keras.Sequential([
  tf.keras.layers.Conv2D(16,[3,3], activation='relu',
                         input_shape=(None, None, 1)),
  tf.keras.layers.Conv2D(16,[3,3], activation='relu'),
  tf.keras.layers.GlobalAveragePooling2D(),
  tf.keras.layers.Dense(10)
])

for images,labels in dataset.take(1):
    print("Logits: ", mnist_model(images[0:1]).numpy())

optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

loss_history = []


def train_step(model, images, labels):

    with tf.GradientTape() as tape:
        logits = model(images, training=True)

        # Add asserts to check the shape of the output.
        tf.debugging.assert_equal(logits.shape, (32, 10))

        loss_value = loss_object(labels, logits)

    loss_history.append(loss_value.numpy().mean())
    grads = tape.gradient(loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))


def train(epochs):
  for epoch in range(epochs):
    for (batch, (images, labels)) in enumerate(dataset):
      train_step(mnist_model, images, labels)
    print ('Epoch {} finished'.format(epoch))

我对它进行了训练,并按照以下内容在前后保存了trainable_variables


t0=mnist_model.trainable_variables  
train(epochs = 3)
t1=mnist_model.trainable_variables
diff = tf.reduce_mean(tf.abs(t0[0] - t1[0])) 
# whethere indexing [0] or [1] etc. gets the same outcome of diff
print(diff.numpy())

他们是一样的!!! 我要检查一些不正确的东西吗?如果是这样,我如何正确观察这些更新的变量?

1 个答案:

答案 0 :(得分:0)

您不是在创建新的变量数组,而是在同一对象上创建2个指针 尝试这样做

t0 = np.array(mnist_model.trainable_variables)