Tensorflow map_fn给出错误:ValueError:没有名为'_XlaCompile'的attr

时间:2017-11-30 13:19:00

标签: tensorflow

我尝试按照https://arxiv.org/pdf/1703.07737.pdf中的描述实施“批量硬”批次,以便使用三联体丢失。因此输入的形状为[batch_size,32],输出应该是表示三元组的列表,因此[[batch_size,32],[batch_size,32],[batch_size,32]]当每个单个示例的大小(32,)时

我用以下函数实现了这个,所以基本上使用tf.map_fn:

def batch_hard(inputs):
    """ 
    Batch Hard triplets as described in https://arxiv.org/pdf/1703.07737.pdf.
    For each sample in input the hardest positive and hardest negative
    in the given batch will be selected. A triplet is returned.
    """
    class_ids, f_anchor = inputs[0], inputs[1]

    def body(x):
        class_id, f = x[0], x[1]

        same_class = tf.equal(class_ids, class_id)

        positive = same_class
        negative = tf.logical_not(same_class)

        positive = tf.squeeze(positive)
        negative = tf.squeeze(negative)

        positive.set_shape([None])
        negative.set_shape([None])

        samples_pos = tf.boolean_mask(f_anchor, positive)
        samples_neg = tf.boolean_mask(f_anchor, negative)

        # Select hardest positive example
        distances = euclidean_distance(samples_pos, f)
        hardest_pos = samples_pos[tf.argmax(distances)]

        # Select hardest negative example
        distances = euclidean_distance(samples_neg, f)
        hardest_neg = samples_neg[tf.argmin(distances)]

        return [hardest_pos, hardest_neg]

    [f_pos, f_neg] = tf.map_fn(body, inputs, dtype=[tf.float32, tf.float32])
    return [f_anchor, f_pos, f_neg]

当我只执行正向传递时没有指定train_op,这非常有效。但是,当我添加此行train_op = optimizer.minimize(loss, global_step=global_step)时,会发生以下错误:

Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gradients_impl.py", line 348, in _MaybeCompile
    xla_compile = op.get_attr("_XlaCompile")
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2003, in get_attr
    raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
ValueError: No attr named '_XlaCompile' in name: "map/while/strided_slice"
op: "StridedSlice"
input: "map/while/boolean_mask/Gather"
input: "map/while/strided_slice/stack"
input: "map/while/strided_slice/stack_1"
input: "map/while/strided_slice/Cast"
attr {
  key: "Index"
  value {
    type: DT_INT64
  }
}
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "begin_mask"
  value {
    i: 0
  }
}
attr {
  key: "ellipsis_mask"
  value {
    i: 0
  }
}
attr {
  key: "end_mask"
  value {
    i: 0
  }
}
attr {
  key: "new_axis_mask"
  value {
    i: 0
  }
}
attr {
  key: "shrink_axis_mask"
  value {
    i: 1
  }
}

有没有人知道出了什么问题?

此问题的完整示例是https://gist.github.com/anonymous/0b5e9194ebf09be7ad2f0a740bf369b8

编辑:似乎问题出现在这些行中

hardest_pos = samples_pos[tf.argmax(distances)]

之类的东西替换它
hardest_pos = tf.zeros(32)

没有错误,但如何解决?

0 个答案:

没有答案