Tensorflow批量训练

时间:2017-08-05 03:14:27

标签: tensorflow

我有一个代码段来按批次训练张量流模型。 getChunk返回BATCH_SIZE的训练数据,并在文件末尾返回false。我希望日志条目是顺序的,例如,如果总数据集大小为10,批量大小为5,我预计日志文件为1,1,2,2,3,3,依此类推。但是,文件是混合的,我看到像1,2,3,1,3,2这样的订单...我做错了什么。

for i in range(1, N_EPOCHS + 1):
gen = getLine(TMP_DATA_PATH+"/"+"data_train.txt")
tmp_X, tmp_Y=getChunk(gen)

while tmp_X != False:
    sess.run(optimizer, feed_dict={X: tmp_X, Y: tmp_Y})

    _, acc_train, loss_train = sess.run([pred_softmax, accuracy, loss], feed_dict={X: tmp_X, Y: tmp_Y})

    logEntry(str(i)+" "+str(loss_train)+" "+str(acc_train))
    tmp_X, tmp_Y=getChunk(gen)

data_train.txt是一个CSV文件,其中包含一系列向量,后跟一个热编码标签。

[[1,2,3,4],[-1,2,3,4],........[-5,2,3,1]],[0,0,0,1]

def getLine(fileName):
    with open(fileName) as file:
        for i in file:
            yield i.rstrip()

def getChunk(genLine):
    try:
        chunk = [next(genLine) for i in range(BATCH_SIZE)]
        assert(len(chunk)==BATCH_SIZE)
        X=[]
        Y=[]
        for c in chunk:
            y=c[-30:]
            x=c[:-30]

            x_list=eval(x)
            y=eval(y)

            X.append(x_list)
            Y.append(y)

        X = np.asarray(X)
        return X,Y

    except:
        return False,False

def logEntry(TMP_STRING):
    LOG_FILE.write(TMP_STRING)
    LOG_FILE.write("\n")
    LOG_FILE.flush()
    print(TMP_STRING)

0 个答案:

没有答案