我是tensorflow的新手。但我认为对tesnorflow核心操作的理解是必须的。如果我们使用面向对象方式的tf python API,我们可以创建不同的图形操作作为定义。
def _create_placeholders(self):
""" Step 1: define the placeholders for input and output """
with tf.name_scope("data"):
self.center_words = tf.placeholder(tf.int32, shape=[self.batch_size], name='center_words')
print("Extracting the op",self.center_words.op)
self.target_words = tf.placeholder(tf.int32, shape=[self.batch_size, 1], name='target_words')
print("so",self.center_words.op)
def _create_embedding(self):
""" Step 2: define weights. In word2vec, it's actually the weights that we care about """
# Assemble this part of the graph on the CPU. You can change it to GPU if you have GPU
with tf.device('/cpu:0'):
with tf.name_scope("embed"):
self.embed_matrix = tf.Variable(tf.random_uniform([self.vocab_size,
self.embed_size], -1.0, 1.0),
name='embed_matrix')
def _create_loss(self):
""" Step 3 + 4: define the model + the loss function """
with tf.device('/cpu:0'):
with tf.name_scope("loss"):
# Step 3: define the inference
embed = tf.nn.embedding_lookup(self.embed_matrix, self.center_words, name='embed')
# Step 4: define loss function
# construct variables for NCE loss
nce_weight = tf.Variable(tf.truncated_normal([self.vocab_size, self.embed_size],
stddev=1.0 / (self.embed_size ** 0.5)),
name='nce_weight')
nce_bias = tf.Variable(tf.zeros([VOCAB_SIZE]), name='nce_bias')
# define loss function to be NCE loss function
self.loss = tf.reduce_mean(tf.nn.nce_loss(weights=nce_weight,
biases=nce_bias,
labels=self.target_words,
inputs=embed,
num_sampled=self.num_sampled,
num_classes=self.vocab_size), name='loss')
这里我提到了两个用于创建嵌入和计算损失的定义。 因此,如果我使用 _create_loss()运行其中一个def,它将在图中创建一个节点。我查看了源代码,我在图形构建阶段看到的是在那个阶段它会将每个操作加载到某种缓冲区。 然后在会话期间,我们只使用真实数据重新运行每个和所有内容。
with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(os.path.dirname('c/checkpointsq'))
# if that checkpoint exists, restore from checkpoint
if ckpt and ckpt.model_checkpoint_path:
print("Restoring the checkpoins")
saver.restore(sess, ckpt.model_checkpoint_path)
total_loss = 0.0 # we use this to calculate late average loss in the last SKIP_STEP steps
writer = tf.summary.FileWriter('./ improved_graph/lr' + str(LEARNING_RATE), sess.graph)
initial_step = model.global_step.eval()
for index in range(1):
centers, targets = batch_gen.__next__()
feed_dict={model.center_words: centers, model.target_words: targets}
loss_batch, _, summary = sess.run([model.loss, model.optimizer, model.summary_op],
feed_dict=feed_dict)
这是我的问题。在sess.run中,tensorflow甚至不关心python API。它只关心从上面的图形初始化代码加载的图形操作。 我的问题是在会话对象中执行所有这些操作的位置。我能理解它的核心。我们有权访问吗?