在组件之间传递渐变; pass_by_obj输出

时间:2015-12-29 21:40:01

标签: openmdao

我的情况是,必须在另一个组件中计算一个组件的梯度。我试图做的只是让渐变是第一个组件的输出和第二个组件的输入。我已将其设置为pass_by_obj,因此它不会影响其他计算。关于这是否是最佳方式的任何建议将不胜感激。不过,使用check_partial_derivatives()时出错。对于任何指定为pass_by_obj的输出,似乎都是错误的。这是一个简单的案例:

import numpy as np
from openmdao.api import Group, Problem, Component, ScipyGMRES, ExecComp, IndepVarComp

class Comp1(Component):
    def __init__(self):
        super(Comp1, self).__init__()
        self.add_param('x', shape=1)

        self.add_output('y', shape=1)
        self.add_output('dz_dy', shape=1, pass_by_obj=True)

    def solve_nonlinear(self, params, unknowns, resids):

        x = params['x']

        unknowns['y'] = 4.0*x + 1.0
        unknowns['dz_dy'] = 2.0*x

    def linearize(self, params, unknowns, resids):

        J = {}
        J['y', 'x'] = 4.0
        return J

class Comp2(Component):
    def __init__(self):
        super(Comp2, self).__init__()
        self.add_param('y', shape=1)
        self.add_param('dz_dy', shape=1, pass_by_obj=True)

        self.add_output('z', shape=1)

    def solve_nonlinear(self, params, unknowns, resids):
        y = params['y']
        unknowns['z'] = y*2.0

    def linearize(self, params, unknowns, resids):
        J = {}
        J['z', 'y'] = params['dz_dy']
        return J

class TestGroup(Group):
    def __init__(self):
        super(TestGroup, self).__init__()
        self.add('x', IndepVarComp('x', 0.0), promotes=['*'])
        self.add('c1', Comp1(), promotes=['*'])
        self.add('c2', Comp2(), promotes=['*'])

p = Problem()
p.root = TestGroup()
p.setup(check=False)

p['x'] = 2.0

p.run()

print p['z']
print 'gradients'
test_grad = open('partial_gradients_test.txt', 'w')
partial = p.check_partial_derivatives(out_stream=test_grad)

我收到以下错误消息:

partial = p.check_partial_derivatives(out_stream=test_grad)
  File "/usr/local/lib/python2.7/site-packages/openmdao/core/problem.py", line 1699, in check_partial_derivatives
    dresids._dat[u_name].val[idx] = 1.0
TypeError: '_ByObjWrapper' object does not support item assignment

我之前询问过在check_partial_derivatives()中检查pass_by_obj的params,它可能只是检查pass_by_obj的未知数。

1 个答案:

答案 0 :(得分:0)

您获得的错误是另一个与check_partial_derivatives函数相关的错误。它应该很容易修复,但在此期间你可以删除pass_by_obj设置。由于您在一个组件中计算一个值并将其传递给另一个组件,因此根本不需要执行pass_by_obj(如果您不这样做,则效率会更高)。

你说你这么做是因为它并没有影响其他计算,但我不太清楚你的意思。除非您在solve_nonlinear方法中使用它,否则它不会影响任何事情。