这是我的可复制代码:
tf_ent = tf.Variable([ [9.96, 8.65, 0.99, 0.1 ],
[0.7, 8.33, 0.1 , 0.1 ],
[0.9, 0.1, 6, 7.33],
[6.60, 0.1, 3, 5.5 ],
[9.49, 0.2, 0.2, 0.2 ],
[0.4, 8.45, 0.2, 0.2 ],
[0.3, 0.2, 5.82, 8.28]])
tf_ent_var = tf.constant([True, False, False, False, False, True, False])
我想保留tf_ent
中对应索引为True的tf_ent_var
中的行,并在整个矩阵中将其余行减为最少。
所以预期的输出将是这样的:
[[9.96, 8.65, 0.99, 0.1 ],
[0.1, 0.1, 0.1 , 0.1 ],
[0.1, 0.1, 0.1, 0.1 ],
[0.1, 0.1, 0.1, 0.1 ],
[0.1, 0.1, 0.1, 0.1 ],
[0.4, 8.45, 0.2, 0.2 ],
[0.1, 0.1, 0.1, 0.1 ]]
任何想法我该怎么做?
我试图从掩盖的张量中获取索引,然后使用tf.gather来完成这一操作,但是我获得的索引就像[[0], [6]]
一样,这很有意义,因为它给出了一个矢量的索引。
答案 0 :(得分:2)
编辑:对于tensorflow 1.x,使用:
val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)
不幸的是,广播规则不是2.0规则的子集(与numpy相同),但是“只是不同”。关于版本兼容性,Tensorflow并不是最好的库。
基本思想是使用tf.where
,但是您首先需要将tf_ent_var
更改为形状为(7, 1)
的张量,以便tensorflow知道在第二个轴上广播它在第一个轴上。所以:
val = tf.math.reduce_min(tf_ent)
tf.where(tf_ent_var[:, tf.newaxis], tf_ent, val)
当然,您也可以将其重塑为(-1, 1)
,但我认为与tf.newaxis
的切片更短,更清晰。
这是我与1.13.1的Python交互式会话,用于进行故障排除。
Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 16:52:21)
[Clang 6.0 (clang-600.0.57)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> sess = tf.InteractiveSession()
2019-06-22 15:51:09.210852: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
>>> tf_ent = tf.Variable([ [9.96, 8.65, 0.99, 0.1 ],
... [0.7, 8.33, 0.1 , 0.1 ],
... [0.9, 0.1, 6, 7.33],
... [6.60, 0.1, 3, 5.5 ],
... [9.49, 0.2, 0.2, 0.2 ],
... [0.4, 8.45, 0.2, 0.2 ],
... [0.3, 0.2, 5.82, 8.28]])
WARNING:tensorflow:From /Users/REDACTED/Documents/test/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
>>>
>>> tf_ent_var = tf.constant([True, False, False, False, False, True, False])
>>> init = tf.global_variables_initializer()
>>> sess.run(init)
>>> val = tf.math.reduce_min(tf_ent)
>>> tf.where(tf_ent_var, tf_ent, tf.zeros_like(tf_ent) + val)
<tf.Tensor 'Select:0' shape=(7, 4) dtype=float32>
>>> _.eval()
array([[9.96, 8.65, 0.99, 0.1 ],
[0.1 , 0.1 , 0.1 , 0.1 ],
[0.1 , 0.1 , 0.1 , 0.1 ],
[0.1 , 0.1 , 0.1 , 0.1 ],
[0.1 , 0.1 , 0.1 , 0.1 ],
[0.4 , 8.45, 0.2 , 0.2 ],
[0.1 , 0.1 , 0.1 , 0.1 ]], dtype=float32)
>>> tf.__version__
'1.13.1'
答案 1 :(得分:2)
min_mat = tf.broadcast_to(tf.reduce_min(tf_ent), tf_ent.shape)
output = tf.where(tf_ent_var, tf_ent, min_mat)
sess.run(output)
答案 2 :(得分:1)
这是我使用import pandas as pd
from bs4 import BeautifulSoup
from requests import get
link = 'https://www.sec.gov/Archives/edgar/data/789019/000156459018019062/msft-20180630.xml'
r = get(link)
str = r.text
soup = BeautifulSoup(str, 'lxml')
tags = soup.find_all()
df = pd.DataFrame(columns=['field','period','value'])
for tag in tags:
if ('us-gaap:' in tag.name # only want gaap-related tags
and tag.text.isdigit()): # only want values, no commentary
#a = re.match("^C_"+ re.escape(cik) + "_[0-9]", tag['contextref'])
name = tag.name.split('gaap:')[1]
cref = tag['contextref'][-8:-4]
value = tag.text
df = df.append({'field': name, 'period': cref, 'value': value}, ignore_index=True)
print(df)
和tf.concat()
语句的实现。它不如其他人的回答优雅,但可以正常工作:
if-else
输出:
import tensorflow as tf
tf.enable_eager_execution()
def slice_tensor_based_on_mask(tf_ent, tf_ent_var):
res = tf.fill([1, 4], 0.0)
min_value_tensor = tf.fill([1,int(tf_ent.shape[1])], tf.reduce_min(tf_ent))
for i in range(int(tf_ent.shape[0])):
if tf_ent_var[i:i+1].numpy()[0]: # true value in tf_ent_var
res = tf.concat([res, tf_ent[i:i+1]], 0)
else:
res = tf.concat([res, min_value_tensor], 0)
return res[1:]
tf_ent = tf.Variable([[9.96, 8.65, 0.99, 0.1 ],
[0.7, 8.33, 0.1 , 0.1 ],
[0.9, 0.1, 6, 7.33],
[6.60, 0.1, 3, 5.5 ],
[9.49, 0.2, 0.2, 0.2 ],
[0.4, 8.45, 0.2, 0.2 ],
[0.3, 0.2, 5.82, 8.28]])
tf_ent_var = tf.constant([True, False, False, False, False, True, False])
print(slice_tensor_based_on_mask(tf_ent, tf_ent_var))