如何修改张量值?

时间:2019-10-20 20:27:38

标签: python tensorflow

我有要修改的二维张量

这是我的原始张量

target = tf.Tensor(
[[-0.11913221  0.10740653 -0.11313857  0.0947545 ]
 [-0.08016778  0.08989932 -0.04723025  0.09613379]
 [-0.08598217  0.09504095 -0.05100587  0.10170317]
 [-0.11362994  0.10152908 -0.11400419  0.09055013]
 [-0.08566315  0.0948316  -0.0573874   0.09701383]
 [-0.0454545   0.06643758  0.03581601  0.07577533]
 [-0.08061063  0.08883315 -0.04477514  0.09522746]
 [-0.10650193  0.10237528 -0.10827965  0.09186891]
 [-0.10352179  0.10333493 -0.091276    0.09608406]
 [-0.10635372  0.10243753 -0.10126923  0.09358902]
 [-0.06562068  0.05653733  0.02287578  0.07902215]
 [-0.05279381  0.06863255  0.00381979  0.07298659]
 [-0.05986387  0.0574185   0.00537424  0.07158188]
 [-0.04212753  0.06873286  0.03713646  0.07518272]
 [-0.05582039  0.06819632 -0.00719201  0.07523153]
 [-0.05361114  0.06135845  0.00325142  0.06868324]
 [-0.0516073   0.06817734 -0.00397537  0.07227116]
 [-0.07735051  0.07971332 -0.0285136   0.08493245]
 [-0.0682629   0.07416652 -0.01796982  0.07863645]
 [-0.06170752  0.05759309  0.0143391   0.0743294 ]
 [-0.06184721  0.05632517  0.026482    0.07929536]
 [-0.13982043  0.10956189 -0.13085063  0.09942309]
 [-0.13597223  0.10945709 -0.12371519  0.09820569]
 [-0.09833615  0.10151923 -0.08566514  0.09523689]
 [-0.05372259  0.06107685  0.03173799  0.07874461]
 [-0.0419498   0.07148563  0.04023236  0.07559083]
 [-0.0629972   0.06186401  0.00474169  0.07343359]
 [-0.08224992  0.08429107 -0.03304049  0.08904412]
 [-0.06368244  0.06590354  0.00620057  0.07675664]
 [-0.0835162   0.09322319 -0.07526941  0.09122212]
 [-0.08421545  0.09513094 -0.06602468  0.09551296]
 [-0.05821873  0.0586473   0.0036582   0.07073924]
 [-0.14736125  0.1123023  -0.13142368  0.10306048]
 [-0.0446386   0.06116525  0.03132816  0.06727339]
 [-0.07087782  0.07238103 -0.01442777  0.0776275 ]
 [-0.11170405  0.1033025  -0.11235781  0.09162579]
 [-0.0577084   0.05986657  0.03374245  0.08077951]
 [-0.04407909  0.06878516  0.03870909  0.07593159]
 [-0.04852748  0.06432013  0.03479851  0.07601137]
 [-0.06502739  0.05591577  0.01520929  0.07542811]
 [-0.08267248  0.08766897 -0.04166481  0.09204446]
 [-0.0922101   0.09705735 -0.08632483  0.09220427]
 [-0.08511391  0.09397344 -0.06466756  0.09294012]
 [-0.14137845  0.1104335  -0.12682468  0.09961776]
 [-0.05779795  0.05793206  0.02900699  0.07875173]
 [-0.05267341  0.06256577  0.03344044  0.07881397]
 [-0.06395801  0.0699858  -0.01378324  0.07634839]
 [-0.05504322  0.05956101  0.03019704  0.07883701]
 [-0.07256803  0.06274231  0.00485652  0.07637607]
 [-0.13035618  0.10879815 -0.12050705  0.09703324]
 [-0.07034346  0.06409847  0.00414656  0.07681213]
 [-0.15281731  0.1131851  -0.13400695  0.10459089]
 [-0.05138929  0.06004792  0.03254107  0.07282083]
 [-0.05973424  0.06679939  0.00474964  0.07575827]
 [-0.12728928  0.11161849 -0.11404566  0.10005538]
 [-0.06549592  0.07298904 -0.0131294   0.07808644]
 [-0.08454984  0.08645502 -0.03881929  0.08918599]
 [-0.06490561  0.07068416 -0.00567248  0.07783496]
 [-0.112358    0.10129648 -0.11178349  0.08972972]
 [-0.11124412  0.10365634 -0.10590465  0.09341211]
 [-0.08787578  0.095176   -0.08083484  0.0917344 ]
 [-0.06601885  0.05607153  0.02268048  0.07903323]
 [-0.07518905  0.0782318  -0.02215605  0.08435896]
 [-0.1372289   0.11121203 -0.12439059  0.10107265]], shape=(64, 4), dtype=float32)

使用对应行和列的新值更新原始张量

values = [-1.7363266  -1.60792916  3.65811157 -1.94871653  1.20506907  1.32829045
 -2.73125213  0.29886797 -0.46393473 -1.23520107 -2.60709122 -2.23281132
 -0.69956029 -2.2418921   3.22370838 -2.00896171 -2.01805908  2.82528791
 -2.20533027 -2.69834789  0.66163399  3.3812018   0.40844984 -0.46674331
 -0.58725082 -2.34600192 -1.67365202 -2.1870339   1.34777899  0.26781824
 -1.58523572 -2.61302884 -1.57072259 -2.32469599 -2.27428221 -1.9625542
 -0.36708335 -2.24177788 -2.52221747 -0.27126677 -2.06894214 -2.17454067
 -2.12836831 -0.41872739 -1.88022913  0.1116449  -2.44375356 -2.59824813
 -1.92523038 -1.89574978 -1.51718945 -1.02749129  1.22673353 -2.26721576
 -0.95453949  1.27187676 -1.94634709 -1.77064807 -1.56859298 -2.22134502
 -1.82816027  2.0710269   0.9413091  -2.10814866]

在行和列

rows = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]
columns = [1 1 2 0 3 1 2 2 2 3 1 2 2 1 3 1 2 3 2 3 3 0 0 0 1 0 0 0 2 2 2 0 3 3 2 1 2
 1 2 0 1 0 3 3 2 3 3 1 1 1 3 2 1 1 2 1 2 1 1 2 1 1 3 0]

但是我在下面出现TypeError

'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment

当我使用以下语法更新张量时

target[[rows], [columns]] = values

我的理解是张量是不可变的对象,但是我如何更新它呢?

2 个答案:

答案 0 :(得分:0)

我不得不将张量转换为numpy数组。我正在使用tensorflow 2.0,所以我使用了numpy()并进行编辑。

target_np = target.numpy()
target_np[[rows], [columns]] = values

答案 1 :(得分:-1)

您可以使用tf.tensor_scatter_nd_update。为了使这种更新生效,尽管target必须是tf.Variable,例如

target = tf.Variable(...)

然后,使用您定义的rowscolumnsvalues,稀疏更新将如下所示:

indices = tf.stack([rows, columns], axis=1)
target_new = tf.tensor_scatter_nd_update(target, indices, values)