为Word编码器生成批处理

时间:2019-05-22 15:26:31

标签: python-3.x machine-learning

对于这种情况下的skip_window,我感到困惑。显然,它不等于batch_size。另外,我不知道第二个for循环的作用-如果skip_window是一个标量,那么如何像第二个嵌套的for循环一样将target放置在targets_to_avoid中。有人可以向我总结此功能吗?

from collections import deque

def generate_batch(batch_size, num_skips, skip_window): 
    global data_index
    # these conditions must hold true
    assert batch_size % num_skips == 0
    assert num_skips <= 2 * skip_window
    batch = np.ndarray(shape=[batch_size], dtype=np.int32)
    labels = np.ndarray(shape=[batch_size, 1], dtype=np.int32)
    span = 2 * skip_window + 1
    buffer = deque(maxlen=span)
    for _ in range(span):
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)
    for i in range(batch_size // num_skips):
        target = skip_window # target label at the center of the buffer
        targets_to_avoid = [skip_window]
        for j in range(num_skips):
            while target in targets_to_avoid:
                target = np.random.randint(0, span)
            targets_to_avoid.append(target)
            batch[i * num_skips + j] = buffer[skip_window]
            labels[i * num_skips + j, 0] = buffer[target]
        buffer.append(data[data_index])
        data_index = (data_index + 1) % len(data)
    return batch, labels

0 个答案:

没有答案