我正在尝试使用weighted_cross_entropy_with_logits来执行具有四个功能的标准逻辑回归(pos_weight = 1)。似乎该代码仅对一项功能有效,而对多项功能无效。你能帮我找出我做错了什么吗?
# -*- coding: utf-8 -*-
import h5py
import random
import tensorflow as tf
import numpy as np
from sklearn.model_selection import KFold
from imblearn.over_sampling import SMOTE
import time
time_start=time.time()
from matplotlib import pyplot as plt
tf.reset_default_graph()
# Parametes
clss = 1 # clss = 1 for binary classification
batch_size = 32
seed = 1 # random seed (kept unchanged for reproducibility)
#eps_beta = np.finfo(float).eps # epsilon to avoid num over 0
#itermax = 10**(4) # the upper limit of iterations
training_epochs = 10000
learning_rate = 0.01
tf.set_random_seed(seed)
# Step_1: read in data from .mat file
data_path = "test10.mat"# read in data test10.mat
test = h5py.File(data_path)
x = test['x'][:].T
y = test['y'][:].T
x = x.astype(np.float32) # features (float32)
y = y.astype(np.float32) # label (float32)
n_samples, n_features = x.shape
# tf Graph Input
inputs = tf.placeholder(tf.float32, [None, n_features]) # X is n_features matrix
outputs = tf.placeholder(tf.float32, [None, clss]) # y is loanStatus
weights = tf.placeholder(tf.float32, [n_features, clss])
lamb = tf.placeholder(tf.float32,[])
# Set model weights
beta = tf.Variable(tf.truncated_normal([n_features, clss])) # beta (W) -- decision variables
beta0 = tf.Variable(tf.truncated_normal([clss])) # beta0 (b) -- intercept term
pred = tf.matmul(inputs,beta)+beta0
# cost function
cost_fun = tf.nn.weighted_cross_entropy_with_logits(
targets = outputs,
logits = pred,
pos_weight = 1
)
# Using Adam with learning rate lr= to minimize cost_fun
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost_fun)# should be smaller when itermax increase
#initialize the necessary variables, in this case, beta and beta0
init= tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
idx = random.sample(range(x.shape[0]),batch_size)
xs = x[idx,:]
ys = y[idx]
_, c = sess.run([optimizer, cost_fun],feed_dict={inputs:xs,outputs:ys})
beta0Value = sess.run(beta0)
betaValue = sess.run(beta)
print("beta", betaValue)
print("beta0", beta0Value)
数据已包含在https://drive.google.com/open?id=1WitKP4q_9VnyMohbl3BfQUJlDYXqKoVs
中