我当前在我的iOS应用中使用Tensorflow的Swift版本。 我的模型工作正常,但是将数据复制到第一个Tensor时遇到了问题,因此我可以使用神经网络检测事物。
我咨询了the testsuite inside the repository,他们的代码工作如下:
他们正在使用一些扩展名:
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
let elements = unsafeData.withUnsafeBytes {
UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
)
}
self.init(elements)
}
}
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}
创建包含数据的数组,以及包含该数组的Data对象:
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
然后,他们将inputData
复制到神经网络中。
我试图修改他们的代码以将图像加载到[1,28,28,1]张量中。 图片看起来像这样:
[[[[Float32(254.0)],
[Float32(255.0)],
[Float32(254.0)],
[Float32(250.0)],
[Float32(252.0)],
[Float32(255.0)],
[Float32(255.0)],
[Float32(255.0)],
[Float32(255.0)],
[Float32(254.0)],
[Float32(214.0)],
[Float32(160.0)],
[Float32(130.0)],
[Float32(124.0)],
[Float32(129.0)],
...
你明白了。
但是,如果我尝试将其与图像数据一起转换为Data / init Data,我就只能得到8个字节:
private func createTestData() -> Data {
return Data(copyingBufferOf:
[[[[Float32(254.0)],
[Float32(255.0)],
[Float32(254.0)],
...
测试中的代码也一样,但是对于它们来说,这没问题(2 * Float32 = 8字节)。 对我来说,这太小了(应该是28 * 28 * 4 = 3136字节)!
答案 0 :(得分:1)
Swift Array
是一个固定大小的结构,具有指向实际元素存储的(不透明的)指针。 withUnsafeBufferPointer()
方法使用指向该元素存储的缓冲区指针来调用给定的闭包。在[Float]
数组的情况下,它是指向浮点值的内存地址的指针。那就是为什么
array.withUnsafeBufferPointer(Data.init)
工作以获取代表浮点数的Data
值。
如果您将嵌套数组(例如类型[[Float]]
)传递给withUnsafeBufferPointer()
方法,则使用指向内部数组Array
结构的指针来调用闭包。因此,元素类型现在不是Float
而是[Float]
–在警告的意义上不是“琐碎的类型”
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
您需要做的是将嵌套数组展平为一个简单数组,然后从该简单数组创建一个Data
值。