TypeError:无法解包不可迭代的AxesSubplot对象

时间:2019-01-26 02:49:37

标签: mysql python-3.x matplotlib

我正在尝试使用python,mysql和matplotlib绘制股价数据图表。 该代码从mysql数据库获取数据。看起来可以成功完成。

数据将转换为列表,然后转换为numpy数组。

我对matplotlib完全陌生。任何帮助将不胜感激。 我尝试搜索错误消息,但还没有发现任何有用的信息。

错误似乎来自代码中的这一行。特别是在无花果关键字周围。

**fig,** ax1 = plt.subplot()

这是完整的代码:

import numpy as np
import csv
import matplotlib.pyplot as plt
from matplotlib import dates, ticker
import matplotlib
from mpl_finance import candlestick_ohlc
import pymysql
import pandas as pd
import sys
import os
from os import listdir
from os.path import isfile, join

ticker = '14D'
start_date = '20190107'
end_date = '20190111'
host = 'localhost'
user = 'user'
password = 'password'
db = 'trading'

def get_data_from_mysql(host, user, password, db, ticker, start_date, end_date):
    # Create empty lists
    date_data = []
    open_data = []
    high_data = []
    low_data = []
    close_data = []
    volume_data = []

    '''
    This function load a csv file to MySQL table according to
    the load_sql statement.
    SELECT * FROM asx WHERE Symbol = '14D' AND Date >= '2019-01-02' AND Date <='2019-01-11'
    '''
    query_sql = 'SELECT * FROM asx WHERE Symbol = %s AND Date >= %s AND Date <= %s'
    args = [ticker, start_date, end_date]
    print('You are in get_data_from_mysql')
    print(args)
    try:
        con = pymysql.connect(host=host,
                                user=user,
                                password=password,
                                db=db,
                                autocommit=True,
                                local_infile=1)
        print('Connected to DB: {}'.format(host))
        # Create cursor and execute Load SQL
        cursor = con.cursor()
        cursor.execute(query_sql, args)
        results = cursor.fetchall()
        # Print the results
        for row in results:
            date_data.append(row[1])
            open_data.append(row[2])
            high_data.append(row[3])
            low_data.append(row[4])
            close_data.append(row[5])
            volume_data.append(row[6])
        con.close()

    except Exception as e:
        print('Error: {}'.format(str(e)))
        sys.exit(1)

    # Convert the list items to string before printing them out
    print(' '.join(map(str, date_data)))
    print(' '.join(map(str, open_data)))
    print(' '.join(map(str, high_data)))
    print(' '.join(map(str, low_data)))
    print(' '.join(map(str, close_data)))
    print(' '.join(map(str, volume_data)))

    # Convert the list into numpy arrays
    open_val = np.array(open_data[1:], dtype=np.float64)
    high_val = np.array(high_data[1:], dtype=np.float64)
    low_val = np.array(low_data[1:], dtype=np.float64)
    close_val = np.array(close_data[1:], dtype=np.float64)

    # Matplotlib needs dates in floating numbers before they can be plotted
    data_dates = []
    for date in date_data[:1]:
        new_date = dates.date2num(date)
        data_dates.append(new_date)

    # Create a compacted dataset for the ohlc plot
    i = 0
    ohlc_data = []
    # While i is less that the number of items in the data_dates list
    while i < len(data_dates):
        stats_1_day = data_dates[i], open_val[i], high_val[i], low_val[i], close_val[i]
        print(stats_1_day)
        ohlc_data.append(stats_1_day)
        # Increment the counter
        i += 1

    # Day format for the x labels
    dayFormatter = dates.DateFormatter('%d-%b-%Y')

    # Create a figure and an axis
    fig, ax1 = plt.subplot()
    # Colorup is green, colordown is red. Alpha  = opacity of the candlestick (80pct opacity)
    candlestick_ohlc(ax1, ohlc_data, width=0.5, colorup='g', colordown='r', alpha='0.8')
    # Data verification check
    '''
    plt.plot(data_dates, open_val)
    plt.plot(data_dates, high_val)
    plt.plot(data_dates, low_val)
    plt.plot(data_dates, close_val)
    '''

    # Format the dates
    ax1.xaxis.set_major_formatter(dayFormatter)
    ax1.xaxis.set_major_locator(ticker.MaxNLocator(10)) # specify the number of ticks in the x axis
    plt.xticks(rotation=30) # rotates the x axis titles by 30 degrees
    plt.xlabel('Date')
    plt.ylabel('14D')
    plt.title('Historical Data for the period of 07-01-2019 to 11-01-2019')
    plt.tight_layout
    plt.show

get_data_from_mysql(host, user, password, db, ticker, start_date, end_date)

这是错误消息:

Traceback (most recent call last):
  File "SimpleJapaneseCandlestick_v0.4.py", line 128, in <module>
    get_data_from_mysql(host, user, password, db, ticker, start_date, end_date)
  File "SimpleJapaneseCandlestick_v0.4.py", line 107, in get_data_from_mysql
    fig, ax1 = plt.subplot()
TypeError: cannot unpack non-iterable AxesSubplot object

谢谢。

0 个答案:

没有答案