鉴于形状为x
的4D张量(batch_size, batch_size, seq_len, feature_dim)
,我希望能够沿对角线入口检索矩阵,即我需要一种方法来获取所有x[diag_entry, diag_entry, :, :]
切片值range(batch_size)
产生形状为(batch_size, seq_len, feature_dim)
的张量。但是,由于我在Keras工作,range(batch_size)
可能会有所不同,因此我无法显式地遍历batch_size
。 Tensorflow是否具有支持此类操作的功能?