在Tensorflow上使用DNNClassifier预测多个值

时间:2018-07-04 18:42:24

标签: python tensorflow multilabel-classification

我是Tensoflow的新手,我在Tensorflow上使用DNNClassifier训练了一个模型。 目前,我能够使用函数Estimator.predict()来预测标签。 但是,该函数只返回1个标签。

我想使用函数predict()中的相同输入来预测至少5个标签,而不是1个。 如何在张量流中实现?

See picture of my dataset

估算者将受让人作为标签返回。

以下是预测的输出:

{'class_ids': array([190]),
 'classes': array(['mak77'], dtype=object),
 'logits': array([-0.2421599 , -0.12319885,  0.15174948, -0.15026578,  0.89157784,
        -0.23324066, -0.01014094, -0.18928407, -0.09564508, -0.09692501,
         0.27798116, -0.18627472,  0.30694136, -0.03852478,  0.27085435,
        -0.28440273, -0.05254484,  0.97408783,  0.79990137,  0.92376447,
        -0.13129678, -0.10517468,  0.06036109, -0.08028498,  0.23511052,
        -0.21042491, -0.3250528 , -0.09714904, -0.2318278 ,  0.34265715,
        -0.1475605 , -0.30490822, -0.27597314, -0.31579942,  0.23443258,
        -0.24326894,  0.78562534, -0.1337908 , -0.33254987, -0.13175169,
         0.13258094,  0.3983282 , -0.236716  , -0.26265702, -0.3695963 ,
         0.30973643, -0.30874044,  0.18243848, -0.15242837, -0.3009109 ,
        -0.25969234, -0.21650666, -0.27924046, -0.07056297, -0.21142481,
        -0.15906394, -0.306207  , -0.2627802 , -0.31472245, -0.19378348,
        -0.05995677,  1.022069  ,  0.636706  , -0.07286941, -0.18499285,
        -0.28038606,  0.1368906 ,  0.15198329,  0.71752954,  1.1104813 ,
        -0.26558146,  0.8238711 , -0.19233458, -0.1854143 ,  0.31759885,
        -0.146381  , -0.18593065, -0.13675475, -0.20170444, -0.16712724,
        -0.00164723,  0.14585523, -0.18056622,  0.72360873, -0.41427153,
        -0.00764041, -0.05752933,  0.20944679, -0.17798054,  0.12586626,
         0.1259037 , -0.22287199,  0.0444079 ,  0.3401235 ,  0.03906745,
        -0.2371698 , -0.16577421, -0.2032205 , -0.28764653, -0.12376721,
        -0.30089054, -0.075834  , -0.19585209, -0.1570221 ,  0.55330455,
         0.17516677, -0.00671412, -0.11628791, -0.25115073, -0.07640148,
         0.53846574, -0.25847486, -0.14987963,  0.15814571,  0.02292341,
        -0.04512586, -0.16207342,  0.19349058,  0.39997375,  0.3301786 ,
         0.844493  ,  0.09324032, -0.2139478 , -0.13260707, -0.26083636,
        -0.34126705, -0.23246536, -0.1741175 ,  0.45213443, -0.1172404 ,
         0.4537478 ,  0.2179842 , -0.2546497 , -0.14152235, -0.1866959 ,
        -0.24136472, -0.15965518, -0.08740373, -0.21120709, -0.19131401,
        -0.02312465, -0.1424037 , -0.12424263, -0.36425227,  0.45957452,
        -0.15135665, -0.14188159, -0.05672614,  0.4294385 , -0.09361711,
        -0.19219732, -0.21372089, -0.32322058, -0.08584573, -0.30583653,
         0.8455917 , -0.3032573 , -0.10099987, -0.26707715,  0.9182271 ,
        -0.08267958, -0.23941115,  0.726688  ,  0.8229299 , -0.20243265,
        -0.22727305, -0.21914873, -0.25882468, -0.10637616, -0.00753736,
        -0.13777769, -0.23496273,  0.5319726 , -0.09427259, -0.21627167,
         0.65407616, -0.17718542, -0.3488192 ,  0.26468042,  0.7410902 ,
        -0.16991816, -0.23623988, -0.2834435 , -0.30871546,  1.0743666 ,
        -0.16253519, -0.18778808,  0.09737102,  0.3328103 , -0.2579007 ,
         1.189676  , -0.14929953,  0.03344507, -0.08095035, -0.09273945,
         0.32589382,  0.7084502 ,  0.1370207 , -0.14819698,  0.02922156,
        -0.09053307, -0.40498376, -0.06966598, -0.11482047, -0.01499982,
        -0.22453265, -0.11453746,  0.439895  , -0.0995437 , -0.08674948,
        -0.22953579,  0.8025714 ,  0.10079399, -0.18149713, -0.20940983,
        -0.19616881, -0.23122519, -0.11368135,  0.17096575,  0.22107399,
        -0.12206952, -0.10031576,  0.36283356,  0.3261691 , -0.07967824,
        -0.1598283 , -0.13299195, -0.14447328,  0.7961639 ,  0.24629065,
        -0.1107186 , -0.2409806 , -0.07271755, -0.1496983 , -0.0646622 ,
         0.68419445, -0.6437435 , -0.16265479,  0.14343943, -0.04606699,
        -0.15428844, -0.13722235,  0.39974225,  0.7410692 , -0.06881508,
        -0.22289805,  0.09856481, -0.20454887, -0.12839113, -0.15118025,
         0.7624941 , -0.05466021, -0.11759382, -0.14394847, -0.27613777,
        -0.27737805, -0.1069457 , -0.16521774, -0.24430685,  1.0053204 ,
         0.16546829], dtype=float32),
 'probabilities': array([0.00275953, 0.00310813, 0.00409173, 0.00302513, 0.00857454,
        0.00278425, 0.00348016, 0.00290937, 0.00319496, 0.00319087,
        0.00464226, 0.00291814, 0.00477866, 0.00338277, 0.00460929,
        0.00264539, 0.00333567, 0.00931203, 0.00782341, 0.00885502,
        0.00308306, 0.00316466, 0.00373438, 0.00324441, 0.00444745,
        0.00284851, 0.00254001, 0.00319016, 0.00278819, 0.00495242,
        0.00303332, 0.00259169, 0.00266778, 0.00256362, 0.00444443,
        0.00275647, 0.00771252, 0.00307538, 0.00252104, 0.00308166,
        0.00401405, 0.00523595, 0.00277459, 0.00270354, 0.00242935,
        0.00479204, 0.00258178, 0.00421925, 0.00301859, 0.00260207,
        0.00271157, 0.00283123, 0.00265908, 0.00327611, 0.00284566,
        0.00299863, 0.00258833, 0.00270321, 0.00256638, 0.00289631,
        0.00331104, 0.00976973, 0.0066454 , 0.00326856, 0.00292188,
        0.00265603, 0.00403139, 0.00409269, 0.00720481, 0.01067283,
        0.00269565, 0.0080132 , 0.00290051, 0.00292065, 0.00482986,
        0.0030369 , 0.00291914, 0.00306628, 0.00287345, 0.00297455,
        0.00350985, 0.00406769, 0.00293484, 0.00724874, 0.00232321,
        0.00348887, 0.00331909, 0.00433476, 0.00294244, 0.00398719,
        0.00398734, 0.00281327, 0.00367527, 0.00493989, 0.0036557 ,
        0.00277333, 0.00297858, 0.0028691 , 0.00263682, 0.00310636,
        0.00260213, 0.00325889, 0.00289032, 0.00300476, 0.00611365,
        0.00418868, 0.00349211, 0.00312968, 0.00273483, 0.00325704,
        0.0060236 , 0.00271487, 0.0030263 , 0.00411799, 0.00359715,
        0.00336051, 0.00298962, 0.00426614, 0.00524457, 0.00489101,
        0.00818017, 0.0038592 , 0.00283849, 0.00307902, 0.00270847,
        0.00249915, 0.00278641, 0.00295383, 0.00552539, 0.0031267 ,
        0.00553431, 0.00437193, 0.00272528, 0.00305169, 0.00291691,
        0.00276172, 0.00299686, 0.0032214 , 0.00284628, 0.00290347,
        0.00343527, 0.00304901, 0.00310489, 0.00244237, 0.00556665,
        0.00302183, 0.0030506 , 0.00332176, 0.0054014 , 0.00320145,
        0.0029009 , 0.00283913, 0.00254467, 0.00322642, 0.00258929,
        0.00818916, 0.00259598, 0.0031779 , 0.00269162, 0.00880612,
        0.00323665, 0.00276712, 0.0072711 , 0.00800567, 0.00287136,
        0.00280092, 0.00282376, 0.00271392, 0.00316086, 0.00348923,
        0.00306314, 0.00277946, 0.00598462, 0.00319935, 0.0028319 ,
        0.00676184, 0.00294478, 0.00248035, 0.00458092, 0.00737658,
        0.00296626, 0.00277591, 0.00264792, 0.00258184, 0.01029426,
        0.00298824, 0.00291372, 0.00387517, 0.0049039 , 0.00271643,
        0.01155243, 0.00302805, 0.0036352 , 0.00324226, 0.00320426,
        0.00487009, 0.00713969, 0.00403191, 0.00303139, 0.00361988,
        0.00321133, 0.00234488, 0.00327905, 0.00313428, 0.00346329,
        0.0028086 , 0.00313517, 0.00545818, 0.00318253, 0.00322351,
        0.00279459, 0.00784433, 0.00388846, 0.00293211, 0.0028514 ,
        0.00288941, 0.00278987, 0.00313785, 0.00417112, 0.00438546,
        0.00311164, 0.00318007, 0.00505336, 0.00487144, 0.00324638,
        0.00299634, 0.00307784, 0.0030427 , 0.00779423, 0.00449745,
        0.00314716, 0.00276278, 0.00326906, 0.00302685, 0.0032955 ,
        0.0069686 , 0.00184684, 0.00298788, 0.00405787, 0.00335735,
        0.00301298, 0.00306485, 0.00524336, 0.00737642, 0.00328184,
        0.0028132 , 0.0038798 , 0.00286529, 0.00309203, 0.00302236,
        0.00753617, 0.00332863, 0.0031256 , 0.0030443 , 0.00266734,
        0.00266403, 0.00315906, 0.00298023, 0.00275361, 0.00960746,
        0.00414826], dtype=float32)}

0 个答案:

没有答案