这是我的代码
for i,val in enumerate(DS3Y_pred_trans):
if val < 1.5:
DS3Y_pred_trans[i] = 1
else:
DS3Y_pred_trans[i] = 2
列表中的值小于1.5
,但out为2
个。
我错过了什么?
这是整个代码。
from numpy import genfromtxt
DS3X_train = np.genfromtxt('train.csv', dtype=float, delimiter=',')
print DS3X_train
DS3Y_train = np.genfromtxt('train_labels.csv', dtype=int, delimiter=',' )
print DS3Y_train
DS3X_test = np.genfromtxt('test.csv', dtype=float, delimiter=',')
print DS3X_test
DS3Y_test = np.genfromtxt('test_labels.csv', dtype=int, delimiter=',' )
print DS3Y_test
DS3X_train_trans = zip(*DS3X_train)
cov_train = np.cov(DS3X_train_trans)
U, s, V = np.linalg.svd(cov_train, full_matrices=True)
u = U[:,:-1]
u_trans = zip(*u)
DS3X_train_reduced = np.dot(u_trans,DS3X_train_trans)
b = np.ones((3,2000))
b[1:,:] = DS3X_train_reduced
print "\n"
DS3X_train_reduced = b
DS3X_train_reduced_trans = zip(*DS3X_train_reduced)
temp = np.dot(DS3X_train_reduced,DS3X_train_reduced_trans)
try:
inv_temp = np.linalg.inv(temp)
except np.linalg.LinAlgError:
pass
else:
psue_inv = np.dot(inv_temp,DS3X_train_reduced)
print psue_inv.shape
weight = np.dot(psue_inv,DS3Y_train)
weight_trans = zip(weight)
print weight_trans
DS3X_test_trans = zip(*DS3X_test)
DS3X_test_reduced = np.dot(u_trans,DS3X_test_trans)
b = np.ones((3,400))
b[1:,:] = DS3X_test_reduced
print "\n"
print b
DS3X_test_reduced = b
print DS3X_test_reduced.shape
DS3X_test_reduced_trans = zip(*DS3X_test_reduced)
DS3Y_pred = np.dot(DS3X_test_reduced_trans,weight_trans)
print DS3Y_pred
print DS3Y_pred.shape
DS3Y_pred_trans = zip(DS3Y_pred)
print repr(DS3Y_pred_trans[0])
for i,val in enumerate(DS3Y_pred_trans):
if val < 1.5:
DS3Y_pred_trans[i] = 1
else:
DS3Y_pred_trans[i] = 2
print DS3Y_pred
now regression using indicator variable and graph plottings
答案 0 :(得分:5)
您的值不是数字。在Python 2中,数字在其他对象之前排序,因此在将val
与1.5
进行比较时,比较始终为false。
你可能有字符串:
>>> '1.0' < 1.5
False
>>> 1.0 < 1.5
True
如果是这样,请先将值转换为浮点数:
for i, val in enumerate(DS3Y_pred_trans):
if float(val) < 1.5:
DS3Y_pred_trans[i] = 1
else:
DS3Y_pred_trans[i] = 2
可能是您仍然在列表中存储其他对象;您需要仔细查看列表中 的内容并相应地调整代码,或首先修复列表的创建方式。
由于您无论如何都要替换所有值,您可以使用列表解析:
DS3Y_pred_trans = [1 if float(val) < 1.5 else 2 for val in DS3Y_pred_trans]