基于张量流中的掩盖张量的切片

时间:2019-06-22 21:05:57

标签: python tensorflow slice

这是我的可复制代码:

tf_ent = tf.Variable([   [9.96,    8.65,    0.99,    0.1 ],
                         [0.7,     8.33,    0.1  ,   0.1   ],
                         [0.9,     0.1,     6,       7.33],
                         [6.60,    0.1,     3,       5.5 ],
                         [9.49,    0.2,     0.2,     0.2   ],
                         [0.4,     8.45,    0.2,     0.2 ],
                         [0.3,     0.2,     5.82,    8.28]])

tf_ent_var = tf.constant([True, False, False, False, False, True, False])

我想保留tf_ent中对应索引为True的tf_ent_var中的行,并在整个矩阵中将其余行减为最少。

所以预期的输出将是这样的:

                    [[9.96,    8.65,    0.99,   0.1 ],
                     [0.1,     0.1,     0.1  ,  0.1 ],
                     [0.1,     0.1,     0.1,    0.1 ],
                     [0.1,     0.1,     0.1,    0.1 ],
                     [0.1,     0.1,     0.1,    0.1 ],
                     [0.4,     8.45,    0.2,      0.2 ],
                     [0.1,     0.1,     0.1,    0.1 ]]

任何想法我该怎么做?

我试图从掩盖的张量中获取索引,然后使用tf.gather来完成这一操作,但是我获得的索引就像[[0], [6]]一样,这很有意义,因为它给出了一个矢量的索引。

3 个答案:

答案 0 :(得分:2)

编辑:对于tensorflow 1.x,使用:

val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)

不幸的是,广播规则不是2.0规则的子集(与numpy相同),但是“只是不同”。关于版本兼容性,Tensorflow并不是最好的库。


基本思想是使用tf.where,但是您首先需要将tf_ent_var更改为形状为(7, 1)的张量,以便tensorflow知道在第二个轴上广播它在第一个轴上。所以:

val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var[:, tf.newaxis], tf_ent, val)

当然,您也可以将其重塑为(-1, 1),但我认为与tf.newaxis的切片更短,更清晰。


这是我与1.13.1的Python交互式会话,用于进行故障排除。

Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 16:52:21) 
[Clang 6.0 (clang-600.0.57)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
2019-06-22 15:51:09.210852: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
>>> tf_ent = tf.Variable([   [9.96,    8.65,    0.99,    0.1 ],
...                          [0.7,     8.33,    0.1  ,   0.1   ],
...                          [0.9,     0.1,     6,       7.33],
...                          [6.60,    0.1,     3,       5.5 ],
...                          [9.49,    0.2,     0.2,     0.2   ],
...                          [0.4,     8.45,    0.2,     0.2 ],
...                          [0.3,     0.2,     5.82,    8.28]])
WARNING:tensorflow:From /Users/REDACTED/Documents/test/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
>>> 
>>> tf_ent_var = tf.constant([True, False, False, False, False, True, False])
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> val = tf.math.reduce_min(tf_ent)
>>> tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)
<tf.Tensor 'Select:0' shape=(7, 4) dtype=float32>
>>> _.eval()
array([[9.96, 8.65, 0.99, 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.1 , 0.1 , 0.1 , 0.1 ],
       [0.4 , 8.45, 0.2 , 0.2 ],
       [0.1 , 0.1 , 0.1 , 0.1 ]], dtype=float32)
>>> tf.__version__
'1.13.1'

答案 1 :(得分:2)

min_mat = tf.broadcast_to(tf.reduce_min(tf_ent), tf_ent.shape)
output = tf.where(tf_ent_var, tf_ent, min_mat)
sess.run(output)

答案 2 :(得分:1)

这是我使用import pandas as pd from bs4 import BeautifulSoup from requests import get link = 'https://www.sec.gov/Archives/edgar/data/789019/000156459018019062/msft-20180630.xml' r = get(link) str = r.text soup = BeautifulSoup(str, 'lxml') tags = soup.find_all() df = pd.DataFrame(columns=['field','period','value']) for tag in tags: if ('us-gaap:' in tag.name # only want gaap-related tags and tag.text.isdigit()): # only want values, no commentary #a = re.match("^C_"+ re.escape(cik) + "_[0-9]", tag['contextref']) name = tag.name.split('gaap:')[1] cref = tag['contextref'][-8:-4] value = tag.text df = df.append({'field': name, 'period': cref, 'value': value}, ignore_index=True) print(df) tf.concat()语句的实现。它不如其他人的回答优雅,但可以正常工作:

if-else

输出:

import tensorflow as tf
tf.enable_eager_execution()

def slice_tensor_based_on_mask(tf_ent, tf_ent_var):
    res = tf.fill([1, 4], 0.0)  
    min_value_tensor = tf.fill([1,int(tf_ent.shape[1])], tf.reduce_min(tf_ent))

    for i in range(int(tf_ent.shape[0])):
        if tf_ent_var[i:i+1].numpy()[0]: # true value in tf_ent_var
            res = tf.concat([res, tf_ent[i:i+1]], 0)
        else:
            res = tf.concat([res, min_value_tensor], 0)
    return res[1:]

tf_ent = tf.Variable([[9.96,    8.65,    0.99,   0.1 ],
                     [0.7,     8.33,    0.1  ,   0.1 ],
                     [0.9,     0.1,     6,       7.33],
                     [6.60,    0.1,     3,       5.5 ],
                     [9.49,    0.2,     0.2,     0.2 ],
                     [0.4,     8.45,    0.2,     0.2 ],
                     [0.3,     0.2,     5.82,    8.28]])

tf_ent_var = tf.constant([True, False, False, False, False, True, False])
print(slice_tensor_based_on_mask(tf_ent, tf_ent_var))