TensorFlow:具有多个输入的线性回归返回NaN

时间:2017-04-12 22:13:16

标签: python numpy tensorflow

这是我在TensorFlow上的第一次尝试:我正在使用多个输入构建线性回归模型。

问题是结果总是NaN ,我怀疑这是因为我是一个使用numpy和tensorflow进行矩阵运算的完整菜鸟(matlab background hehe)。

以下是代码:

import numpy as np
import tensorflow as tf

N_INP = 2
N_OUT = 1

# Model params
w = tf.Variable(tf.zeros([1, N_INP]), name='w')
b = tf.Variable(tf.zeros([1, N_INP]), name='b')

# Model input and output
x = tf.placeholder(tf.float32, [None, N_INP], name='x')
y = tf.placeholder(tf.float32, [None, N_OUT], name='y')
linear_model = tf.reduce_sum(x * w + b, axis=1, name='out')

# Loss as sum(error^2)
loss = tf.reduce_sum(tf.square(linear_model - y), name='loss')

# Create optimizer
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss, name='train')

# Define training data
w_real = np.array([-1, 4])
b_real = np.array([1, -5])
x_train = np.array([[1, 2, 3, 4], [0, 0.5, 1, 1.5]]).T
y_train = np.sum(x_train * w_real + b_real, 1)[np.newaxis].T
print('Real X:\n', x_train)
print('Real Y:\n', y_train)

# Create session and init parameters
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Training loop
train_data = {x: x_train, y: y_train}
for i in range(1000):
    sess.run(train, train_data)

# Eval solution
w_est, b_est, curr_loss, y_pred = sess.run([w, b, loss, linear_model], train_data)
print("w: %s b: %s loss: %s" % (w_est, b_est, curr_loss))
print("y_pred: %s" % (y_pred,))

这是输出:

Real X:
 [[ 1.   0. ]
 [ 2.   0.5]
 [ 3.   1. ]
 [ 4.   1.5]]
Real Y:
 [[-5.]
 [-4.]
 [-3.]
 [-2.]]

w: [[ nan  nan]] b: [[ nan  nan]] loss: nan
y_pred: [ nan  nan  nan  nan]

1 个答案:

答案 0 :(得分:0)

您需要在<?php include_once("../php/db_connection.php"); include_once("php/admin_session.php"); ?> <?php // ADD NEW PROPERTY $error_msg = ''; if(isset($_POST['title'])){ $title = $_POST['title']; $names = $_POST['name']; $description = $_POST['description']; $price = $_POST['price']; $category = $_POST['category']; $location = $_POST['location']; $map = $_POST['map']; $img = "../rent_images/$names"; if(isset($_FILES['product_image'])){ foreach($_FILES['product_image']['tmp_name'] as $key => $tmp_name ){ $file_name = $key.$_FILES['product_image']['name'][$key]; $file_size =$_FILES['product_image']['size'][$key]; $file_tmp =$_FILES['product_image']['tmp_name'][$key]; $file_type=$_FILES['product_image']['type'][$key]; if($file_size > 2097152){ $error_msg ='File size must be less than 2 MB'; } $desired_dir=$img; if(empty($error_msg)==true){ if(is_dir($desired_dir)==false){ mkdir("$desired_dir", 0755); // Create directory if it does not exist } if(is_dir("$desired_dir/".$file_name)==false){ move_uploaded_file($file_tmp,"$desired_dir/".$file_name); } else{ // rename the file if another one exist $new_dir="$desired_dir/".$file_name.time(); rename($file_tmp,$new_dir) ; } } } $sql_addProduct = mysqli_query($connection, "INSERT INTO rents(id,title,name,description,price,category,location,map,date_added) VALUES('','$title','$names','$description','$price','$category','$location','$map',now())"); if(!$sql_addProduct){ $error_msg = '<div class="atention">Couldn\'t upload the images or property details, please try again</div>'; }else{ header("location: property_list.php"); exit(); } } } ?> <!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <title>Admin Panel</title> <link rel="stylesheet" type="text/css" href="style/bootstrap.css"> <link rel="stylesheet" href="style/style.css" type="text/css" media="screen"/> <script type="text/javascript" src="js/jquery-1.11.3.js"></script> <script type="text/javascript" src="js/bootstrap.min.js"></script> <script type="text/javascript"> $(function(){ var pull = $('#pull'); menu = $('nav ul'); menuHeight = menu.height(); $(pull).on('click', function(e){ e.preventDefault(); menu.slideToggle(); }); }); </script> <script type="text/javascript"> $(window).resize(function(){ var w = $(window).width(); if(w > 320 && menu.is(':hidden')){ menu.removeAttr('style'); } }); </script> </head> <body> <div class="container"> <div> <?php include_once("template/header.php"); ?> </div> <div class="row"> <p><a href="property_list.php"> Click Here Update Property</a></p> <table width="100%" cellpadding="0" cellspacing="0" border="0"> <tr> <td valign="top" width="100%"> <section id="main_content"> <h1>Add Product</h1> <?php echo $error_msg; ?> <form action="" method="post" enctype="multipart/form-data"> <input type="text" class="input-login" name="title" placeholder="Title" maxlength="100"><br><br> <input type="text" class="input-login" name="name" placeholder="Name" maxlength="100"><br><br> <textarea cols="50" rows="10" name="description" placeholder="Detailed description"></textarea><br><br> <input type="text" class="input-login" name="price" placeholder="Price" maxlength="15"><br><br> <input name="category" class="input-login" placeholder="Category" maxlength="50"><br><br> <input name="location" class="input-login" placeholder="location" maxlength="50"><br><br> <textarea cols="50" rows="10" name="map" class="input-login" placeholder="Insert Google map address"></textarea><br><br> <input type="file" name="product_image[]" multiple title="Select property images"><br> <br/> <button id="submitBTN">Add Product</button> </form> </section> </td> </tr> </table> </div> <div> <?php include_once("../template/footer.php"); ?> </div> </div> </body> </html> 的定义中添加keep_dims=True。也就是说,

linear_model

原因是否则结果会“变平”,您无法从中减去linear_model = tf.reduce_sum(x * w + b, axis=1, name='out',keep_dims=True)

例如,

y