我正在尝试创建一个简单的NN,该文件在tfrecords文件夹中读取。每条记录都有一个1024值的“ mean_rgb”向量和一个类别标签。我正在尝试创建一个简单的前馈NN,以基于此特征向量来学习类别。
({'mean_rgb': <tf.Tensor: id=23, shape=(64, 1024), dtype=float32, numpy=
array([[ 0.9243997 , 0.28990048, -0.4130672 , ..., -0.096692 ,
0.27225342, 0.13346168],
[ 0.5853526 , 0.67050666, -0.24683481, ..., -0.6999033 ,
-0.4100128 , -0.00349384],
[ 0.49572858, 0.5231492 , -0.53445834, ..., 0.0449002 ,
0.10582132, -0.37333965],
...,
[ 0.5776026 , -0.07128889, -0.61762846, ..., 0.22194198,
0.61441416, -0.27355513],
[-0.01848815, 0.20132884, 1.1023484 , ..., 0.06496283,
0.29560333, 0.09157721],
[-0.25877073, -1.9552246 , 0.10309827, ..., 0.22032814,
-0.6812989 , -0.23649289]], dtype=float32)>}
但是,出现以下错误:
InvalidArgumentError(请参阅上面的回溯):重塑的输入是 张量具有65536个值,但请求的形状具有64
我不确定为什么它将输入视为64x1024 = 65536向量而不是(64,1024)向量。当我在生成器中打印下一个项目时,我得到
import UIKit
import CoreBluetooth
//let svcLight = CBUUID.itit(string: "24958294582945")
class BLEViewController: UIViewController , CBCentralManagerDelegate, CBPeripheralDelegate{
func centralManagerDidUpdateState(_ central: CBCentralManager) {
//scan for peripherals if "on" state change
if central.state == CBManagerState.poweredOn{
//not concerned with any services for now, and we are not passing in any options
central.scanForPeripherals(withServices: nil, options: nil)
print("scanning...")
//check for other states, add if else statements
}
}
//handle the callback for when something is found
//diddiscover was the autocomplete word for a peripheral
func centralManager(_ central: CBCentralManager, didDiscover peripheral: CBPeripheral, advertisementData: [String : Any], rssi RSSI: NSNumber) {
if peripheral.name?.contains("POR 1007BT") == true {
//even if the peripheral doesn't have a name, we will get some info for it
print (peripheral.name ?? "no name")
centralManager.stopScan()
print(advertisementData)
//connect to the peripheral now
central.connect(peripheral, options: nil)
//store a local copy of the peripheral in the property
myPeripheral = peripheral
}
}
//so our central can begin scanning again
//if peripheral disconnects for whatever reason, it will immidiately start scanning
func centralManager(_ central: CBCentralManager, didDisconnectPeripheral peripheral: CBPeripheral, error: Error?) {
central.scanForPeripherals(withServices: nil, options: nil)
}
//callback for connecting to a central, didconnect was the autocomplete word
//discovr services
func centralManager(_ central: CBCentralManager, didConnect peripheral: CBPeripheral) {
print("connected \(peripheral.name)")
peripheral.discoverServices(nil)
peripheral.delegate = self
}
//callback for diddiscover services: auto complete: diddiscover
//for each of the services in my peripheral, print out the UUID
//may not need this funciton, need to check
func peripheral(_ peripheral: CBPeripheral, didDiscoverServices error: Error?) {
//optionally binding it
if let services = peripheral.services {
for svc in services {
print(svc.uuid.uuidString)
}
}
}
//! is used to unwrap it so anywhere in the code, it doesn't need to be unwrapped as an optional
var centralManager : CBCentralManager!
//keep a reference/store our peripheral
var myPeripheral : CBPeripheral?
override func viewDidLoad() {
super.viewDidLoad()
// Do any additional setup after loading the view.
//create instance of CBManager
//pass in self as the delegate to handle any callbacks
centralManager = CBCentralManager.init(delegate: self, queue: nil)
}
具有正确的(64、1024)形状
答案 0 :(得分:0)
问题在于features_columns的工作方式,例如,我遇到了类似的问题,我通过重塑解决了问题,这是我的代码的一部分,可以帮助您理解:
定义features_column:
feature_columns = {
'images': tf.feature_column.numeric_column('images', self.shape),
}
然后为模型创建输入:
with tf.name_scope('input'):
feature_columns = list(self._features_columns().values())
input_layer = tf.feature_column.input_layer(
features=features, feature_columns=feature_columns)
input_layer = tf.reshape(
input_layer,
shape=(-1, self.parameters.size, self.parameters.size,
self.parameters.channels))
如果要注意我必须重塑张量的最后一部分,则-1是让Tensorflow找出批量大小
答案 1 :(得分:-1)
我认为问题在于feature_columns = [tf.feature_column.numeric_column(k) for k in ['mean_rgb']]
假定该列是标量-实际上它是1024向量。我必须将shape=1024
添加到numeric_column调用中。还必须删除现有检查点保存的模型。