如何使用 tensorflow 数据集训练 sklearn 模型?

时间:2021-02-09 12:03:48

标签: tensorflow scikit-learn tensorflow2.0 tensorflow-datasets

我想知道是否可以使用 Tensorflow 数据集来训练 scikit-learn 和其他 ML 框架。

那么,例如,我可以使用 tf.data.dataset 来训练 xgboost、LogisticReg、RandomForest 分类器等吗? 即我可以将 tf.data.dataset 对象传递到这些模型的 .fit() 方法中进行训练吗?

我试过了:

    xs=np.asarray([i for i in range(10000)]).reshape(-1, 1)
    ys=np.asarray([int(i%2==0)for i in range(10000)])
    
    xs = tf.data.Dataset.from_tensor_slices(xs)
    ys = tf.data.Dataset.from_tensor_slices(ys)
    cls.fit(xs, ys)

我收到以下错误:

    TypeError: float() argument must be a string or a number, not 'TensorSliceDataset'

1 个答案:

答案 0 :(得分:0)

您可以使用as_numpy_iterator()方法;来自docs

<块引用>

返回一个迭代器,它将数据集的所有元素转换为 numpy。

按照您的示例:

from sklearn.svm import SVC

x = list(xs.as_numpy_iterator())
y = list(ys.as_numpy_iterator())

clf = SVC(gamma='auto')

clf.fit(x, y)