想象一下,我想训练模型,这可以最大限度地减少图像和查询之间的距离。从一侧我有来自CNN的图像特征,从另一边我有从字到嵌入向量的映射(例如w2v):
def raw_data_generator():
for row in network_data:
yield (row["cnn"], row["w2v_indices"])
dataset = tf.data.Dataset.from_generator(raw_data_generator, (tf.float32, tf.int32))
dataset = dataset.prefetch(1000)
这里我想创建批处理,但我想为cnn功能创建密集批处理,为w2v创建稀疏批处理,因为它显然具有可变长度(我想使用safe_embeddings_lookup_sparse)。密集的批函数和稀疏的 .apply(tf.contrib.data.dense_to_sparse_batch(..))函数,但如何同时使用它们?
答案 0 :(得分:0)
您可以尝试创建两个数据集(每个功能一个),对每个数据集应用适当的批处理,然后将其与tf.data.Dataset.zip一起压缩。
@staticmethod
zip(datasets)
通过将给定数据集压缩在一起来创建数据集。
此方法与内置的zip()函数具有相似的语义 Python,主要区别在于数据集参数可以 是数据集对象的任意嵌套结构。例如:
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3 }
b = { 4, 5, 6 }
c = { (7, 8), (9, 10), (11, 12) }
d = { 13, 14 }
# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }
# The `datasets` argument may contain an arbitrary number of
# datasets.
Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
(2, 5, (9, 10)),
(3, 6, (11, 12)) }
# The number of elements in the resulting dataset is the same as
# the size of the smallest dataset in `datasets`.
Dataset.zip((a, d)) == { (1, 13), (2, 14) }