如何在Keras中保存每个时期的预测数据

时间:2020-01-24 20:31:44

标签: python tensorflow keras

我一直在尝试获取数据帧或csv文件中的每个纪元,但我做不到。

使用此代码,我可以获得每个时期的输出:

import tensorflow as tf
import keras

# define your custom callback for prediction
class PredictionCallback(tf.keras.callbacks.Callback):    
  def on_epoch_end(self, epoch, logs={}):
    y_pred = self.model.predict(self.validation_data[0])
    print('prediction: {} at epoch: {}'.format(y_pred, epoch))

# ...
autoencoder.fit(x_train,x_train,
                epochs=100,
               batch_size=35,
                shuffle=True,
                validation_data=(x_test, x_test),
                callbacks=[PredictionCallback()])

我作为na输出得到的纪元(这里我只复制了3个纪元来说明):

Train on 360 samples, validate on 90 samples
Epoch 1/100
360/360 [==============================] - 0s 452us/step - loss: 0.1030 - val_loss: 0.0932
prediction: [[0.34389612 0.37337315 0.47786626 0.48220542 0.62026477]
 [0.35700512 0.38369307 0.48124415 0.4824584  0.6131428 ]
 [0.35337454 0.3812024  0.4795871  0.48330936 0.61319363]
 [0.13948059 0.18983081 0.42813146 0.4534687  0.7879183 ]
 [0.20307824 0.24314904 0.46807638 0.43815055 0.7736514 ]
 [0.1998192  0.24222535 0.46196112 0.44424677 0.7657877 ]
 [0.34509167 0.37455976 0.47768456 0.48285195 0.61832213]
 [0.13414457 0.18191731 0.43387797 0.44261545 0.80697894]
 [0.21318719 0.2532636  0.4681181  0.44286066 0.7608668 ]
 [0.36793184 0.39044768 0.48761055 0.47812417 0.6167267 ]
 [0.35455436 0.38038698 0.48340294 0.4788863  0.6217733 ]
 [0.21094766 0.24981478 0.47118074 0.4379853  0.7698178 ]
 [0.12418425 0.17579657 0.416779   0.4576224  0.79258084]
 [0.14085323 0.18823138 0.43772322 0.44212955 0.80335504]
 [0.19756079 0.24132124 0.45835423 0.44768292 0.76152813]
 [0.14671937 0.1961593  0.43356085 0.4511346  0.7865843 ]
 [0.20308566 0.24256581 0.46960002 0.4362489  0.77659965]
 [0.19673023 0.24016443 0.45915073 0.44627088 0.7642524 ]
 [0.35020012 0.3767052  0.48281422 0.4781311  0.6255002 ]
 [0.20725232 0.24907625 0.46367356 0.44566134 0.75949323]
 [0.35835826 0.38374668 0.48361555 0.47992104 0.6177386 ]
 [0.20402372 0.24567202 0.46404788 0.44365358 0.76446176]
 [0.12747633 0.1777477  0.42284876 0.45214012 0.798106  ]
 [0.36679193 0.3891655  0.48810983 0.47710776 0.6193952 ]
 [0.34052953 0.3713761  0.47563657 0.4838536  0.6185636 ]
 [0.35730538 0.38642126 0.47631833 0.48880112 0.59971195]
 [0.14134729 0.18990555 0.43429908 0.44679776 0.7964076 ]
 [0.13267383 0.18463475 0.42017156 0.4591754  0.78416073]
 [0.34687585 0.37267172 0.4848389  0.4744216  0.6347489 ]
 [0.1285612  0.18061405 0.41779673 0.45940813 0.78677225]
 [0.14394826 0.19354925 0.43210676 0.45125064 0.7882222 ]
 [0.13141504 0.18273345 0.42154515 0.45654637 0.7889507 ]
 [0.13619533 0.18744034 0.4240093  0.4565995  0.78552413]
 [0.202586   0.24309975 0.4669258  0.43935475 0.77202916]
 [0.33675078 0.36690074 0.47767183 0.4799488  0.6285275 ]
 [0.3630827  0.39190117 0.4759083  0.49129438 0.59152013]
 [0.19509727 0.23752019 0.4616915  0.44226924 0.7714783 ]
 [0.14365548 0.19170165 0.4366811  0.44523817 0.7971649 ]
 [0.19835302 0.24368668 0.4543566  0.45311975 0.7522273 ]
 [0.20800978 0.24826795 0.467655   0.44102174 0.7665388 ]
 [0.12523556 0.17708108 0.41663343 0.45856375 0.7904394 ]
 [0.3488555  0.37241757 0.4890314  0.4698138  0.6431198 ]
 [0.34667885 0.37659758 0.4765062  0.4848966  0.61327314]
 [0.3446629  0.37111464 0.48389086 0.47485104 0.63497293]
 [0.3433488  0.37114823 0.48137116 0.47757888 0.63005847]
 [0.20048887 0.24463984 0.45751095 0.4501779  0.7558526 ]
 [0.3637992  0.38934845 0.48230854 0.48342952 0.6077196 ]
 [0.35611618 0.38505888 0.4768707  0.48769307 0.6026441 ]
 [0.20313534 0.24394819 0.46616802 0.44056785 0.7698331 ]
 [0.12796089 0.17899826 0.42067355 0.45527738 0.79325414]
 [0.36706376 0.39267203 0.48159686 0.48543805 0.601877  ]
 [0.20752326 0.24767563 0.4679206  0.4404595  0.7676877 ]
 [0.3317272  0.36230478 0.4776036  0.47826046 0.63452303]
 [0.19272554 0.23683271 0.4571363  0.44682062 0.76560134]
 [0.13748333 0.18681785 0.43041503 0.4492419  0.7954247 ]
 [0.19966611 0.24347341 0.45834872 0.44872034 0.7586832 ]
 [0.33442613 0.36698973 0.47310665 0.48491305 0.6194912 ]
 [0.21322897 0.2527365  0.46954548 0.441091   0.76369417]
 [0.35778782 0.38809884 0.47384343 0.4921052  0.59249157]
 [0.196098   0.23905295 0.46035945 0.44443798 0.767516  ]
 [0.21373272 0.2530284  0.47007236 0.4406641  0.7641164 ]
 [0.14440677 0.19148055 0.4398795  0.4416542  0.80179286]
 [0.14438662 0.1904681  0.44288903 0.4378359  0.8071472 ]
 [0.36326277 0.39048672 0.47905794 0.4873635  0.5997318 ]
 [0.33750865 0.36796397 0.47692457 0.48116165 0.6256533 ]
 [0.13969895 0.18784368 0.43494454 0.44491437 0.8001623 ]
 [0.21023598 0.24946672 0.47026443 0.4387992  0.7689006 ]
 [0.21659935 0.2561568  0.46938002 0.4428502  0.75909597]
 [0.34917396 0.37839937 0.4774884  0.48451743 0.6127981 ]
 [0.3608024  0.3865621  0.48242915 0.48225665 0.6116648 ]
 [0.20789057 0.24737614 0.4696203  0.4385049  0.77058303]
 [0.20092338 0.2424278  0.46432766 0.44181022 0.7690537 ]
 [0.14367282 0.18989041 0.44224328 0.43820238 0.8070762 ]
 [0.34553868 0.37238124 0.48294276 0.4763538  0.6314702 ]
 [0.34945858 0.37790424 0.47901568 0.4826815  0.6164701 ]
 [0.19864962 0.24241135 0.45841193 0.4481441  0.7601801 ]
 [0.1464892  0.19345969 0.44086337 0.44170973 0.8004205 ]
 [0.2016359  0.24118432 0.46940258 0.4357999  0.77804554]
 [0.20230329 0.24341127 0.46538913 0.44114456 0.7693664 ]
 [0.20564258 0.24739474 0.46382207 0.4447102  0.76189613]
 [0.14310995 0.19170648 0.43482554 0.44725257 0.79461265]
 [0.20061046 0.24167252 0.46546033 0.44023833 0.77169263]
 [0.1395224  0.18776402 0.43458012 0.44526353 0.7997805 ]
 [0.3475588  0.3757964  0.47976196 0.48107868 0.62074584]
 [0.14598793 0.19323191 0.43988067 0.44264248 0.79940563]
 [0.20084766 0.24182764 0.46567947 0.44007844 0.7718169 ]
 [0.2142604  0.2540593  0.4688056  0.44249493 0.7608948 ]
 [0.1270327  0.17875427 0.41801888 0.45804867 0.78988326]
 [0.13120902 0.18291286 0.42024544 0.4580799  0.78683406]
 [0.14218849 0.19055101 0.4352023  0.44618642 0.7967429 ]] at epoch: 0
Epoch 2/100
360/360 [==============================] - 0s 47us/step - loss: 0.0979 - val_loss: 0.0893
prediction: [[0.35732222 0.3843055  0.4859484  0.48535848 0.614601  ]
 [0.36997944 0.39424476 0.48890743 0.48561424 0.6075097 ]
 [0.36640897 0.39180788 0.48731565 0.4864407  0.60757226]
 [0.15470544 0.2043908  0.4453513  0.45744687 0.7822104 ]
 [0.22013906 0.25799784 0.48291984 0.44272035 0.7675872 ]
 [0.21660301 0.25693047 0.47665784 0.44863018 0.75973165]
 [0.35839027 0.38539642 0.48566985 0.48600549 0.61266136]
 [0.14958984 0.19669074 0.4520371  0.4468543  0.8013898 ]
 [0.23011923 0.26790065 0.48231584 0.44729978 0.75475234]
 [0.38086855 0.40094742 0.49519378 0.48138314 0.61104286]
 [0.36780554 0.39116603 0.49130145 0.48215204 0.61605173]
 [0.22799313 0.26455355 0.48564622 0.44259876 0.76370716]
 [0.13870358 0.19014329 0.43467307 0.46143445 0.78697145]
 [0.15658155 0.20308846 0.4555388  0.44638947 0.7977196 ]
 [0.21404067 0.25584692 0.47287792 0.452028   0.7554476 ]
 [0.16231912 0.21085566 0.45056292 0.45516118 0.78085285]
 [0.22022358 0.2574654  0.48453245 0.44087476 0.77054095]
 [0.21333212 0.25478655 0.4738298  0.45062605 0.75818825]
 [0.3636607  0.38765424 0.49089357 0.48140967 0.6197581 ]
 [0.22405118 0.26369625 0.47799864 0.45000288 0.75340843]
 [0.3714857  0.39441207 0.49138355 0.4831366  0.6120605 ]
 [0.220752   0.26028338 0.47850764 0.44811335 0.7583585 ]
 [0.14232019 0.19224715 0.4408622  0.45610636 0.79250515]
 [0.3798129  0.39973748 0.49576357 0.4804052  0.61367524]
 [0.35395893 0.38231993 0.4837401  0.48695916 0.6129259 ]
 [0.36994076 0.39669037 0.48369867 0.49174997 0.59424824]
 [0.15691277 0.20465478 0.45181996 0.4509385  0.79073536]
 [0.14741346 0.19898406 0.4374162  0.46300626 0.7784644 ]
 [0.36057183 0.38382792 0.49314764 0.4778388  0.6289046 ]
 [0.14315149 0.19493252 0.43528277 0.463211   0.78111166]
 [0.15941921 0.2081998  0.4492206  0.45528746 0.7825019 ]
 [0.1462234  0.19714814 0.439036   0.46042946 0.7832836 ]
 [0.15118873 0.20190248 0.44123423 0.4604865  0.779825  ]
 [0.21957254 0.25790486 0.48171806 0.44389832 0.76595926]
 [0.35038683 0.37803873 0.4860064  0.4832129  0.6227732 ]
 [0.37542376 0.4019188  0.48302317 0.494155   0.5861449 ]
 [0.21179065 0.252222   0.47659698 0.44676298 0.7654203 ]
 [0.15938073 0.20652103 0.4541801  0.44941494 0.79148793]
 [0.21466365 0.25808823 0.46861535 0.45728454 0.74615365]
 [0.22492918 0.2629586  0.48211256 0.44553623 0.76043725]
 [0.13977218 0.19142461 0.4344142  0.462351   0.78481704]
 [0.36268726 0.3836893  0.49745688 0.47339272 0.6371733 ]
 [0.35981232 0.38729733 0.48434967 0.48798668 0.6076673 ]
 [0.35836428 0.38228482 0.49222758 0.47826657 0.6291283 ]
 [0.35698575 0.38226214 0.48964757 0.480896   0.62427586]
 [0.21700248 0.2591525  0.47186983 0.4544009  0.7497854 ]
 [0.37650543 0.3996746  0.4897336  0.4865505  0.60213363]
 [0.3688224  0.39539248 0.48432076 0.49068648 0.5971417 ]
 [0.21998331 0.25865078 0.4808128  0.44511697 0.7637404 ]
 [0.14272577 0.19344449 0.43848625 0.45915654 0.78763115]
 [0.37954012 0.4028109  0.48883012 0.4885005  0.5963464 ]
 [0.22458777 0.2624728  0.4825138  0.44493523 0.76161456]
 [0.3456758  0.37369645 0.4862094  0.4815521  0.6287333 ]
 [0.20922479 0.2514378  0.47193563 0.45117268 0.7595469 ]
 [0.15280128 0.20146582 0.4479947  0.45332208 0.7897587 ]
 [0.21615177 0.25797677 0.47275183 0.45302525 0.7525966 ]
 [0.3478763  0.37798136 0.48129317 0.4880186  0.61384493]
 [0.23026142 0.26744086 0.4838441  0.44556963 0.7575892 ]
 [0.37015066 0.3981576  0.48101804 0.49497786 0.5870967 ]
 [0.21280283 0.25375354 0.4751804  0.44883034 0.7614647 ]
 [0.23086849 0.26779807 0.48442724 0.4451209  0.75802517]
 [0.16032186 0.20641011 0.4575643  0.44590983 0.79614353]
 [0.16039735 0.2054488  0.46074235 0.44222015 0.8015151 ]
 [0.37576467 0.40064627 0.48631835 0.49036396 0.594242  ]
 [0.351057   0.37903073 0.4851778  0.48439184 0.61991966]
 [0.15533501 0.20266908 0.4527164  0.44906908 0.794526  ]
 [0.22735363 0.26426458 0.48479953 0.44333953 0.7628144 ]
 [0.23359635 0.27079648 0.4834756  0.44726843 0.7529824 ]
 [0.36239177 0.38913894 0.48535174 0.48756993 0.6072227 ]
 [0.3736729  0.39702362 0.48999622 0.48541307 0.6060443 ]
 [0.22492436 0.26214594 0.48420691 0.44309115 0.764485  ]
 [0.21782109 0.25719512 0.47909576 0.44626558 0.76299584]
 [0.15965301 0.2048631  0.46012452 0.4425694  0.8014504 ]
 [0.35923162 0.38352633 0.49123988 0.47968817 0.6256813 ]
 [0.3626297  0.38862965 0.48687848 0.48583895 0.6108204 ]
 [0.21529478 0.25704074 0.47299138 0.4524002  0.7541275 ]
 [0.16242692 0.20836812 0.45838898 0.4459951  0.7947432 ]
 [0.21876335 0.25609392 0.48440838 0.4404432  0.7719944 ]
 [0.21922132 0.25817454 0.4801142  0.44563147 0.7632961 ]
 [0.22234026 0.26196405 0.47816664 0.4491368  0.7557899 ]
 [0.15871033 0.2064518  0.45221338 0.45139268 0.7889187 ]
 [0.21754101 0.25646544 0.48029417 0.44475472 0.76562977]
 [0.15510672 0.20255497 0.4523122  0.44942728 0.7941356 ]
 [0.36078405 0.3865859  0.48770133 0.48432088 0.61502707]
 [0.16188675 0.20812565 0.4573929  0.44689578 0.7937268 ]
 [0.21768725 0.2565524  0.4804409  0.44464153 0.7657378 ]
 [0.23129413 0.26875257 0.48303518 0.4469082  0.754797  ]
 [0.14159614 0.19307613 0.43566018 0.4618925  0.7842368 ]
 [0.146      0.1973272  0.4377039  0.4618938  0.781173  ]
 [0.15777522 0.20529687 0.4526826  0.45036286 0.79106134]] at epoch: 1
Epoch 3/100
360/360 [==============================] - 0s 54us/step - loss: 0.0939 - val_loss: 0.0858
prediction: [[0.3701952  0.3945966  0.49327147 0.4881377  0.6095992 ]
 [0.3824011  0.40416604 0.4958423  0.48839453 0.60254055]
 [0.3788927  0.40178216 0.49431235 0.48919833 0.6026142 ]
 [0.17017335 0.21866179 0.46118838 0.46103063 0.7770189 ]
 [0.23705906 0.2723304  0.4964591  0.4468463  0.7620789 ]
 [0.23325539 0.2711235  0.49007675 0.45258096 0.7542403 ]
 [0.37114    0.3955961  0.49290463 0.48878393 0.60766464]
 [0.16535434 0.21122459 0.4687259  0.45068732 0.7962855 ]
 [0.24684757 0.2819891  0.49526238 0.4513008  0.74921185]
 [0.39323795 0.41081223 0.50204635 0.4842617  0.606022  ]

但是如何在数据框或csv文件中获取此信息,以便以后可以使用它通过d3在我的应用程序中进行可视化显示。

我尝试了在keras文档中找到的方法:

from keras.callbacks import CSVLogger

csv_logger = CSVLogger("model_history_log.csv", append=True)
autoencoder.fit(x_train,x_train,
                epochs=200,
               batch_size=35,
                shuffle=True,
                validation_data=(x_test, x_test),
                callbacks=[csv_logger])

但这并不能为我提供每个时期的预测,而只是将csv中各个时期的损失值保存下来。

2 个答案:

答案 0 :(得分:2)

将保存内容放入回调函数中

filename='training'

  def on_epoch_end(self, epoch, logs={}):
    y_pred = self.model.predict(self.validation_data[0])
    print('prediction: {} at epoch: {}'.format(y_pred, epoch))
    pd.DataFrame(y_pred).assign(epoch=epoch).to_csv('{}_{}.csv'.format(filename, epoch))

答案 1 :(得分:0)

扩展凯南的护腕。

由于根据Tensorflow回调文档弃用了validate_data,因此需要完成以下操作才能实现目标:

class Metrics(Callback): 
   def __init__(self, val_data):
      super().__init__()
      self.validation_data = val_data
   
   """the rest is Kenan's code"""         
   def on_epoch_end(self, epoch, logs={}):
      y_pred = self.model.predict(self.validation_data[0])
      print('prediction: {} at epoch: {}'.format(y_pred, epoch))
      pd.DataFrame(y_pred).assign(epoch=epoch).to_csv('{}_{}.csv'.format(filename, epoch))

此外,您可以改用模型名称作为文件名。由于model.name变量是只读的,因此请执行以下操作:

model._name = 'SOME_NAME'

,然后在Metrics类中使用model._name。