我一直在尝试获取数据帧或csv文件中的每个纪元,但我做不到。
使用此代码,我可以获得每个时期的输出:
import tensorflow as tf
import keras
# define your custom callback for prediction
class PredictionCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.validation_data[0])
print('prediction: {} at epoch: {}'.format(y_pred, epoch))
# ...
autoencoder.fit(x_train,x_train,
epochs=100,
batch_size=35,
shuffle=True,
validation_data=(x_test, x_test),
callbacks=[PredictionCallback()])
我作为na输出得到的纪元(这里我只复制了3个纪元来说明):
Train on 360 samples, validate on 90 samples
Epoch 1/100
360/360 [==============================] - 0s 452us/step - loss: 0.1030 - val_loss: 0.0932
prediction: [[0.34389612 0.37337315 0.47786626 0.48220542 0.62026477]
[0.35700512 0.38369307 0.48124415 0.4824584 0.6131428 ]
[0.35337454 0.3812024 0.4795871 0.48330936 0.61319363]
[0.13948059 0.18983081 0.42813146 0.4534687 0.7879183 ]
[0.20307824 0.24314904 0.46807638 0.43815055 0.7736514 ]
[0.1998192 0.24222535 0.46196112 0.44424677 0.7657877 ]
[0.34509167 0.37455976 0.47768456 0.48285195 0.61832213]
[0.13414457 0.18191731 0.43387797 0.44261545 0.80697894]
[0.21318719 0.2532636 0.4681181 0.44286066 0.7608668 ]
[0.36793184 0.39044768 0.48761055 0.47812417 0.6167267 ]
[0.35455436 0.38038698 0.48340294 0.4788863 0.6217733 ]
[0.21094766 0.24981478 0.47118074 0.4379853 0.7698178 ]
[0.12418425 0.17579657 0.416779 0.4576224 0.79258084]
[0.14085323 0.18823138 0.43772322 0.44212955 0.80335504]
[0.19756079 0.24132124 0.45835423 0.44768292 0.76152813]
[0.14671937 0.1961593 0.43356085 0.4511346 0.7865843 ]
[0.20308566 0.24256581 0.46960002 0.4362489 0.77659965]
[0.19673023 0.24016443 0.45915073 0.44627088 0.7642524 ]
[0.35020012 0.3767052 0.48281422 0.4781311 0.6255002 ]
[0.20725232 0.24907625 0.46367356 0.44566134 0.75949323]
[0.35835826 0.38374668 0.48361555 0.47992104 0.6177386 ]
[0.20402372 0.24567202 0.46404788 0.44365358 0.76446176]
[0.12747633 0.1777477 0.42284876 0.45214012 0.798106 ]
[0.36679193 0.3891655 0.48810983 0.47710776 0.6193952 ]
[0.34052953 0.3713761 0.47563657 0.4838536 0.6185636 ]
[0.35730538 0.38642126 0.47631833 0.48880112 0.59971195]
[0.14134729 0.18990555 0.43429908 0.44679776 0.7964076 ]
[0.13267383 0.18463475 0.42017156 0.4591754 0.78416073]
[0.34687585 0.37267172 0.4848389 0.4744216 0.6347489 ]
[0.1285612 0.18061405 0.41779673 0.45940813 0.78677225]
[0.14394826 0.19354925 0.43210676 0.45125064 0.7882222 ]
[0.13141504 0.18273345 0.42154515 0.45654637 0.7889507 ]
[0.13619533 0.18744034 0.4240093 0.4565995 0.78552413]
[0.202586 0.24309975 0.4669258 0.43935475 0.77202916]
[0.33675078 0.36690074 0.47767183 0.4799488 0.6285275 ]
[0.3630827 0.39190117 0.4759083 0.49129438 0.59152013]
[0.19509727 0.23752019 0.4616915 0.44226924 0.7714783 ]
[0.14365548 0.19170165 0.4366811 0.44523817 0.7971649 ]
[0.19835302 0.24368668 0.4543566 0.45311975 0.7522273 ]
[0.20800978 0.24826795 0.467655 0.44102174 0.7665388 ]
[0.12523556 0.17708108 0.41663343 0.45856375 0.7904394 ]
[0.3488555 0.37241757 0.4890314 0.4698138 0.6431198 ]
[0.34667885 0.37659758 0.4765062 0.4848966 0.61327314]
[0.3446629 0.37111464 0.48389086 0.47485104 0.63497293]
[0.3433488 0.37114823 0.48137116 0.47757888 0.63005847]
[0.20048887 0.24463984 0.45751095 0.4501779 0.7558526 ]
[0.3637992 0.38934845 0.48230854 0.48342952 0.6077196 ]
[0.35611618 0.38505888 0.4768707 0.48769307 0.6026441 ]
[0.20313534 0.24394819 0.46616802 0.44056785 0.7698331 ]
[0.12796089 0.17899826 0.42067355 0.45527738 0.79325414]
[0.36706376 0.39267203 0.48159686 0.48543805 0.601877 ]
[0.20752326 0.24767563 0.4679206 0.4404595 0.7676877 ]
[0.3317272 0.36230478 0.4776036 0.47826046 0.63452303]
[0.19272554 0.23683271 0.4571363 0.44682062 0.76560134]
[0.13748333 0.18681785 0.43041503 0.4492419 0.7954247 ]
[0.19966611 0.24347341 0.45834872 0.44872034 0.7586832 ]
[0.33442613 0.36698973 0.47310665 0.48491305 0.6194912 ]
[0.21322897 0.2527365 0.46954548 0.441091 0.76369417]
[0.35778782 0.38809884 0.47384343 0.4921052 0.59249157]
[0.196098 0.23905295 0.46035945 0.44443798 0.767516 ]
[0.21373272 0.2530284 0.47007236 0.4406641 0.7641164 ]
[0.14440677 0.19148055 0.4398795 0.4416542 0.80179286]
[0.14438662 0.1904681 0.44288903 0.4378359 0.8071472 ]
[0.36326277 0.39048672 0.47905794 0.4873635 0.5997318 ]
[0.33750865 0.36796397 0.47692457 0.48116165 0.6256533 ]
[0.13969895 0.18784368 0.43494454 0.44491437 0.8001623 ]
[0.21023598 0.24946672 0.47026443 0.4387992 0.7689006 ]
[0.21659935 0.2561568 0.46938002 0.4428502 0.75909597]
[0.34917396 0.37839937 0.4774884 0.48451743 0.6127981 ]
[0.3608024 0.3865621 0.48242915 0.48225665 0.6116648 ]
[0.20789057 0.24737614 0.4696203 0.4385049 0.77058303]
[0.20092338 0.2424278 0.46432766 0.44181022 0.7690537 ]
[0.14367282 0.18989041 0.44224328 0.43820238 0.8070762 ]
[0.34553868 0.37238124 0.48294276 0.4763538 0.6314702 ]
[0.34945858 0.37790424 0.47901568 0.4826815 0.6164701 ]
[0.19864962 0.24241135 0.45841193 0.4481441 0.7601801 ]
[0.1464892 0.19345969 0.44086337 0.44170973 0.8004205 ]
[0.2016359 0.24118432 0.46940258 0.4357999 0.77804554]
[0.20230329 0.24341127 0.46538913 0.44114456 0.7693664 ]
[0.20564258 0.24739474 0.46382207 0.4447102 0.76189613]
[0.14310995 0.19170648 0.43482554 0.44725257 0.79461265]
[0.20061046 0.24167252 0.46546033 0.44023833 0.77169263]
[0.1395224 0.18776402 0.43458012 0.44526353 0.7997805 ]
[0.3475588 0.3757964 0.47976196 0.48107868 0.62074584]
[0.14598793 0.19323191 0.43988067 0.44264248 0.79940563]
[0.20084766 0.24182764 0.46567947 0.44007844 0.7718169 ]
[0.2142604 0.2540593 0.4688056 0.44249493 0.7608948 ]
[0.1270327 0.17875427 0.41801888 0.45804867 0.78988326]
[0.13120902 0.18291286 0.42024544 0.4580799 0.78683406]
[0.14218849 0.19055101 0.4352023 0.44618642 0.7967429 ]] at epoch: 0
Epoch 2/100
360/360 [==============================] - 0s 47us/step - loss: 0.0979 - val_loss: 0.0893
prediction: [[0.35732222 0.3843055 0.4859484 0.48535848 0.614601 ]
[0.36997944 0.39424476 0.48890743 0.48561424 0.6075097 ]
[0.36640897 0.39180788 0.48731565 0.4864407 0.60757226]
[0.15470544 0.2043908 0.4453513 0.45744687 0.7822104 ]
[0.22013906 0.25799784 0.48291984 0.44272035 0.7675872 ]
[0.21660301 0.25693047 0.47665784 0.44863018 0.75973165]
[0.35839027 0.38539642 0.48566985 0.48600549 0.61266136]
[0.14958984 0.19669074 0.4520371 0.4468543 0.8013898 ]
[0.23011923 0.26790065 0.48231584 0.44729978 0.75475234]
[0.38086855 0.40094742 0.49519378 0.48138314 0.61104286]
[0.36780554 0.39116603 0.49130145 0.48215204 0.61605173]
[0.22799313 0.26455355 0.48564622 0.44259876 0.76370716]
[0.13870358 0.19014329 0.43467307 0.46143445 0.78697145]
[0.15658155 0.20308846 0.4555388 0.44638947 0.7977196 ]
[0.21404067 0.25584692 0.47287792 0.452028 0.7554476 ]
[0.16231912 0.21085566 0.45056292 0.45516118 0.78085285]
[0.22022358 0.2574654 0.48453245 0.44087476 0.77054095]
[0.21333212 0.25478655 0.4738298 0.45062605 0.75818825]
[0.3636607 0.38765424 0.49089357 0.48140967 0.6197581 ]
[0.22405118 0.26369625 0.47799864 0.45000288 0.75340843]
[0.3714857 0.39441207 0.49138355 0.4831366 0.6120605 ]
[0.220752 0.26028338 0.47850764 0.44811335 0.7583585 ]
[0.14232019 0.19224715 0.4408622 0.45610636 0.79250515]
[0.3798129 0.39973748 0.49576357 0.4804052 0.61367524]
[0.35395893 0.38231993 0.4837401 0.48695916 0.6129259 ]
[0.36994076 0.39669037 0.48369867 0.49174997 0.59424824]
[0.15691277 0.20465478 0.45181996 0.4509385 0.79073536]
[0.14741346 0.19898406 0.4374162 0.46300626 0.7784644 ]
[0.36057183 0.38382792 0.49314764 0.4778388 0.6289046 ]
[0.14315149 0.19493252 0.43528277 0.463211 0.78111166]
[0.15941921 0.2081998 0.4492206 0.45528746 0.7825019 ]
[0.1462234 0.19714814 0.439036 0.46042946 0.7832836 ]
[0.15118873 0.20190248 0.44123423 0.4604865 0.779825 ]
[0.21957254 0.25790486 0.48171806 0.44389832 0.76595926]
[0.35038683 0.37803873 0.4860064 0.4832129 0.6227732 ]
[0.37542376 0.4019188 0.48302317 0.494155 0.5861449 ]
[0.21179065 0.252222 0.47659698 0.44676298 0.7654203 ]
[0.15938073 0.20652103 0.4541801 0.44941494 0.79148793]
[0.21466365 0.25808823 0.46861535 0.45728454 0.74615365]
[0.22492918 0.2629586 0.48211256 0.44553623 0.76043725]
[0.13977218 0.19142461 0.4344142 0.462351 0.78481704]
[0.36268726 0.3836893 0.49745688 0.47339272 0.6371733 ]
[0.35981232 0.38729733 0.48434967 0.48798668 0.6076673 ]
[0.35836428 0.38228482 0.49222758 0.47826657 0.6291283 ]
[0.35698575 0.38226214 0.48964757 0.480896 0.62427586]
[0.21700248 0.2591525 0.47186983 0.4544009 0.7497854 ]
[0.37650543 0.3996746 0.4897336 0.4865505 0.60213363]
[0.3688224 0.39539248 0.48432076 0.49068648 0.5971417 ]
[0.21998331 0.25865078 0.4808128 0.44511697 0.7637404 ]
[0.14272577 0.19344449 0.43848625 0.45915654 0.78763115]
[0.37954012 0.4028109 0.48883012 0.4885005 0.5963464 ]
[0.22458777 0.2624728 0.4825138 0.44493523 0.76161456]
[0.3456758 0.37369645 0.4862094 0.4815521 0.6287333 ]
[0.20922479 0.2514378 0.47193563 0.45117268 0.7595469 ]
[0.15280128 0.20146582 0.4479947 0.45332208 0.7897587 ]
[0.21615177 0.25797677 0.47275183 0.45302525 0.7525966 ]
[0.3478763 0.37798136 0.48129317 0.4880186 0.61384493]
[0.23026142 0.26744086 0.4838441 0.44556963 0.7575892 ]
[0.37015066 0.3981576 0.48101804 0.49497786 0.5870967 ]
[0.21280283 0.25375354 0.4751804 0.44883034 0.7614647 ]
[0.23086849 0.26779807 0.48442724 0.4451209 0.75802517]
[0.16032186 0.20641011 0.4575643 0.44590983 0.79614353]
[0.16039735 0.2054488 0.46074235 0.44222015 0.8015151 ]
[0.37576467 0.40064627 0.48631835 0.49036396 0.594242 ]
[0.351057 0.37903073 0.4851778 0.48439184 0.61991966]
[0.15533501 0.20266908 0.4527164 0.44906908 0.794526 ]
[0.22735363 0.26426458 0.48479953 0.44333953 0.7628144 ]
[0.23359635 0.27079648 0.4834756 0.44726843 0.7529824 ]
[0.36239177 0.38913894 0.48535174 0.48756993 0.6072227 ]
[0.3736729 0.39702362 0.48999622 0.48541307 0.6060443 ]
[0.22492436 0.26214594 0.48420691 0.44309115 0.764485 ]
[0.21782109 0.25719512 0.47909576 0.44626558 0.76299584]
[0.15965301 0.2048631 0.46012452 0.4425694 0.8014504 ]
[0.35923162 0.38352633 0.49123988 0.47968817 0.6256813 ]
[0.3626297 0.38862965 0.48687848 0.48583895 0.6108204 ]
[0.21529478 0.25704074 0.47299138 0.4524002 0.7541275 ]
[0.16242692 0.20836812 0.45838898 0.4459951 0.7947432 ]
[0.21876335 0.25609392 0.48440838 0.4404432 0.7719944 ]
[0.21922132 0.25817454 0.4801142 0.44563147 0.7632961 ]
[0.22234026 0.26196405 0.47816664 0.4491368 0.7557899 ]
[0.15871033 0.2064518 0.45221338 0.45139268 0.7889187 ]
[0.21754101 0.25646544 0.48029417 0.44475472 0.76562977]
[0.15510672 0.20255497 0.4523122 0.44942728 0.7941356 ]
[0.36078405 0.3865859 0.48770133 0.48432088 0.61502707]
[0.16188675 0.20812565 0.4573929 0.44689578 0.7937268 ]
[0.21768725 0.2565524 0.4804409 0.44464153 0.7657378 ]
[0.23129413 0.26875257 0.48303518 0.4469082 0.754797 ]
[0.14159614 0.19307613 0.43566018 0.4618925 0.7842368 ]
[0.146 0.1973272 0.4377039 0.4618938 0.781173 ]
[0.15777522 0.20529687 0.4526826 0.45036286 0.79106134]] at epoch: 1
Epoch 3/100
360/360 [==============================] - 0s 54us/step - loss: 0.0939 - val_loss: 0.0858
prediction: [[0.3701952 0.3945966 0.49327147 0.4881377 0.6095992 ]
[0.3824011 0.40416604 0.4958423 0.48839453 0.60254055]
[0.3788927 0.40178216 0.49431235 0.48919833 0.6026142 ]
[0.17017335 0.21866179 0.46118838 0.46103063 0.7770189 ]
[0.23705906 0.2723304 0.4964591 0.4468463 0.7620789 ]
[0.23325539 0.2711235 0.49007675 0.45258096 0.7542403 ]
[0.37114 0.3955961 0.49290463 0.48878393 0.60766464]
[0.16535434 0.21122459 0.4687259 0.45068732 0.7962855 ]
[0.24684757 0.2819891 0.49526238 0.4513008 0.74921185]
[0.39323795 0.41081223 0.50204635 0.4842617 0.606022 ]
但是如何在数据框或csv文件中获取此信息,以便以后可以使用它通过d3在我的应用程序中进行可视化显示。
我尝试了在keras文档中找到的方法:
from keras.callbacks import CSVLogger
csv_logger = CSVLogger("model_history_log.csv", append=True)
autoencoder.fit(x_train,x_train,
epochs=200,
batch_size=35,
shuffle=True,
validation_data=(x_test, x_test),
callbacks=[csv_logger])
但这并不能为我提供每个时期的预测,而只是将csv中各个时期的损失值保存下来。
答案 0 :(得分:2)
将保存内容放入回调函数中
filename='training'
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.validation_data[0])
print('prediction: {} at epoch: {}'.format(y_pred, epoch))
pd.DataFrame(y_pred).assign(epoch=epoch).to_csv('{}_{}.csv'.format(filename, epoch))
答案 1 :(得分:0)
扩展凯南的护腕。
由于根据Tensorflow回调文档弃用了validate_data,因此需要完成以下操作才能实现目标:
class Metrics(Callback):
def __init__(self, val_data):
super().__init__()
self.validation_data = val_data
"""the rest is Kenan's code"""
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.validation_data[0])
print('prediction: {} at epoch: {}'.format(y_pred, epoch))
pd.DataFrame(y_pred).assign(epoch=epoch).to_csv('{}_{}.csv'.format(filename, epoch))
此外,您可以改用模型名称作为文件名。由于model.name变量是只读的,因此请执行以下操作:
model._name = 'SOME_NAME'
,然后在Metrics类中使用model._name。