我有一个Tensorflow DatasetV1Adapter
对象形式的数据集。
<DatasetV1Adapter shapes: OrderedDict([(labels, (6,)), (snippets, ())]), types: OrderedDict([(labels, tf.int32), (snippets, tf.string)])>
# Example Output
OrderedDict([('labels', <tf.Tensor: id=37, shape=(6,), dtype=int32, numpy=array([0, 0, 0, 0, 0, 0], dtype=int32)>), ('snippets', <tf.Tensor: id=38, shape=(), dtype=string, numpy=b'explanationwhy the edits made under my username hardcore metallica fan were reverted they werent vandalisms just closure on some gas after i voted at new york dolls fac and please dont remove the template from the talk page since im retired now892053827'>)])
OrderedDict([('labels', <tf.Tensor: id=41, shape=(6,), dtype=int32, numpy=array([0, 0, 0, 0, 0, 0], dtype=int32)>), ('snippets', <tf.Tensor: id=42, shape=(), dtype=string, numpy=b'daww he matches this background colour im seemingly stuck with thanks talk 2151 january 11 2016 utc'>)])
如您所见,它包含一个OrderedDict
对象,其键为labels
和snippets
。后者基本上很重要,因为它包含我希望使用句子嵌入将其转换为向量的文本字符串。
为此,我决定使用tensorflow集线器中的Universal Sentence Encoder(使用)。它基本上接受一个句子列表作为输入,并将输出长度为512的向量作为其输出。要注意的一件事是,如果启用了急切执行,则无法在期间执行tensorflow hub。因此,我们必须定义一个会话才能将USE与tensorflow hub一起使用。
但是,我希望使用tensorflow提供的map
。但是问题出在我应该如何定义其中包含张量流会话的函数?为了使用该函数并将其映射到数据集,我是否需要定义另一个张量流会话?
我的第一个方法是实际做到这一点。具体来说,通过定义一个包含张量流会话的函数。然后,启动一个新的tensorflow会话,并尝试将该函数映射到该会话中的数据集。
请注意,我在会话外部定义了USE句子嵌入模型。
# Sentence embedding model (USE)
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
def to_vec(w):
x = w['snippets']
with tf.Session() as sess:
vector = sess.run(embed(x))
return vector
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
# try_data is the DatasetV1Adapter object
sess.run(try_data.map(to_vec))
但是最后我得到了这个错误
RuntimeError: Module must be applied in the graph it was instantiated for.
或者,我尝试在tensorflow会话中定义函数,就像这样
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
def to_vec(w):
x = w['snippets']
vector = sess.run(embed(x))
return vector
sess.run(try_data.map(to_vec))
但是那没有用,我仍然遇到同样的错误。经过一番搜索之后,我偶然发现了this post和this post,说我必须定义一个tf.Graph
并在会话中传递它。
graph = tf.Graph()
with graph.as_default():
with tf.Session(graph=graph) as sess:
sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
def to_vec(w):
x = w['snippets']
vector = sess.run(embed(x))
return vector
sess.run(try_data.map(to_vec))
但是,我仍然收到相同的错误。我还尝试在会话内定义USE,但仍然会导致相同的错误。
从那里,我对如何执行此操作感到非常困惑。有人对我错过的事情有任何想法吗?预先感谢。
答案 0 :(得分:0)
简短的回答:您没有。 Tensorflow将在图模式下调用传递给Dataset.map
的函数(它仅调用一次函数并为每个示例使用结果图,因此您不必担心可能运行与hub相关的准备工作(下载等))。
我对Tensorflow Hub不太熟悉,但是请尝试以下方法。
def map_fn(inputs):
snippets = inputs['snippets']
# you -may- be able to pull the line below outside of map_fn
# it probably won't affect performance
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
vector = embed(snippets)
return vector
dataset = dataset.map(map_fn)