ND4J阵列&他们的形状:将数据放入列表中

时间:2017-12-15 20:43:08

标签: java scala deeplearning4j nd4j

请考虑以下代码,该代码使用ND4J library创建更简单版本的the "moons" test data set

val n = 100
val n1: Int = n/2
val n2: Int = n-n1
val outerX = Nd4j.getExecutioner.execAndReturn(new Cos(Nd4j.linspace(0, Math.PI, n1)))
val outerY = Nd4j.getExecutioner.execAndReturn(new Sin(Nd4j.linspace(0, Math.PI, n1)))
val innerX = Nd4j.getExecutioner.execAndReturn(new Cos(Nd4j.linspace(0, Math.PI, n2))).mul(-1).add(1)
val innerY = Nd4j.getExecutioner.execAndReturn(new Sin(Nd4j.linspace(0, Math.PI, n2))).mul(-1).add(1)
val X: INDArray = Nd4j.vstack(
  Nd4j.concat(1, outerX, innerX), // 1 x n
  Nd4j.concat(1, outerY, innerY)  // 1 x n
) // 2 x n
val y: INDArray = Nd4j.hstack(
  Nd4j.zeros(n1), // 1 x n1
  Nd4j.ones(n2)   // 1 x n2
) // 1 x n
println(s"# y shape: ${y.shape().toList}")                        // 1x100
println(s"# y data length: ${y.data().length()}")                 // 100
println(s"# X shape: ${X.shape().toList}")                        // 2x100
println(s"# X row 0 shape: ${X.getRow(0).shape().toList}")        // 1x100
println(s"# X row 1 shape: ${X.getRow(1).shape().toList}")        // 1x100
println(s"# X row 0 data length: ${X.getRow(0).data().length()}") // 200    <- !
println(s"# X row 1 data length: ${X.getRow(1).data().length()}") // 100

在第二行到最后一行,X.getRow(0).data().length()令人惊讶地是200而不是100.在检查时,这是因为data()返回的结构包含整个矩阵,即两行连接。

如何将X矩阵的实际第一行变为Java(或Scala)List?我可以只采用200元素“第一行”中的前100项,但这看起来并不优雅。

2 个答案:

答案 0 :(得分:2)

.data()给你一排直线。 请参阅:http://nd4j.org/tensor

数组的形状只是底层数据缓冲区的视图。 我通常不建议你在没有充分理由的情况下做你想做的事情。所有数据都存储在堆中。那份副本很贵。

在堆上做任何数学都是不好的。这里唯一的用例是集成。我建议尽可能直接在阵列上操作。从序列化到索引的所有内容都是为您处理的。

如果真的需要它进行某种集成,请使用番石榴,你可以在一行中完成: Doubles.asList(arr.data()DUP()asDouble());

其中arr是你操作的ndarray。

答案 1 :(得分:0)

是的,事实证明.data()与ND4J并不是真正重要的东西。对于我想做的事情,这有点可耻:编写不真正依赖ND4J的单元测试及其内部处理数据的方式。

作为此问题的另一个示例,请考虑以下代码:

import org.nd4j.linalg.factory.Nd4j

object foo extends App {

  val x = Nd4j.create(Array[Double](1,2, 3,4, 5,6), Array(3,2))
  // 1,2
  // 3,4
  // 5,6
  println(x)
  val xArr = x.data().asDouble().toList
  // 1,2,  3,4,  5,6 - row-wise
  println(xArr)

  val w = Nd4j.create(Array[Double](10,20,30, 40,50,60), Array(2,3))
  // 10,20,30
  // 40,50,60
  println(w)
  val wArr = w.data().asDouble().toList
  // 10,20,30,  40,50,60 - row-wise
  println(wArr)

  val wx = w.mmul(x)
  /*
   *  10,20,30   1,2     10*1+20*3+30*5  10*2+20*4+30*6      220  280
   *  40,50,60   3,4  =  40*1+50*3+60*5  40*2+50*4+60*6  =   490  640
   *             5,6
   */
  println(wx)
  val wxArr = wx.data().asDouble().toList
  // 220, 490,  280, 640 - column-wise
  println(wxArr)
  val wxTArr = wx.transpose().data().asDouble().toList
  // 220, 490,  280, 640 - still column-wise
  println(wxTArr)
  val wxTIArr = wx.transposei().data().asDouble().toList
  // 220, 490,  280, 640 - still column-wise
  println(wxTIArr)

}

如您所见,ND4J基本上在内部完成了它想要的工作,并且当您使用.data()时,它将简单地为您提供其内部表示;此表示不会因任何转置或您要求执行的其他任何操作而改变,因为这些实际上并不会移动基础数据。

这一切都很好,但是我要做的基本上是:列出普通双打的Scala列表;把它交给我的自定义库;要求图书馆做事;获取其输出并将其转换为另一个Scala双打列表;验证这些双打是否符合我的预期。相反,我要做的是将期望的内容放入ND4J阵列中,以便可以将其与实际输出正确地进行比较,因此我的测试现在依赖于ND4J,这是我库的内部技术选择。

无论如何,这是一个相对较小的投诉,应该避免使用.data(),如果您使用的是ND4J,请在整个过程中使用它(即使您认为不太优雅)。