无法获得下一批嵌入索引

时间:2018-06-18 06:01:09

标签: python class tensorflow

第一个模块minibatch。

import numpy as np
import tensorflow as tf
import Utils.neighbor_samplers as samplers

class Minibatch:

def __init__(self, embedding, batch_size):

    self.embedding = embedding
    self.batch_idx = [i for i in range(self.embedding.shape[0])]
    self.batch_size = batch_size
    self.iters = 0

    self.batch_permutation = np.random.permutation(self.batch_idx) #List
    self.next_batch = []
    self.max_iters = embedding.shape[0] // self.batch_size


def next_batch(self):
    self.start_idx = self.iters * self.batch_size
    self.iters += 1
    self.end_idx = self.start_idx + self.batch_size

    #next_batch = tf.nn.embedding_lookup(embedding, [i for i in range(start_idx, end_idx)])
    self.next_idx = self.batch_permutation[self.start_idx : self.end_idx]
    #next_neighbors = tf.nn.embedding_lookup(self.embedding, next_idx)



def shuffle(self):

    self.batch_permutation = np.random.permutation(self.batch_idx)
    self.batch_num = 0

和第二个模块模型。

    def train(self):
    batch = minibatch.Minibatch(self.normal_embedding, self.batch_size)

    for epoch in range(self.epoch):
        batch.shuffle()

        print('Epoch : %04d' %(epoch + 1))

        for iter in range(batch.max_iters):

            if iter % 100 == 0 and iter != 0:
                print('%d iters done' %(iter))


            next_idx = batch.next_batch

使用最后一行next_idx = batch.next_batch,我想让下一批索引在嵌入中查找。但它一直显示空列表。

0 个答案:

没有答案