我有一个代码段来按批次训练张量流模型。 getChunk
返回BATCH_SIZE
的训练数据,并在文件末尾返回false。我希望日志条目是顺序的,例如,如果总数据集大小为10,批量大小为5,我预计日志文件为1,1,2,2,3,3,依此类推。但是,文件是混合的,我看到像1,2,3,1,3,2这样的订单...我做错了什么。
for i in range(1, N_EPOCHS + 1):
gen = getLine(TMP_DATA_PATH+"/"+"data_train.txt")
tmp_X, tmp_Y=getChunk(gen)
while tmp_X != False:
sess.run(optimizer, feed_dict={X: tmp_X, Y: tmp_Y})
_, acc_train, loss_train = sess.run([pred_softmax, accuracy, loss], feed_dict={X: tmp_X, Y: tmp_Y})
logEntry(str(i)+" "+str(loss_train)+" "+str(acc_train))
tmp_X, tmp_Y=getChunk(gen)
data_train.txt是一个CSV文件,其中包含一系列向量,后跟一个热编码标签。
[[1,2,3,4],[-1,2,3,4],........[-5,2,3,1]],[0,0,0,1]
def getLine(fileName):
with open(fileName) as file:
for i in file:
yield i.rstrip()
def getChunk(genLine):
try:
chunk = [next(genLine) for i in range(BATCH_SIZE)]
assert(len(chunk)==BATCH_SIZE)
X=[]
Y=[]
for c in chunk:
y=c[-30:]
x=c[:-30]
x_list=eval(x)
y=eval(y)
X.append(x_list)
Y.append(y)
X = np.asarray(X)
return X,Y
except:
return False,False
def logEntry(TMP_STRING):
LOG_FILE.write(TMP_STRING)
LOG_FILE.write("\n")
LOG_FILE.flush()
print(TMP_STRING)