假设我有以下矩阵:
所以,如果我想查找第二行,我应该得到[3 4]。
这就是我使用tensorflow操作实现查找机制的方法:
import tensorflow as tf
import numpy as np
from tensorflow.python.estimator.model_fn import EstimatorSpec
def model_fn_1(features, labels, mode):
x = tf.constant([[1]])
labels = tf.constant([[10.]])
with tf.name_scope('Embedding_Layer'):
m = np.array([[1, 2], [3, 4]], np.float32)
lookup = tf.nn.embedding_lookup(m, x, name='embedding_matrix_1')
lookup = tf.Print(lookup, [lookup])
preds = tf.keras.layers.Dense(1)(lookup)
loss = tf.reduce_mean(labels - preds)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss, tf.train.get_global_step())
eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels, preds)}
return EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
model_1 = tf.estimator.Estimator(model_fn_1)
model_1.train(input_fn=lambda: None, steps=1)
正如预期的那样,培训时lookup
的输出是:
2017-11-08 21:17:49.010728: I C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\36\tensorflow\core\kernels\logging_ops.cc:79] [[[3 4]]]
问题是当我尝试使用Keras图层实现嵌入查找时,我得不到相同的输出:
import tensorflow as tf
import numpy as np
from tensorflow.python.estimator.model_fn import EstimatorSpec
def model_fn(features, labels, mode):
x = tf.constant([[1]])
labels = tf.constant([[10.]])
m = np.array([[1, 2], [3, 4]], np.float32)
with tf.name_scope('Embedding_Layer'):
n = tf.keras.layers.Embedding(2, 2, weights=[m], input_length=1, name='embedding_matrix_1', trainable=False)
lookup = n(x)
lookup = tf.Print(lookup, [lookup])
preds = tf.keras.layers.Dense(1)(lookup)
loss = tf.reduce_mean(labels - preds)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss, tf.train.get_global_step())
eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels, preds)}
return EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)
model = tf.estimator.Estimator(model_fn)
model.train(input_fn=lambda: None, steps=1)
lookup
的输出是一些随机数字,如:
2017-11-08 21:20:59.046951: I C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\36\tensorflow\core\kernels\logging_ops.cc:79] [[[0.532017469 0.885832787]]]
在我看来,Keras实现与tensforlow实现相同。为什么我没有获得相同的输出以及如何修复Keras实现?
答案 0 :(得分:1)
当前不推荐使用weights
,可以改用embeddings_initializer
:
n = tf.keras.layers.Embedding(2, 2,
embeddings_initializer=tf.initializers.constant(m), input_length=1,
name='embedding_matrix_1', trainable=False)
然后keras
的{{1}}的工作原理与Embedding(...)(x)
的{{1}}完全一样。
答案 1 :(得分:-1)
这两个实现之间有什么共同之处吗? embedding_lookup
只是一个查找,它只是根据索引m
从x
中提取信息。但是keras.layers.Embedding
是一个计算。只需查看官方文档:https://keras.io/layers/embeddings/