结合使用TF Estimator和TFRecord生成器

时间:2018-08-24 18:57:57

标签: tensorflow tensorflow-datasets tensorflow-estimator

我正在尝试创建一个简单的NN,该文件在tfrecords文件夹中读取。每条记录都有一个1024值的“ mean_rgb”向量和一个类别标签。我正在尝试创建一个简单的前馈NN,以基于此特征向量来学习类别。

({'mean_rgb': <tf.Tensor: id=23, shape=(64, 1024), dtype=float32, numpy=
array([[ 0.9243997 ,  0.28990048, -0.4130672 , ..., -0.096692  ,
         0.27225342,  0.13346168],
       [ 0.5853526 ,  0.67050666, -0.24683481, ..., -0.6999033 ,
        -0.4100128 , -0.00349384],
       [ 0.49572858,  0.5231492 , -0.53445834, ...,  0.0449002 ,
         0.10582132, -0.37333965],
       ...,
       [ 0.5776026 , -0.07128889, -0.61762846, ...,  0.22194198,
         0.61441416, -0.27355513],
       [-0.01848815,  0.20132884,  1.1023484 , ...,  0.06496283,
         0.29560333,  0.09157721],
       [-0.25877073, -1.9552246 ,  0.10309827, ...,  0.22032814,
        -0.6812989 , -0.23649289]], dtype=float32)>}

但是,出现以下错误:

  

InvalidArgumentError(请参阅上面的回溯):重塑的输入是   张量具有65536个值,但请求的形状具有64

我不确定为什么它将输入视为64x1024 = 65536向量而不是(64,1024)向量。当我在生成器中打印下一个项目时,我得到

import UIKit
import CoreBluetooth

//let  svcLight = CBUUID.itit(string: "24958294582945")

class BLEViewController: UIViewController , CBCentralManagerDelegate, CBPeripheralDelegate{
    func centralManagerDidUpdateState(_ central: CBCentralManager) {
        //scan for peripherals if "on" state change
        if central.state == CBManagerState.poweredOn{
            //not concerned with any services for now, and we are not passing in any options
            central.scanForPeripherals(withServices: nil, options: nil)
            print("scanning...")
            //check for other states, add if else statements
        }
    }

    //handle the callback for when something is found
    //diddiscover was the autocomplete word for a peripheral
    func centralManager(_ central: CBCentralManager, didDiscover peripheral: CBPeripheral, advertisementData: [String : Any], rssi RSSI: NSNumber) {

        if peripheral.name?.contains("POR 1007BT") == true {
        //even if the peripheral doesn't have a name, we will get some info for it
            print (peripheral.name ?? "no name")
            centralManager.stopScan()
            print(advertisementData)
            //connect to the peripheral now
            central.connect(peripheral, options: nil)
            //store a local copy of the peripheral in the property
            myPeripheral = peripheral
    }
}
    //so our central can begin scanning again
    //if peripheral disconnects for whatever reason, it will immidiately start scanning
    func centralManager(_ central: CBCentralManager, didDisconnectPeripheral peripheral: CBPeripheral, error: Error?) {
        central.scanForPeripherals(withServices: nil, options: nil)
    }

    //callback for connecting to a central, didconnect was the autocomplete word
    //discovr services
    func centralManager(_ central: CBCentralManager, didConnect peripheral: CBPeripheral) {
        print("connected \(peripheral.name)")
        peripheral.discoverServices(nil)
        peripheral.delegate = self

    }

    //callback for diddiscover services: auto complete: diddiscover
    //for each of the services in my peripheral, print out the UUID
    //may not need this funciton, need to check

    func peripheral(_ peripheral: CBPeripheral, didDiscoverServices error: Error?) {
        //optionally binding it
        if let services  = peripheral.services {
            for svc in services {
                print(svc.uuid.uuidString)
            }
        }
    }


    //! is used to unwrap it so anywhere in the code, it doesn't need to be unwrapped as an optional
    var centralManager : CBCentralManager!
    //keep a reference/store our peripheral
    var myPeripheral : CBPeripheral?

    override func viewDidLoad() {
        super.viewDidLoad()

        // Do any additional setup after loading the view.
   //create instance of CBManager
        //pass in self as the delegate to handle any callbacks
        centralManager = CBCentralManager.init(delegate: self, queue: nil)


    }

具有正确的(64、1024)形状

2 个答案:

答案 0 :(得分:0)

问题在于features_columns的工作方式,例如,我遇到了类似的问题,我通过重塑解决了问题,这是我的代码的一部分,可以帮助您理解:

定义features_column:

feature_columns = {
        'images': tf.feature_column.numeric_column('images', self.shape),
    }

然后为模型创建输入:

        with tf.name_scope('input'):
            feature_columns = list(self._features_columns().values())
            input_layer = tf.feature_column.input_layer(
                features=features, feature_columns=feature_columns)

            input_layer = tf.reshape(
                input_layer,
                shape=(-1, self.parameters.size, self.parameters.size,
                       self.parameters.channels))

如果要注意我必须重塑张量的最后一部分,则-1是让Tensorflow找出批量大小

答案 1 :(得分:-1)

我认为问题在于feature_columns = [tf.feature_column.numeric_column(k) for k in ['mean_rgb']]假定该列是标量-实际上它是1024向量。我必须将shape=1024添加到numeric_column调用中。还必须删除现有检查点保存的模型。