无法更改KMeansClustering Tensorflow中的群集数量

时间:2019-05-28 08:17:12

标签: python tensorflow batch-processing k-means

我找到了这段代码,它运行完美。这个想法-拆分我的数据并对其进行训练KMeansClustering。因此,我创建了InitHook和迭代器,并将其用于培训。

class _IteratorInitHook(tf.train.SessionRunHook):
    """Hook to initialize data iterator after session is created."""

    def __init__(self):
        super(_IteratorInitHook, self).__init__()
        self.iterator_initializer_fn = None

    def after_create_session(self, session, coord):
        """Initialize the iterator after the session has been created."""
        del coord
        self.iterator_initializer_fn(session)


# Run K-means clustering.
def _get_input_fn():
    """Helper function to create input function and hook for training.
    Returns:
        input_fn: Input function for k-means Estimator training.
        init_hook: Hook used to load data during training.
    """
    init_hook = _IteratorInitHook()

    def _input_fn():
        """Produces tf.data.Dataset object for k-means training.
        Returns:
            Tensor with the data for training.
        """
        features_placeholder = tf.placeholder(tf.float32,
                                                my_data.shape)
        delf_dataset = tf.data.Dataset.from_tensor_slices((features_placeholder))
        delf_dataset = delf_dataset.shuffle(1000).batch(
            my_data.shape[0])
        iterator = delf_dataset.make_initializable_iterator()

        def _initializer_fn(sess):
            """Initialize dataset iterator, feed in the data."""
            sess.run(
                iterator.initializer,
                feed_dict={features_placeholder: my_data})

        init_hook.iterator_initializer_fn = _initializer_fn
        return iterator.get_next()

    return _input_fn, init_hook


input_fn, init_hook = _get_input_fn()

output_cluster_dir = 'parameters/clusters'

kmeans = tf.contrib.factorization.KMeansClustering(
    num_clusters=1024,
    model_dir=output_cluster_dir,
    use_mini_batch=False,
)


print('Starting K-means clustering...')
kmeans.train(input_fn, hooks=[init_hook])

但是如果我将num_clusters更改为512或256,则会出现下一个错误:

  

InvalidArgumentError:segment_ids [0] = 600超出范围[0,256)
  [[node UnsortedSegmentSum(定义为   /home/mikhail/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py:1112)   ]] [[节点压缩(定义为   /home/mikhail/.conda/envs/tf2/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py:1112)   ]]

看起来我在将数据拆分为批次时遇到了一些问题,或者即使我设置了另一个值,我的KMeans在默认情况下也使用1024个群集!

我不知道要进行哪些更改才能使其正常工作。 追溯非常大,如果需要,我可以将其附加为文件。

1 个答案:

答案 0 :(得分:0)

我发现了问题: 如您所见,我将密码本保存到parameters/clusters。当它创建了tensorflow时,也在这里保存图形。 因此,张量流的默认行为-如果已经存在,则不要创建新图!

因此,每次我尝试运行KMeansClustering时,它仍然使用从密码本加载的图形。 我每次运行clusters时都删除了文件夹KMeansClustering,从而解决了这个问题。

我仍然有一些问题:我创建了新集群,并并行启动2个脚本来使用它创建功能:其中一个为旧的代码簿创建,另一个为新的代码簿创建!仍然强制执行此操作,但是我的建议是在创建新的代码本后重新启动所有功能(也许某些信息仍在tensorflow中加载)。