您可以有效地在TensorFlow JS中复制tf.random_crop吗?

时间:2019-04-10 12:27:16

标签: javascript tensorflow tensorflow.js

TensorFlow JS中没有实现随机裁剪,但是可以复制吗?我的想法是将Tensor.slice()与从tf.randomUniform生成的张量一起用作参数,但它仅接受“数字”。因此在我看来,为了使随机裁剪工作正常进行,我必须在每次迭代中使用新生成的随机数(例如来自Math.random())作为切片参数来重构计算图的该部分。还是有另一种方法?

这是我的代码。我的理解是,内部函数只会创建一次随机偏移rx和ry,并且我需要一个Tensorflow操作来在每次迭代中连续获取随机值。

export function jitter (d) {
  const inner = (tImage) => {
    const tShp = tImage.shape;
    const cropShape = [
      tShp[0], tShp[1]-d,
      tShp[2]-d, tShp[3]];
    const rx = Math.floor(Math.random() * d + 0.5);
    const ry = Math.floor(Math.random() * d + 0.5);
    const crop = tImage.slice(
      [0, rx, ry, 0],
      [cropShape[0], cropShape[1], cropShape[2], cropShape[3]]);
  }

  return inner;
}

Link to doc for Tensor.slice()

1 个答案:

答案 0 :(得分:0)

slice将允许切片或裁剪部分输入。另一方面,如果要避免重复使用切片,使用gatherND将允许切片多次。但是应该给出切片的索引。下面,函数g根据随机坐标生成索引,并尝试计算将包含在农作物中的所有z * z元素的索引。

const g = (r, s, z, n) => {
  const arr = []
  for (let i = 0; i < n; i++) {
    const c = Math.floor(Math.random() * r)
    const d = Math.floor(Math.random() * s)
    const p = Array.from({length: z}, (_, k) => k + c)
    const q = Array.from({length: z}, (_, k) => k + d)
    arr.push(p.map( e => q.map(f => ([e, f]))).flat())
  }
  return arr
}

const n = 3
const crop = 4
const hsize = 2 // maximum of the height where to start cropping
const wsize = 2 // maximum of the width where to start cropping
// hsize = length_of_height_dimension - crop_size_over_height
// wsize = length_of_width_dimension - crop_size_over_width
const indices = tf.tensor( g(hsize, wsize, crop, n)).toInt()
const input = tf.tensor(Array.from({length: 64 * 3}, (_, k) => k +1), [8, 8, 3]);
tf.gatherND(input, indices).reshape([n, crop, crop, 3]).print()
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
  </head>

  <body>
  </body>
</html>