在方法参数不同时共享代码的继承最佳实践?

时间:2016-07-12 16:26:38

标签: python oop inheritance

我有一个AWS Redshift包装器类,它可以为我S3自动执行类似类型的加载,我最近将其调整为适用于Spark个工作需要一个清单,而需要一个稍微不同的COPY语句。除了这一种方法之外,所有其他代码都是可转移和可重用的。由于方法参数不同,PyCharm给了我一个警告,我想知道是否有最佳做法"这样做的方法。

class RedshiftLoader(PrettyStr):
    def __init__(self,
                 s3_credentials=config3.S3_INFO,
                 redshift_db_credentials=config3.REDSHIFT_POSTGRES_INFO_PROD,
                 table_name=None,
                 schema_name=None,
                 dev_db_credentials=config3.REDSHIFT_POSTGRES_INFO,
                 safe_load=False,
                 truncate=False):
...
def copy_to_db(self, database_credentials, copy_from, manifest=False):
        """
        Copies data from a file on S3 to a Redshift table.  Data must be
        properly formatted and in the right order, etc...

        :param database_credentials: A dictionary containing the host, port,
        database name, username, and password.  Keys must match example:

        REDSHIFT_POSTGRES_INFO = {
            'host': REDSHIFT_HOST,
            'port': REDSHIFT_PORT,
            'database': REDSHIFT_DATABASE_DEV,
            'user': REDSHIFT_USER,
            'password': REDSHIFT_PASS
        }
        :param copy_from: The location of the file on the S3 server.
        :param manifest: True if a manifest file is to be used in the copy
        step, False otherwise.

        :return: None
        """
        if not self.table_name:
            raise AttributeError('A table must be specified.')
        s3_access = self.s3_credentials['aws_access_key_id']
        s3_secret = self.s3_credentials['aws_secret_access_key']
        manifest = 'MANIFEST' if manifest else ''
        logger.info('Accessing {table}'.format(table=self.table_name))
        try:
            with ppg2.connect(**database_credentials) as conn:
                cur = conn.cursor()

                if self.truncate:
                    RedshiftLoader.truncate_table(self.table_name, cur)

                load = '''
                COPY {table}
                FROM '{copy_from}'
                CREDENTIALS 'aws_access_key_id={pub};aws_secret_access_key={priv}'
                DELIMITER '|'
                GZIP
                TRIMBLANKS
                TRUNCATECOLUMNS
                ACCEPTINVCHARS
                TIMEFORMAT 'auto'
                DATEFORMAT 'auto'
                {manifest}
                '''.format(
                    table=self.table_name,
                    copy_from=copy_from,
                    pub=s3_access,
                    priv=s3_secret,
                    manifest=manifest
                )
                logger.info('Copying to {table}'.format(
                    table=self.table_name))
                cur.execute(load)
                conn.commit()
                logger.info('Copy complete.')
        except ppg2.Error as e:
            logger.critical('Error occurred during load: {error}'.format(
                error=e
            ))
            raise

然后是子类:

class SparkRedshiftLoader(RedshiftLoader):
    def copy_to_db(self, database_credentials, copy_from):
        """
        Copies data from a file on S3 to a Redshift table.  Data must be
        properly formatted and in the right order, etc...

        :param database_credentials: A dictionary containing the host, port,
        database name, username, and password.  Keys must match example:

        REDSHIFT_POSTGRES_INFO = {
            'host': REDSHIFT_HOST,
            'port': REDSHIFT_PORT,
            'database': REDSHIFT_DATABASE_DEV,
            'user': REDSHIFT_USER,
            'password': REDSHIFT_PASS
        }
        :param copy_from: The location of the file on the S3 server.  Assumes
        that it is being passed an 's3n' version of the path (common in Spark
        and Hadoop) and will automatically convert to the proper format.

        :return: None
        """
        if not self.table_name:
            raise AttributeError('A table must be specified.')
        s3_access = self.s3_credentials['aws_access_key_id']
        s3_secret = self.s3_credentials['aws_secret_access_key']
        copy_from = copy_from.replace('s3n', 's3')
        logging.info('Accessing {table}'.format(table=self.table_name))
        try:
            with ppg2.connect(**database_credentials) as conn:
                cur = conn.cursor()

                if self.truncate:
                    SparkRedshiftLoader.truncate_table(self.table_name, cur)

                load = '''
                COPY {table}
                FROM '{copy_from}'
                CREDENTIALS 'aws_access_key_id={pub};aws_secret_access_key={priv}'
                DELIMITER '|'
                GZIP
                TRIMBLANKS
                TRUNCATECOLUMNS
                ACCEPTINVCHARS
                TIMEFORMAT 'auto'
                DATEFORMAT 'auto'
                CSV
                NULL 'null'
                '''.format(
                    table=self.table_name,
                    copy_from=copy_from,
                    pub=s3_access,
                    priv=s3_secret,
                )
                logging.info('Copying to {table}'.format(
                    table=self.table_name))
                cur.execute(load)
                conn.commit()
                logging.info('Copy complete.')
        except ppg2.Error as e:
            logging.info('Error occurred during load: {error}'.format(
                error=e
            ))
            raise

如您所见,子类将manifest作为参数删除,在第一个中找不到replace语句,并且COPY命令略有不同。

1 个答案:

答案 0 :(得分:1)

RedshiftLoader._copy_to_db定义为:

def _copy_to_db(self, database_credentials, copy_from, manifest):
        """
        Copies data from a file on S3 to a Redshift table.  Data must be
        properly formatted and in the right order, etc...

        :param database_credentials: A dictionary containing the host, port,
        database name, username, and password.  Keys must match example:

        REDSHIFT_POSTGRES_INFO = {
            'host': REDSHIFT_HOST,
            'port': REDSHIFT_PORT,
            'database': REDSHIFT_DATABASE_DEV,
            'user': REDSHIFT_USER,
            'password': REDSHIFT_PASS
        }
        :param copy_from: The location of the file on the S3 server.
        :param manifest: True if a manifest file is to be used in the copy
        step, False otherwise.

        :return: None
        """
        if not self.table_name:
            raise AttributeError('A table must be specified.')
        s3_access = self.s3_credentials['aws_access_key_id']
        s3_secret = self.s3_credentials['aws_secret_access_key']
        logger.info('Accessing {table}'.format(table=self.table_name))
        try:
            with ppg2.connect(**database_credentials) as conn:
                cur = conn.cursor()

                if self.truncate:
                    RedshiftLoader.truncate_table(self.table_name, cur)

                load = '''
                COPY {table}
                FROM '{copy_from}'
                CREDENTIALS 'aws_access_key_id={pub};aws_secret_access_key={priv}'
                DELIMITER '|'
                GZIP
                TRIMBLANKS
                TRUNCATECOLUMNS
                ACCEPTINVCHARS
                TIMEFORMAT 'auto'
                DATEFORMAT 'auto'
                {manifest}
                '''.format(
                    table=self.table_name,
                    copy_from=copy_from,
                    pub=s3_access,
                    priv=s3_secret,
                    manifest=manifest
                )
                logger.info('Copying to {table}'.format(
                    table=self.table_name))
                cur.execute(load)
                conn.commit()
                logger.info('Copy complete.')
        except ppg2.Error as e:
            logger.critical('Error occurred during load: {error}'.format(
                error=e
            ))
            raise

这与RedshiftLoader.copy_to_db之间的唯一区别是manifest没有默认值,并且在使用它之前不会修改其值。现在,在每个类中定义copy_to_db如下:

class RedshiftLoader(PrettyStr):

    def copy_to_db(self, database_credentials, copy_from, manifest=False):
        manifest = 'MANIFEST' if manifest else ''
        self._copy_to_db(database_credentials, copy_from, manifest)

class SparkRedshiftLoader(RedshiftLoader):

    def copy_to_db(self, database_credentials, copy_from):
        copy_from = copy_from.replace('s3n', 's3')
        self._copy_to_db(database_credentials, copy_from, "CSV NULL 'null'")

私有方法抽象出所有公共代码(几乎全部代码); public方法提供了一个以类适当的方式修改copy_frommanifest值的地方。

请注意,manifest可能不是最佳参数名称,因为它以相当不同的方式使用。但是请注意,在这两种情况下,它只是一些特定于类的SQL附加到共享查询的末尾。

使用您传递给spark的{​​{1}}布尔值的想法,可以在一个类中完成相同的重构。