Tensorflow 2坐标分类器

时间:2019-05-16 13:32:49

标签: python tensorflow machine-learning keras classification

我是尝试机器学习的新手。我看到了这个仓库https://github.com/jbp261/Optimal-Classification-Model-of-BLE-RSSI-Dataset,并想复制一个类似的实验。

因此,我有2个接收器,并且想要对Rssi的给定值最接近的接收器进行分类。我捕获了一些训练数据,并定义了区域0(信标1附近)和区域1(信标2附近)。

我用keras建立了一个模型(也可以使用RandomForest尝试,但效果很好),但是即使以0.8的准确度评估基础训练数据,我也会得到50%的错误预测。

batch_size = 100

#reading the input samples and separating the input from the outputs
dataframe = pd.read_csv("C:\aaa\Log.csv")
labels = dataframe.pop('result')

#creating the dataset from the data
ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
ds = ds.batch(batch_size)

feature_columns = []
headers = dataframe.columns.tolist()

# numeric cols
for header in headers:
  temp = feature_column.numeric_column(header)
  #feature_columns.append(feature_column.bucketized_column(temp, boundaries=[-70, -60, -50, -40 , -30])) tried also this
  feature_columns.append(temp)

feature_layer = tf.keras.layers.DenseFeatures(feature_columns)

model = tf.keras.Sequential([
  feature_layer,
  layers.Dense(128, activation='relu'),
  layers.Dense(128, activation='relu'),
  layers.Dense(2, activation='sigmoid')
])

model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

model.fit(ds, epochs=20)


test_ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
test_ds = test_ds.batch(batch_size)

loss, accuracy = model.evaluate(test_ds)
print("Accuracy", accuracy)

2 个答案:

答案 0 :(得分:2)

model.fit()中的

添加一些验证(简单的方法是validation_split=0.5或您想要拆分的任何百分比。)这将获取您的一些数据,将其与训练数据分开,并且仅在时期结束后使用来查看网络如何处理从未有过的数据。这样,您将看到丢失,准确性, validation_loss和validation_accuracy。后两者更好地反映了模型在实际使用中的性能。

一旦开始使用该指标,就可以查看您是否过拟合,或者您对网络所做的更改实际上是否有帮助。

答案 1 :(得分:0)

我认为您希望在回归值内获得2个输出。

请尝试使用relu作为激​​活,mean_squared_error作为丢失。

model = tf.keras.Sequential([
  feature_layer,
  layers.Dense(128, activation='relu'),
  layers.Dense(128, activation='relu'),
  layers.Dense(2, activation='relu')
])

model.compile(optimizer='adam',
              loss='mean_squared_error',
              metrics=['accuracy'])