以w a s d和相应的屏幕图像的组合形式收集游戏输入后,当我尝试平衡数据时,会出现一些问题。原始代码只有3个输入,只有w,a或d。我把它扩展到了9种可能性,比如aw,sd或者nokeys。平衡数据的一部分是具有相同长度的所有输入向量。但这是它似乎出错的地方。原始代码已注释掉。
平衡代码:
# balance_data.py
import numpy as np
import pandas as pd
from collections import Counter
from random import shuffle
import sys
train_data = np.load('training_data-1.npy')
df = pd.DataFrame(train_data)
print(df.head())
print(Counter(df[1].apply(str)))
##lefts = []
##rights = []
##forwards = []
##
##shuffle(train_data)
##
##for data in train_data:
## img = data[0]
## choice = data[1]
##
## if choice == [1,0,0]:
## lefts.append([img,choice])
## elif choice == [0,1,0]:
## forwards.append([img,choice])
## elif choice == [0,0,1]:
## rights.append([img,choice])
## else:
## print('no matches')
##
##
##forwards = forwards[:len(lefts)][:len(rights)]
##lefts = lefts[:len(forwards)]
##rights = rights[:len(forwards)]
##
##final_data = forwards + lefts + rights
##shuffle(final_data)
w = []
a = []
d = []
s = []
wa = []
wd = []
sd = []
sa = []
nk = []
shuffle(train_data)
for data in train_data:
img = data[0]
choice = data[1]
print(choice)
if choice == [0,1,0,0]:
w.append([img,choice])
elif choice == [1,0,0,0]:
a.append([img,choice])
elif choice == [0,0,1,0]:
d.append([img,choice])
elif choice == [0,0,0,1]:
s.append([img,choice])
elif choice == [1,1,0,0]:
wa.append([img,choice])
elif choice == [0,1,1,0]:
wd.append([img,choice])
elif choice == [0,0,1,1]:
sd.append([img,choice])
elif choice == [1,0,0,1]:
sa.append([img,choice])
elif choice == [0,0,0,0]:
nk.append([img,choice])
else:
print('no matches')
min_length = 10000
print (len(w))
print (len(a))
print (len(d))
print (len(s))
print (len(wa))
print (len(wd))
print (len(sd))
print (len(sa))
print (len(nk))
if len(w) < min_length:
min_length = len(w)
if len(a) < min_length:
min_length = len(a)
if len(d) < min_length:
min_length = len(d)
if len(s) < min_length:
min_length = len(s)
if len(wa) < min_length:
min_length = len(wa)
if len(wd) < min_length:
min_length = len(wd)
if len(sd) < min_length:
min_length = len(sd)
if len(sa) < min_length:
min_length = len(sa)
w = w[min_length]
a = a[min_length]
d = d[min_length]
s = s[min_length]
wa = wa[min_length]
wd = wd[min_length]
sd = sd[min_length]
sa = sa[min_length]
nk = nk[min_length]
final_data = w + a + d + s + wa + wd + sd + sa + nk
shuffle(final_data)
np.save('training_data-1-balanced.npy', final_data)
矢量长度和误差后。
9715
920
510
554
887
1069
132
128
6085
Traceback (most recent call last):
File "C:\Users\StefBrands\Documents\GitHub\pygta5 - Copy\balance_data.py", line 115, in <module>
sa = sa[min_length]
IndexError: list index out of range
所以现在主要有两件事: 我在某个地方犯了错误,可能是的:) 2.有更好的平衡方式吗?
答案 0 :(得分:1)
您没有考虑列表长度与其最大索引之间的差异 - 例如,列表[0, 5, 1]
的长度为3,但最大索引为2.因此,您应该减少计算min_length
乘以1。
我们可以大大提高计算效率。从if if len(w) < min_length...
到final_data = ...
的行可以替换为以下内容:
key_lists = (w, a, d, s, wa, wd, sd, sa, nk)
min_length = min(len(x)-1 for x in key_lists)
final_data = sum(x[min_length] for x in key_lists)
我们创建一个包含每个键的每个列表的元组。然后我们可以使用生成器表达式来查找min_length
,然后再次对值进行求和。这样做的好处是,如果添加了一个额外的键组合,我们可以将其列表变量附加到key_lists
。