我的代码如下:
def distplot_stripplot_buy_signal_3(y_train, y_dev, y_test, y, plot_data_stripplot_distplot, performance_metrics_3_test,
performance_metrics_3_plus_day_train):
fig, axs = plt.subplots(2,1, figsize = (18, 18))
path = 'D:\\DIGITAL_LIBRARY\\Elpis\\cu2003\\Eddie\\2\\'
n_train = y_train.shape[0]
n_dev = y_dev.shape[0]
n_test = y_test.shape[0]
n_all = n_train + n_dev + n_test
start_forecasting_horizon = y.iloc[n_all:n_all+1].index.to_series().dt.strftime('%Y-%m-%d %H-%M-%S').iloc[0]
end_forecasting_horizon = y.iloc[n_all+300:n_all+301].index.to_series().dt.strftime('%Y-%m-%d %H-%M-%S').iloc[0]
timespan_forecasting_horizon = round((y.iloc[n_all+300:n_all+ 1+300].index.to_series().iloc[0] - \
y.iloc[n_all:n_all+1].index.to_series().iloc[0]).total_seconds())
sns.distplot(plot_data_stripplot_distplot.query('y_test == 1').y_test_pred, color = '#0571b0',
label = 'Price rose in the next 3 mnutes',
ax = axs[0])
sns.distplot(plot_data_stripplot_distplot.query('y_test == 0').y_test_pred, color = '#e31a1c',
label = 'Price did not increase in the next 3 mnutes',
ax = axs[0])
axs[0].axvline(performance_metrics_3_test.query('fbeta==fbeta.max()', engine = 'python').index[-1],\
lw = 4, color = 'black', label = 'optimal threshold using buy_signal_3 on the Test set')
axs[0].axvline(performance_metrics_3_plus_day_train.query('fbeta==fbeta.max()', engine = 'python').index[-1],\
lw = 4, color = 'gray', label = 'optimal threshold using buy_signal_3_plus_day on the Train set')
axs[0].legend()
axs[0].set_title(f'Distribution of Probability Scores by Class \n \
What is the probability to Profit in the next3 to 5 minutes -- if you buy now', fontsize = 18)
axs[0].set_xlabel('Predicted Probability that the Price will rise in the next 3-5 minutes')
sns.stripplot(x = 'y_test', y = 'y_test_pred', data = plot_data_stripplot_distplot , ax = axs[1])
axs[1].axhline(performance_metrics_3_test.query('fbeta==fbeta.max()', engine = 'python').index[-1],\
lw = 4, color = 'black', label = 'optimal threshold using buy_signal_3 on the Test set')
axs[1].axhline(performance_metrics_3_plus_day_train.query('fbeta==fbeta.max()', engine = 'python').index[-1],\
lw = 4, color = 'gray', label = 'optimal threshold using buy_signal_3_plus_day on the Train set')
axs[1].set_title(f'Distribution of Probability Scores by Class \n \
What is the probability to Profit in the next 3-5 minutes -- if you buy now', fontsize = 18)
axs[1].set_xlabel('Ground Truth Outcome \n 0 = The Price did not rise, 1 = Price rose')
axs[1].set_ylabel('Predicted Probability of a Price \n rise in the next 3-5 minutes')
axs[1].legend()
fig.subplots_adjust(wspace=0.1, hspace = 0.8)
plt.tight_layout()
contract = y_train.name[1]
number_of_rows_train = y_train.shape[0]
day = y_train.index.to_series().dt.day.iloc[0]
start_time = performance_metrics_3_plus_day_test.start_time.iloc[0]
end_time = performance_metrics_3_plus_day_test.end_time.iloc[0]
duration = performance_metrics_3_plus_day_test.duration.iloc[0]
contract = y_train.name[1]
plt.suptitle(f'Distplot and Stripplot of Probabilities by Class --Ground Truth: buy_signal_3 \n \
contract : {contract} -- day : {day} \n \
forecasting horizon : from: {start_forecasting_horizon} - to : {end_forecasting_horizon} -- \
timespan of forecasting horizon in seconds :{timespan_forecasting_horizon} \n \
Number of timestamps in Train Set: {number_of_rows_train}, Time Span of Test Set : {duration} \n \
Start Time of Test Set: {start_time}, End Time of Test Set: {end_time}', fontsize = 20, y = 1.1)
plt.savefig(\
f'{path}Distplot and Stripplot of Probabilities by Class --Ground Truth: buy_signal_3\
-contract_{contract}_number_of_timesteps_in_train_set{number_of_rows_train}.pdf')
plt.show()
然后调用该函数在Jupyter Notebook中绘制此图片:
distplot_stripplot_buy_signal_3(y_train, y_dev, y_test, y, plot_data_stripplot_distplot, performance_metrics_3_test,performance_metrics_3_plus_day_train)
[![enter image description here][1]][1]
但是保存的文件为空:
为什么会发生这种情况,我该如何纠正?
答案 0 :(得分:0)
好吧,如果您想解决问题,请在代码中输入。我不知道为什么会这样,但是这个问题已经解决了。show()应该在plt.savefig()之前出现
说明:plt.show()清除了整个内容,因此以后任何事情都会在一个新的空白图形上发生。
fig1 = plt.gcf()
plt.show()
plt.draw()
fig1.savefig(y_train, y_dev, y_test, y, plot_data_stripplot_distplot, dpi=100)