对于转移学习,人们经常使用网络作为特征提取器来创建特征数据集,在其上训练另一个分类器(例如SVM)。
我想使用数据集API(tf.contrib.data
)和dataset.map()
来实现这一点:
# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
dataset = inputs(...) # This creates a dataset of (image, label) pairs
def map_example(image, label):
features = feature_extractor(image, trainable=False)
# Leaving out initialization from a checkpoint here...
return features, label
dataset = dataset.map(map_example)
return dataset
在为数据集创建迭代器时执行此操作失败。
ValueError: Cannot capture a stateful node by value.
这是事实,网络的内核和偏见是变量,因此是有状态的。对于这个特殊的例子,他们不必这样做。
有没有办法让Ops和特别是tf.Variable
个对象无状态?
由于我使用tf.layers
我不能简单地将它们创建为常量,并且设置trainable=False
不会创建常量,但只是不会将变量添加到GraphKeys.TRAINABLE_VARIABLES
集合。
答案 0 :(得分:13)
不幸的是,tf.Variable
本质上是有状态的。但是,只有在使用Dataset.make_one_shot_iterator()
创建迭代器时才会出现此错误。*为了避免此问题,您可以使用Dataset.make_initializable_iterator()
,但需要注意的是,您还必须在iterator.initializer
上运行tf.Variable
在运行输入管道中使用的Dataset.make_one_shot_iterator()
对象的初始化程序后返回迭代器。
*此限制的原因是Defun
的实现细节以及它用于封装数据集定义的正在进行的TensorFlow函数(ProcessPoolExecutor
)支持。由于使用查找表和变量之类的有状态资源比我们最初想象的更受欢迎,我们正在研究如何放宽这种限制。