
时间:2019-05-02 17:15:24

标签: theano pymc3 pymc



    def pol2PolsArray(self, pols):
        # Takes A (called 'pols') for joint policies and maps to joint pols array ('polsArray')
        # Input: (num_obs, 2, 2*dim, dim)-shaped theano variable
        # Output: (num_obs, 2, length, ..., length, 2*dim)-shaped numpy array

        theano.config.compute_test_value = 'warn'
        pols.tag.test_value = np.zeros((self.obs.num_obs, 2, 2*self.obs.environment.dim, self.obs.environment.dim))
        # Initialise theano variables
        state = tt.dvector('state')
        # state.tag.test_value = self.obs.environment.dim * (0,)
        ind0 = tt.bscalar('ind0')
        # state.tag.test_value = 0
        ind1 = tt.bscalar('ind1')
        # state.tag.test_value = 0

        # Create graph for computing softmax
        state_act_distn = tt.exp(tt.dot(pols[ind0, ind1], state)) / tt.sum(tt.exp(tt.dot(pols[ind0, ind1], state)))
        pol_softmax = theano.function(inputs=[state, ind0, ind1], outputs=state_act_distn)

        # Initialise polArray
        # Possibly boilerplate as shape is expressed elsewhere
        polsArray = np.zeros((self.obs.num_obs, 2) + self.obs.environment.dim*(self.obs.environment.length,) + (2*self.obs.environment.dim,))

        # Compute softmax for each state value, input into polsArray
        for np_ind0 in range(self.obs.num_obs):
            for np_ind1 in range(2):
                for np_state in product(range(self.obs.environment.length), repeat=self.obs.environment.dim):
                    polsArray[(np_ind0, np_ind1) + np_state] = pol_softmax(np_state, np_ind0, np_ind1)

        return polsArray

    def logp_traj(self, trajectories, pols):
        polsArray = self.pol2PolsArray(pols)

        return (trajectories * np.log(polsArray)).sum()


theano.gof.fg.MissingInputError: Input 0 of the graph (indices start from 0), used to compute Subtensor{int8, int8}(pols, ScalarFromTensor.0, ScalarFromTensor.0), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.


0 个答案:
