如何对tf.feature_column.embedding_column
使用预训练的嵌入。
我在pre_trained
中使用了tf.feature_column.embedding_column
嵌入。但这是行不通的。错误是
错误是:
ValueError:如果已指定,则初始化方法必须是可调用的。嵌入column_name:itemx
这是我的代码:
weight, vocab_size, emb_size = _create_pretrained_emb_from_txt(FLAGS.vocab,
FLAGS.pre_emb)
W = tf.Variable(tf.constant(0.0, shape=[vocab_size, emb_size]),
trainable=False, name="W")
embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, emb_size])
embedding_init = W.assign(embedding_placeholder)
sess = tf.Session()
sess.run(embedding_init, feed_dict={embedding_placeholder: weight})
itemx_vocab = tf.feature_column.categorical_column_with_vocabulary_file(
key='itemx',
vocabulary_file=FLAGS.vocabx)
itemx_emb = tf.feature_column.embedding_column(itemx_vocab,
dimension=emb_size,
initializer=W,
trainable=False)
我尝试了初始化= lambda w:W。像这样:
itemx_emb = tf.feature_column.embedding_column(itemx_vocab,
dimension=emb_size,
initializer=lambda w:W,
trainable=False)
它报告错误:
TypeError:()获得了意外的关键字参数'dtype'
答案 0 :(得分:1)
我在这里https://github.com/tensorflow/tensorflow/issues/20663也有问题
最后,我找到了解决之道。虽然。我不清楚为什么上面的答案无效!如果您知道问题,谢谢给我一些建议!
好的~~~~这是当前的解决方法。实际上是从这里Feature Columns Embedding lookup
代码:
itemx_vocab = tf.feature_column.categorical_column_with_vocabulary_file(
key='itemx',
vocabulary_file=FLAGS.vocabx)
embedding_initializer_x = tf.contrib.framework.load_embedding_initializer(
ckpt_path='model.ckpt',
embedding_tensor_name='w_in',
new_vocab_size=itemx_vocab.vocabulary_size,
embedding_dim=emb_size,
old_vocab_file='FLAGS.vocab_emb',
new_vocab_file=FLAGS.vocabx
)
itemx_emb = tf.feature_column.embedding_column(itemx_vocab,
dimension=128,
initializer=embedding_initializer_x,
trainable=False)
答案 1 :(得分:0)
您还可以将数组包装成如下函数:
some_matrix = np.array([[0,1,2],[0,2,3],[5,6,7]])
def custom_init(shape, dtype):
return some_matrix
embedding_feature = tf.feature_column.embedding_column(itemx_vocab,
dimension=3,
initializer=custom_init
)
这是一种骇人听闻的方法,但可以完成工作。