为什么TensorFlow不支持int64的密集Tensors的增强决策树?

时间:2019-04-23 13:45:46

标签: python numpy tensorflow

相关文件为:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py行236-262

    if isinstance(tensor, sparse_tensor.SparseTensor):
      if tensor.values.dtype == dtypes.float32:
        sparse_float_names.append(key)
        sparse_float_indices.append(tensor.indices)
        sparse_float_values.append(tensor.values)
        sparse_float_shapes.append(tensor.dense_shape)
      elif tensor.values.dtype == dtypes.int64:
        sparse_int_names.append(key)
        sparse_int_indices.append(tensor.indices)
        sparse_int_values.append(tensor.values)
        sparse_int_shapes.append(tensor.dense_shape)
      else:
        raise ValueError("Unsupported sparse feature %s with dtype %s." %
                         (tensor.indices.name, tensor.dtype))
    else:
      if tensor.dtype == dtypes.float32:
        if len(tensor.shape) > 1 and tensor.shape[1] > 1:
          unstacked = array_ops.unstack(tensor, axis=1)
          for i in range(len(unstacked)):
            dense_float_names.append(_FEATURE_NAME_TEMPLATE % (key, i))
            dense_floats.append(array_ops.reshape(unstacked[i], [-1, 1]))
        else:
          dense_float_names.append(key)
          dense_floats.append(tensor)
      else:
        raise ValueError("Unsupported dense feature %s with dtype %s." %
                         (tensor.name, tensor.dtype))

我们可以看到支持稀疏的int64张量,但是在开始的else子句中,我们可以看到对于密集张量,仅支持float32类型。为什么会这样?

0 个答案:

没有答案