我有要修改的二维张量
这是我的原始张量
target = tf.Tensor(
[[-0.11913221 0.10740653 -0.11313857 0.0947545 ]
[-0.08016778 0.08989932 -0.04723025 0.09613379]
[-0.08598217 0.09504095 -0.05100587 0.10170317]
[-0.11362994 0.10152908 -0.11400419 0.09055013]
[-0.08566315 0.0948316 -0.0573874 0.09701383]
[-0.0454545 0.06643758 0.03581601 0.07577533]
[-0.08061063 0.08883315 -0.04477514 0.09522746]
[-0.10650193 0.10237528 -0.10827965 0.09186891]
[-0.10352179 0.10333493 -0.091276 0.09608406]
[-0.10635372 0.10243753 -0.10126923 0.09358902]
[-0.06562068 0.05653733 0.02287578 0.07902215]
[-0.05279381 0.06863255 0.00381979 0.07298659]
[-0.05986387 0.0574185 0.00537424 0.07158188]
[-0.04212753 0.06873286 0.03713646 0.07518272]
[-0.05582039 0.06819632 -0.00719201 0.07523153]
[-0.05361114 0.06135845 0.00325142 0.06868324]
[-0.0516073 0.06817734 -0.00397537 0.07227116]
[-0.07735051 0.07971332 -0.0285136 0.08493245]
[-0.0682629 0.07416652 -0.01796982 0.07863645]
[-0.06170752 0.05759309 0.0143391 0.0743294 ]
[-0.06184721 0.05632517 0.026482 0.07929536]
[-0.13982043 0.10956189 -0.13085063 0.09942309]
[-0.13597223 0.10945709 -0.12371519 0.09820569]
[-0.09833615 0.10151923 -0.08566514 0.09523689]
[-0.05372259 0.06107685 0.03173799 0.07874461]
[-0.0419498 0.07148563 0.04023236 0.07559083]
[-0.0629972 0.06186401 0.00474169 0.07343359]
[-0.08224992 0.08429107 -0.03304049 0.08904412]
[-0.06368244 0.06590354 0.00620057 0.07675664]
[-0.0835162 0.09322319 -0.07526941 0.09122212]
[-0.08421545 0.09513094 -0.06602468 0.09551296]
[-0.05821873 0.0586473 0.0036582 0.07073924]
[-0.14736125 0.1123023 -0.13142368 0.10306048]
[-0.0446386 0.06116525 0.03132816 0.06727339]
[-0.07087782 0.07238103 -0.01442777 0.0776275 ]
[-0.11170405 0.1033025 -0.11235781 0.09162579]
[-0.0577084 0.05986657 0.03374245 0.08077951]
[-0.04407909 0.06878516 0.03870909 0.07593159]
[-0.04852748 0.06432013 0.03479851 0.07601137]
[-0.06502739 0.05591577 0.01520929 0.07542811]
[-0.08267248 0.08766897 -0.04166481 0.09204446]
[-0.0922101 0.09705735 -0.08632483 0.09220427]
[-0.08511391 0.09397344 -0.06466756 0.09294012]
[-0.14137845 0.1104335 -0.12682468 0.09961776]
[-0.05779795 0.05793206 0.02900699 0.07875173]
[-0.05267341 0.06256577 0.03344044 0.07881397]
[-0.06395801 0.0699858 -0.01378324 0.07634839]
[-0.05504322 0.05956101 0.03019704 0.07883701]
[-0.07256803 0.06274231 0.00485652 0.07637607]
[-0.13035618 0.10879815 -0.12050705 0.09703324]
[-0.07034346 0.06409847 0.00414656 0.07681213]
[-0.15281731 0.1131851 -0.13400695 0.10459089]
[-0.05138929 0.06004792 0.03254107 0.07282083]
[-0.05973424 0.06679939 0.00474964 0.07575827]
[-0.12728928 0.11161849 -0.11404566 0.10005538]
[-0.06549592 0.07298904 -0.0131294 0.07808644]
[-0.08454984 0.08645502 -0.03881929 0.08918599]
[-0.06490561 0.07068416 -0.00567248 0.07783496]
[-0.112358 0.10129648 -0.11178349 0.08972972]
[-0.11124412 0.10365634 -0.10590465 0.09341211]
[-0.08787578 0.095176 -0.08083484 0.0917344 ]
[-0.06601885 0.05607153 0.02268048 0.07903323]
[-0.07518905 0.0782318 -0.02215605 0.08435896]
[-0.1372289 0.11121203 -0.12439059 0.10107265]], shape=(64, 4), dtype=float32)
使用对应行和列的新值更新原始张量
values = [-1.7363266 -1.60792916 3.65811157 -1.94871653 1.20506907 1.32829045
-2.73125213 0.29886797 -0.46393473 -1.23520107 -2.60709122 -2.23281132
-0.69956029 -2.2418921 3.22370838 -2.00896171 -2.01805908 2.82528791
-2.20533027 -2.69834789 0.66163399 3.3812018 0.40844984 -0.46674331
-0.58725082 -2.34600192 -1.67365202 -2.1870339 1.34777899 0.26781824
-1.58523572 -2.61302884 -1.57072259 -2.32469599 -2.27428221 -1.9625542
-0.36708335 -2.24177788 -2.52221747 -0.27126677 -2.06894214 -2.17454067
-2.12836831 -0.41872739 -1.88022913 0.1116449 -2.44375356 -2.59824813
-1.92523038 -1.89574978 -1.51718945 -1.02749129 1.22673353 -2.26721576
-0.95453949 1.27187676 -1.94634709 -1.77064807 -1.56859298 -2.22134502
-1.82816027 2.0710269 0.9413091 -2.10814866]
在行和列
rows = [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]
columns = [1 1 2 0 3 1 2 2 2 3 1 2 2 1 3 1 2 3 2 3 3 0 0 0 1 0 0 0 2 2 2 0 3 3 2 1 2
1 2 0 1 0 3 3 2 3 3 1 1 1 3 2 1 1 2 1 2 1 1 2 1 1 3 0]
但是我在下面出现TypeError
'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
当我使用以下语法更新张量时
target[[rows], [columns]] = values
我的理解是张量是不可变的对象,但是我如何更新它呢?
答案 0 :(得分:0)
我不得不将张量转换为numpy数组。我正在使用tensorflow 2.0,所以我使用了numpy()并进行编辑。
target_np = target.numpy()
target_np[[rows], [columns]] = values
答案 1 :(得分:-1)
您可以使用tf.tensor_scatter_nd_update
。为了使这种更新生效,尽管target
必须是tf.Variable
,例如
target = tf.Variable(...)
然后,使用您定义的rows
,columns
和values
,稀疏更新将如下所示:
indices = tf.stack([rows, columns], axis=1)
target_new = tf.tensor_scatter_nd_update(target, indices, values)