我的csv数据在这里:https://storage.googleapis.com/download.tensorflow.org/data/abalone_train.csv 我想根据其他专栏预测“年龄”。培训代码在这里:
import pandas as pd
import numpy as np
# Make numpy values easier to read.
np.set_printoptions(precision=3, suppress=True)
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
abalone_train = pd.read_csv("https://storage.googleapis.com/download.tensorflow.org/data/abalone_train.csv", header=None,
names=["Length", "Diameter", "Height", "Whole weight", "Shucked weight","Viscera weight", "Shell weight", "Age"])
abalone_train.head()
abalone_features = abalone_train.copy()
abalone_labels = abalone_features.pop('Age')
abalone_features = np.array(abalone_features)
abalone_features
abalone_model = tf.keras.Sequential([
layers.Dense(64),
layers.Dense(1)
])
abalone_model.compile(loss = tf.losses.MeanSquaredError(),optimizer = tf.optimizers.Adam())
abalone_model.fit(abalone_features, abalone_labels, epochs=10)
输出:
Epoch 1/10 104/104 [=============================]-0s 1ms / step- 损失:63.1474时代2/10 104/104 [=============================]-0s 924us / step-损失:11.8933纪元3/10 104/104 [=============================]-0s 920us / step-损失:8.4037时期 4/10 104/104 [==============================]-0s 885us / step-损耗: 7.9656时代5/10 104/104 [=============================]-0s 900us / step-损耗:7.5481时代6/10 104/104 [=============================]-0s 908us / step-损失:7.2339时代 7/10 104/104 [==============================]-0s 926us / step-损耗: 6.9871时代8/10 104/104 [=============================]-0s 919us / step-损失:6.7886时代9/10 104/104 [=============================]-0s 956us / step-损失:6.6482时期 10/10 104/104 [==============================]-0s 953us / step-损失: 6.5404
现在,我要上传另一个具有空白“年龄”列的csv文件,并查看预测,但我被卡住了。我上了一些课,但直到“时代”为止,所有的课。在“时代”阶段之后,如何导入“空白年龄” csv文件并查看“年龄预测”?
答案 0 :(得分:1)
根据文档(https://www.tensorflow.org/api_docs/python/tf/keras/Sequential#predict),Sequential
对象具有predict
方法。
输入数据可以是:
tf.data
数据集您可以使用abalone_model.predict(YourData)
,其中YourData
是上述数据类型之一。当然,您可以在自己的训练数据上使用predict()
,这可能会导致过拟合。尝试使用脱节的验证或测试集(如果提供)或拆分可用的数据集。
在这里,您可以找到一个很好的例子来解决诸如您面临的回归问题:https://www.tensorflow.org/tutorials/keras/regression