KeyError:使用Stellargraph generator.flow进行训练时为0

时间:2020-07-09 13:49:30

标签: python networkx

我正在尝试使用星图库的gcn类进行节点分类 因此,我导入了节点特征node_feat.csv,节点标签为node_label.csv以及边缘特征为edge_feat.csv。遵循https://stellargraph.readthedocs.io/en/stable/demos/node-classification/gcn-node-classification.html中给出的节点分类程序。

!wget -O node_feat.csv https://github.com/pranavn91/blockchain/blob/master/tx2009partvertices_new.csv
!wget -O node_targets.csv https://github.com/pranavn91/blockchain/blob/master/tx2009partvertices.csv
!wget -O edge_data.csv https://github.com/pranavn91/blockchain/blob/master/tx2009partedges.csv
    

然后我导入库

from stellargraph import StellarDiGraph as sg
import pandas as pd
import os
#import stellargraph as sg
from stellargraph.mapper import FullBatchNodeGenerator
from stellargraph.layer import GCN

from tensorflow.keras import layers, optimizers, losses, metrics, Model
from sklearn import preprocessing, model_selection
from IPython.display import display, HTML
import matplotlib.pyplot as plt
%matplotlib inline

然后创建恒星有向图

trans2009 = sg(
    {"users": node_feat}, {"transfer_btc": edge_data}
)
print(trans2009.info())

分割数据集

train_subjects, test_subjects = model_selection.train_test_split(
    node_targets, train_size=40, test_size=None
)
val_subjects, test_subjects = model_selection.train_test_split(
    node_targets, train_size=6, test_size=None
)

target_encoding = preprocessing.LabelBinarizer()

train_targets = target_encoding.fit_transform(train_subjects["label"].astype(str))
val_targets = target_encoding.transform(val_subjects["label"].astype(str))
test_targets = target_encoding.transform(test_subjects["label"].astype(str))

generator = FullBatchNodeGenerator(G, method="gcn")

但是对于下面给出的步骤

train_gen = generator.flow(train_subjects.index, train_targets, shuffle=True)

获取keyError:0

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-31-74331d853607> in <module>
----> 1 train_gen = generator.flow(train_subjects.index, train_targets, shuffle=True)

/opt/conda/lib/python3.7/site-packages/stellargraph/mapper/sampled_node_generators.py in flow(self, node_ids, targets, shuffle, seed)
    139             expected_node_type = None
    140 
--> 141         node_ilocs = self.graph.node_ids_to_ilocs(node_ids)
    142         node_types = self.graph.node_type(node_ilocs, use_ilocs=True)
    143         invalid = node_ilocs[node_types != expected_node_type]

/opt/conda/lib/python3.7/site-packages/stellargraph/core/graph.py in node_ids_to_ilocs(self, nodes)
   1211             Numpy array containing the indices for the requested nodes.
   1212         """
-> 1213         return self._nodes.ids.to_iloc(nodes, strict=True)
   1214 
   1215     def node_ilocs_to_ids(self, node_ilocs):

/opt/conda/lib/python3.7/site-packages/stellargraph/core/element_data.py in to_iloc(self, ids, smaller_type, strict)
     95         internal_ids = self._index.get_indexer(ids)
     96         if strict:
---> 97             self.require_valid(ids, internal_ids)
     98 
     99         # reduce the storage required (especially useful if this is going to be stored rather than

/opt/conda/lib/python3.7/site-packages/stellargraph/core/element_data.py in require_valid(self, query_ids, ilocs)
     75 
     76             if len(missing_values) == 1:
---> 77                 raise KeyError(missing_values[0])
     78 
     79             raise KeyError(missing_values)

KeyError: 0

如何解决?

0 个答案:

没有答案