Keras在哪里运作`slice_X()`go?

时间:2017-03-17 08:22:48

标签: python-2.7 slice keras

我有一个代码(see here)使用slice_X()(位置:keras.engine.training)修改常见的TensorBoard后端(这将导致OOM(GPU使用))以便制作在GPU使用期间使用TensorBoard。 不幸的是,由于我升级到Keras 2.0.0,这将导致导入错误,因为keras.engine.training不再包含slice_X()。 它去了哪里?可能的替代解决方案是什么?

非常感谢你的帮助。

E D I T:

我已将代码(see here)更新为Keras 2.0.0和Tensorflow r1.0。

class TensorBoard(keras.callbacks.Callback):
    '''
    Avoids OOM problem.
    Adapted by: https://github.com/Vladimir-Yashin/keras/blob/13e6a1f99f33a3cc7bc0a44d285fda457cc808e4/keras/callbacks.py
    Updated according to discussion:
    http://stackoverflow.com/questions/42852495/where-did-keras-function-slice-x-go/42855104?noredirect=1#42855104

Tensorboard basic visualizations.
This callback writes a log for TensorBoard, which allows
you to visualize dynamic graphs of your training and test
metrics, as well as activation histograms for the different
layers in your model.
TensorBoard is a visualization tool provided with TensorFlow.
If you have installed TensorFlow with pip, you should be able
to launch TensorBoard from the command line:
```
tensorboard --logdir=/full_path_to_your_logs
```
You can find more information about TensorBoard
[here](https://www.tensorflow.org/versions/master/how_tos/summaries_and_tensorboard/index.html).
# Arguments
    log_dir: the path of the directory where to save the log
        files to be parsed by Tensorboard
    histogram_freq: frequency (in epochs) at which to compute activation
        histograms for the layers of the model. If set to 0,
        histograms won't be computed.
    write_graph: whether to visualize the graph in Tensorboard.
        The log file can become quite large when
        write_graph is set to True.
'''

def __init__(self, log_dir='./logs', histogram_freq=0, write_graph=True, write_images=False):
    super(BatchedTensorBoard, self).__init__()
    if K._BACKEND != 'tensorflow':
        raise RuntimeError('TensorBoard callback only works '
                           'with the TensorFlow backend.')
    self.log_dir = log_dir
    self.histogram_freq = histogram_freq
    self.merged = None
    self.write_graph = write_graph
    self.write_images = write_images
    #print(dir(self))

def set_model(self, model):
    import tensorflow as tf
    import keras.backend.tensorflow_backend as KTF

    self.model = model
    self.sess = KTF.get_session()
    if self.histogram_freq and self.merged is None:
        for layer in self.model.layers:

            for weight in layer.weights:
                tf.summary.histogram(weight.name, weight)

                if self.write_images:
                    w_img = tf.squeeze(weight)

                    shape = w_img.get_shape()
                    if len(shape) > 1 and shape[0] > shape[1]:
                        w_img = tf.transpose(w_img)

                    if len(shape) == 1:
                        w_img = tf.expand_dims(w_img, 0)

                    w_img = tf.expand_dims(tf.expand_dims(w_img, 0), -1)

                    tf.image_summary(weight.name, w_img)

            if hasattr(layer, 'output'):
                tf.summary.histogram('{}_out'.format(layer.name),
                                     layer.output)
    if parse_version(tf.__version__) >= parse_version('0.12.0'):
        self.merged = tf.summary.merge_all()
    else:
        self.merged = tf.merge_all_summaries()
    if self.write_graph:
        if parse_version(tf.__version__) >= parse_version('0.12.0'):
            self.writer = tf.summary.FileWriter(self.log_dir,
                                                self.sess.graph)
        elif parse_version(tf.__version__) >= parse_version('0.8.0'):
            self.writer = tf.train.SummaryWriter(self.log_dir,
                                                 self.sess.graph)
        else:
            self.writer = tf.train.SummaryWriter(self.log_dir,
                                                 self.sess.graph_def)
    else:
        if parse_version(tf.__version__) >= parse_version('0.12.0'):
            self.writer = tf.summary.FileWriter(self.log_dir)
        else:
            self.writer = tf.train.SummaryWriter(self.log_dir)

def on_epoch_end(self, epoch, logs={}):
    import tensorflow as tf
    from keras.engine.training import _slice_arrays #original: from keras.engine.training import slice_X
    tf_session = K.get_session()
    #result = []

    if self.validation_data and self.histogram_freq:
        if epoch % self.histogram_freq == 0:
            if self.model.uses_learning_phase:
                cut_v_data = len(self.model.inputs)
                val_data = self.validation_data[:cut_v_data] + [0]
                tensors = self.model.inputs + [K.learning_phase()]
            else:
                val_data = self.validation_data
                tensors = self.model.inputs
            # Sample one batch of validation data to avoid OOM on GPU
            if 'batch_size' in self.params:
                index_array = np.arange(len(val_data[0]))
                batch_ids = np.random.choice(index_array, self.params['batch_size'])
                if self.model.uses_learning_phase:
                    ins_batch = _slice_arrays(val_data[:-1], batch_ids) + [val_data[-1]] #original: slice_X(val_data[:-1], batch_ids) + [val_data[-1]]
                else:
                    ins_batch = _slice_arrays(val_data, batch_ids) #original: slice_X(val_data, batch_ids)
            else:
                # Generators yield one batch at a time and don't provide batch_size
                ins_batch = val_data
            my_feed_dict = dict(zip(tensors, ins_batch))

            result = tf_session.run([self.merged], feed_dict=my_feed_dict)
            #result = self.sess.run([self.merged], feed_dict=my_feed_dict)
            summary_str = result[0]
            self.writer.add_summary(summary_str, epoch)

    for name, value in logs.items():
        if name in ['batch', 'size']:
            continue
        summary = tf.Summary()
        summary_value = summary.value.add()
        summary_value.simple_value = value.item()
        summary_value.tag = name
        self.writer.add_summary(summary, epoch)
    self.writer.flush()

def on_train_end(self, _):
    self.writer.close()

1 个答案:

答案 0 :(得分:4)

似乎slice_X()不再存在,但keras.engine.training_slice_array()中有一个内部函数可以完成切片工作。请参阅code here

如果您有其他问题,请不要犹豫。

编辑:

这是两个功能。 旧的:

def slice_X(X, start=None, stop=None):
    """This takes an array-like, or a list of
    array-likes, and outputs:
        - X[start:stop] if X is an array-like
        - [x[start:stop] for x in X] if X in a list
    Can also work on list/array of indices: `slice_X(x, indices)`
    # Arguments
        start: can be an integer index (start index)
            or a list/array of indices
        stop: integer (stop index); should be None if
            `start` was a list.
    """

新的:

def _slice_arrays(arrays, start=None, stop=None):
    """Slice an array or list of arrays.
    This takes an array-like, or a list of
    array-likes, and outputs:
        - arrays[start:stop] if `arrays` is an array-like
        - [x[start:stop] for x in arrays] if `arrays` is a list
    Can also work on list/array of indices: `_slice_arrays(x, indices)`
    # Arguments
        arrays: Single array or list of arrays.
        start: can be an integer index (start index)
            or a list/array of indices
        stop: integer (stop index); should be None if
            `start` was a list.
    # Returns
        A slice of the array(s).
    """

这里要理解的是,他们基本上只是更改了名称。只需使用相同的参数更改slice_X() _slice_arrays(),即可编辑弗拉基米尔的代码。同时将导入更改为

from keras.engine.training import _slice_arrays

我希望它现在正在运作。