对于这种情况下的skip_window,我感到困惑。显然,它不等于batch_size。另外,我不知道第二个for循环的作用-如果skip_window是一个标量,那么如何像第二个嵌套的for循环一样将target放置在targets_to_avoid中。有人可以向我总结此功能吗?
from collections import deque
def generate_batch(batch_size, num_skips, skip_window):
global data_index
# these conditions must hold true
assert batch_size % num_skips == 0
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=[batch_size], dtype=np.int32)
labels = np.ndarray(shape=[batch_size, 1], dtype=np.int32)
span = 2 * skip_window + 1
buffer = deque(maxlen=span)
for _ in range(span):
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
for i in range(batch_size // num_skips):
target = skip_window # target label at the center of the buffer
targets_to_avoid = [skip_window]
for j in range(num_skips):
while target in targets_to_avoid:
target = np.random.randint(0, span)
targets_to_avoid.append(target)
batch[i * num_skips + j] = buffer[skip_window]
labels[i * num_skips + j, 0] = buffer[target]
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
return batch, labels