x和y必须具有相同的第一个维度,但具有形状(30,)和(1,)

时间:2017-07-20 01:49:00

标签: python numpy

我的Python版本:Python 3.6.1。

当我测试以下程序时,我遇到了一些错误。

可能错误是由training_cost的维度引起的,因为training_cost的维度是1x3。

"""
multiple_eta
~~~~~~~~~~~~~~~
This program shows how different values for the learning rate affect
training.  In particular, we'll plot out how the cost changes using
three different values for eta.
"""

# Standard library
import json
import random
import sys
# My library
sys.path.append('../')
import mnist_loader
import network2
# Third-party libraries
import matplotlib.pyplot as plt
import numpy as np

# Constants
LEARNING_RATES = [0.025, 0.25, 2.5]
COLORS = ['#2A6EA6', '#FFCD33', '#FF7033']
NUM_EPOCHS = 30

def main():
    run_network()
    make_plot()

def run_network():
    """Train networks using three different values for the learning rate,
    and store the cost curves in the file ``multiple_eta.json``, where
    they can later be used by ``make_plot``.
    """
    # Make results more easily reproducible
    random.seed(12345678)
    np.random.seed(12345678)
    training_data, validation_data, test_data = mnist_loader.load_data_wrapper()
    results = []
    for eta in LEARNING_RATES:
        print ("\nTrain a network using eta = "+str(eta))
        net = network2.Network([784,30,10])
        results.append(net.SGD(training_data, NUM_EPOCHS, 10, eta, lmbda = 5.0,
                      evaluation_data=validation_data, monitor_training_cost=True))
    f = open("multiple_eta.json","w")
    json.dump(results,f)
    f.close()

def make_plot():
    f = open("multiple_eta.json", "r")
    results = json.load(f)
    f.close()
    fig = plt.figure()
    ax = fig.add_subplot(111)
    for eta, result, color in zip(LEARNING_RATES, results, COLORS):
        _,_,training_cost,_ = result
        print(training_cost)
    ax.plot(np.arange(NUM_EPOCHS), training_cost, "o-",label = "$\eta$ = "+str(eta),color = color)
    ax.set_xlim([0,NUM_EPOCHS])
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Cost')
    plt.legend(loc = 'upper right')
    plt.show()

if __name__ == "__main__":
    main()

错误信息

Train a network using eta = 0.025
Epoch 0 training complete
Cost on training data: 731.8195315067348

Train a network using eta = 0.25
Epoch 0 training complete
Cost on training data: 2526.705883226454

Train a network using eta = 2.5
Epoch 0 training complete
Cost on training data: 14014.828642157932
[731.8195315067348]
[2526.705883226454]
[14014.828642157932]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-11-b81c2d8fafc3> in <module>()
     64 
     65 if __name__ == "__main__":
---> 66     main()

<ipython-input-11-b81c2d8fafc3> in main()
     26 def main():
     27     run_network()
---> 28     make_plot()
     29 
     30 def run_network():

<ipython-input-11-b81c2d8fafc3> in make_plot()
     56         _,_,training_cost,_ = result
     57         print(training_cost)
---> 58     ax.plot(np.arange(NUM_EPOCHS), training_cost, "o-",label = "$\eta$ = "+str(eta),color = color)
     59     ax.set_xlim([0,NUM_EPOCHS])
     60     ax.set_xlabel('Epoch')

c:\users\ray\appdata\local\programs\python\python36\lib\site-packages\matplotlib\__init__.py in inner(ax, *args, **kwargs)
   1896                     warnings.warn(msg % (label_namer, func.__name__),
   1897                                   RuntimeWarning, stacklevel=2)
-> 1898             return func(ax, *args, **kwargs)
   1899         pre_doc = inner.__doc__
   1900         if pre_doc is None:

c:\users\ray\appdata\local\programs\python\python36\lib\site-packages\matplotlib\axes\_axes.py in plot(self, *args, **kwargs)
   1404         kwargs = cbook.normalize_kwargs(kwargs, _alias_map)
   1405 
-> 1406         for line in self._get_lines(*args, **kwargs):
   1407             self.add_line(line)
   1408             lines.append(line)

c:\users\ray\appdata\local\programs\python\python36\lib\site-packages\matplotlib\axes\_base.py in _grab_next_args(self, *args, **kwargs)
    405                 return
    406             if len(remaining) <= 3:
--> 407                 for seg in self._plot_args(remaining, kwargs):
    408                     yield seg
    409                 return

c:\users\ray\appdata\local\programs\python\python36\lib\site-packages\matplotlib\axes\_base.py in _plot_args(self, tup, kwargs)
    383             x, y = index_of(tup[-1])
    384 
--> 385         x, y = self._xy_from_xy(x, y)
    386 
    387         if self.command == 'plot':

c:\users\ray\appdata\local\programs\python\python36\lib\site-packages\matplotlib\axes\_base.py in _xy_from_xy(self, x, y)
    242         if x.shape[0] != y.shape[0]:
    243             raise ValueError("x and y must have same first dimension, but "
--> 244                              "have shapes {} and {}".format(x.shape, y.shape))
    245         if x.ndim > 2 or y.ndim > 2:
    246             raise ValueError("x and y can be no greater than 2-D, but have "

ValueError: x and y must have same first dimension, but have shapes (30,) and (1,)

可以找到原始代码here。如何解决这个问题?

2 个答案:

答案 0 :(得分:0)

追溯堆栈,看起来问题出现在

ax.plot(np.arange(NUM_EPOCHS), training_cost,  ...

NUM_EPOCHS可能是30,因为它是错误消息中x的形状。 training_cost必须具有形状(1,)。对于普通的x,y图,两个变量应具有相同的点数。

那么为什么training_cost只是一个项目?您的打印显示您在循环中设置它,并且退出循环之前的最后一个值是

[14014.828642157932]

为什么要尝试绘制一个值?相对于30 x点的轴?

results = json.load(f)给出了什么?

答案 1 :(得分:0)

您需要矩阵“ np.arange(NUM_EPOCHS)”

ax.plot(np.mat(np.arange(NUM_EPOCHS)), training_cost, "o-",label = "$\eta$ = "+str(eta),color = color)

如果无法解决 可能需要根据“ training_cost”移置“矩阵”

ax.plot(np.mat(np.arange(NUM_EPOCHS)).transpose(), training_cost, "o-",label = "$\eta$ = "+str(eta),color = color)