我尝试按照https://arxiv.org/pdf/1703.07737.pdf中的描述实施“批量硬”批次,以便使用三联体丢失。因此输入的形状为[batch_size,32],输出应该是表示三元组的列表,因此[[batch_size,32],[batch_size,32],[batch_size,32]]当每个单个示例的大小(32,)时
我用以下函数实现了这个,所以基本上使用tf.map_fn:
def batch_hard(inputs):
"""
Batch Hard triplets as described in https://arxiv.org/pdf/1703.07737.pdf.
For each sample in input the hardest positive and hardest negative
in the given batch will be selected. A triplet is returned.
"""
class_ids, f_anchor = inputs[0], inputs[1]
def body(x):
class_id, f = x[0], x[1]
same_class = tf.equal(class_ids, class_id)
positive = same_class
negative = tf.logical_not(same_class)
positive = tf.squeeze(positive)
negative = tf.squeeze(negative)
positive.set_shape([None])
negative.set_shape([None])
samples_pos = tf.boolean_mask(f_anchor, positive)
samples_neg = tf.boolean_mask(f_anchor, negative)
# Select hardest positive example
distances = euclidean_distance(samples_pos, f)
hardest_pos = samples_pos[tf.argmax(distances)]
# Select hardest negative example
distances = euclidean_distance(samples_neg, f)
hardest_neg = samples_neg[tf.argmin(distances)]
return [hardest_pos, hardest_neg]
[f_pos, f_neg] = tf.map_fn(body, inputs, dtype=[tf.float32, tf.float32])
return [f_anchor, f_pos, f_neg]
当我只执行正向传递时没有指定train_op,这非常有效。但是,当我添加此行train_op = optimizer.minimize(loss, global_step=global_step)
时,会发生以下错误:
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gradients_impl.py", line 348, in _MaybeCompile
xla_compile = op.get_attr("_XlaCompile")
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2003, in get_attr
raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
ValueError: No attr named '_XlaCompile' in name: "map/while/strided_slice"
op: "StridedSlice"
input: "map/while/boolean_mask/Gather"
input: "map/while/strided_slice/stack"
input: "map/while/strided_slice/stack_1"
input: "map/while/strided_slice/Cast"
attr {
key: "Index"
value {
type: DT_INT64
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "begin_mask"
value {
i: 0
}
}
attr {
key: "ellipsis_mask"
value {
i: 0
}
}
attr {
key: "end_mask"
value {
i: 0
}
}
attr {
key: "new_axis_mask"
value {
i: 0
}
}
attr {
key: "shrink_axis_mask"
value {
i: 1
}
}
有没有人知道出了什么问题?
此问题的完整示例是https://gist.github.com/anonymous/0b5e9194ebf09be7ad2f0a740bf369b8
编辑:似乎问题出现在这些行中
hardest_pos = samples_pos[tf.argmax(distances)]
用
之类的东西替换它hardest_pos = tf.zeros(32)
没有错误,但如何解决?