与tf.sparse_tensor_dense_matmul()有关的错误

时间:2019-01-08 07:28:49

标签: tensorflow

示例 以下示例包含一些我无法解决的错误。

import tensorflow as tf
import numpy as np
import scipy.sparse as sp
from scipy.sparse import csc_matrix

def sparse_to_tuple(sparse_mx):
    """Convert sparse matrix to tuple representation."""
    def to_tuple(mx):
        if not sp.isspmatrix_coo(mx):
            mx = mx.tocoo()
        coords = np.vstack((mx.row, mx.col)).transpose()
        values = mx.data
        shape = mx.shape
        return coords, values, shape

    if isinstance(sparse_mx, list):
        for i in range(len(sparse_mx)):
            sparse_mx[i] = to_tuple(sparse_mx[i])
    else:
        sparse_mx = to_tuple(sparse_mx)

    return sparse_mx

A = csc_matrix(np.diag(np.array([1, 2, 3])))
A = sparse_to_tuple(A)

B = tf.ones([3, 2], tf.float32)
C = tf.sparse_placeholder(tf.float32, shape=(3, 3))
D = tf.sparse_tensor_dense_matmul(C, B)

# if this line, together with the last line, is commented, it can run
H = tf.matmul(tf.sparse_tensor_dense_matmul(C, B), tf.transpose(B, [1, 0]))

with tf.Session() as sess:
    out = sess.run(D, {C: A})
    print('----------------')
    print(out)
    print('----------------')
    print(sess.run(H, {C: A}))

错误:

  

ValueError跟踪(最近一次通话最后一次)

<ipython-input-83-02386eb1014c> in <module>()
     32 
     33 with tf.Session() as sess:
---> 34     out = sess.run(D, {C: A})
     35     print('----------------')
     36     print(out)
     

/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py在运行时(self,fetches,feed_dict,options,run_metadata)

    898     collected into this argument and passed back.
    899 
--> 900     Args:
    901       fetches: A single graph element, a list of graph elements,
    902         or a dictionary whose values are graph elements or lists of graph
     

/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py在_run中(自身,句柄,提取,feed_dict,选项,run_metadata)

   1111                 ' is not compatible with Tensor type ' + str(subfeed_dtype) +
   1112                 '. Try explicitly setting the type of the feed tensor'
-> 1113                 ' to a larger type (e.g. int64).')
   1114 
   1115           is_tensor_handle_feed = isinstance(subfeed_val,
     

ValueError:Tensor Tensor(“ Const_132:0”,shape =(2,),dtype = int64)可能无法馈送。

系统信息
tesorflow-gpu = 1.9.0
python = Python 3.6.5 :: Anaconda,Inc.
Ubuntu 16.04

0 个答案:

没有答案