第一个模块minibatch。
import numpy as np
import tensorflow as tf
import Utils.neighbor_samplers as samplers
class Minibatch:
def __init__(self, embedding, batch_size):
self.embedding = embedding
self.batch_idx = [i for i in range(self.embedding.shape[0])]
self.batch_size = batch_size
self.iters = 0
self.batch_permutation = np.random.permutation(self.batch_idx) #List
self.next_batch = []
self.max_iters = embedding.shape[0] // self.batch_size
def next_batch(self):
self.start_idx = self.iters * self.batch_size
self.iters += 1
self.end_idx = self.start_idx + self.batch_size
#next_batch = tf.nn.embedding_lookup(embedding, [i for i in range(start_idx, end_idx)])
self.next_idx = self.batch_permutation[self.start_idx : self.end_idx]
#next_neighbors = tf.nn.embedding_lookup(self.embedding, next_idx)
def shuffle(self):
self.batch_permutation = np.random.permutation(self.batch_idx)
self.batch_num = 0
和第二个模块模型。
def train(self):
batch = minibatch.Minibatch(self.normal_embedding, self.batch_size)
for epoch in range(self.epoch):
batch.shuffle()
print('Epoch : %04d' %(epoch + 1))
for iter in range(batch.max_iters):
if iter % 100 == 0 and iter != 0:
print('%d iters done' %(iter))
next_idx = batch.next_batch
使用最后一行next_idx = batch.next_batch,我想让下一批索引在嵌入中查找。但它一直显示空列表。