tf_agents reset environment using actor

109 Views Asked by At

I'm trying to understand how to use Actor class in tf_agents. I am using DDPG (actor-critic, although this doesn't really matter per say). I also am learning off of gym package, although again this isn't fully important to the question.

I went into the class definition for train.Actor and under the hood the run method calls py_driver.PyDriver. It is my understanding that after it reaches a terminal state, the gym environment needs to be reset. However following the Actor and PyDriver classes, I don't see anywhere (outside the init method) where env.reset() is called. And then looking at the tutorial for sac_agent.SacAgent, I don't see them calling env.reset() either.

Can someone help me understand what is missing? Do I not need to call env.reset()? Or is there some code that is being called that I am missing?

Here is the method for PyDriver.run():

  def run(
      self,
      time_step: ts.TimeStep,
      policy_state: types.NestedArray = ()
      ) -> Tuple[ts.TimeStep, types.NestedArray]:
      
      num_steps = 0
      num_episodes = 0
      while num_steps < self._max_steps and num_episodes < self._max_episodes:
          # For now we reset the policy_state for non batched envs.
          if not self.env.batched and time_step.is_first() and num_episodes > 0:
               policy_state = self._policy.get_initial_state(self.env.batch_size or 1)

          action_step = self.policy.action(time_step, policy_state)
          next_time_step = self.env.step(action_step.action)

          # When using observer (for the purpose of training), only the previous
          # policy_state is useful. Therefore substitube it in the PolicyStep and
          # consume it w/ the observer.
          action_step_with_previous_state = action_step._replace(state=policy_state)
          traj = trajectory.from_transition(time_step, action_step_with_previous_state, next_time_step)
          for observer in self._transition_observers:
              observer((time_step, action_step_with_previous_state, next_time_step))
          for observer in self.observers:
              observer(traj)
          for observer in self.info_observers:
              observer(self.env.get_info())

          if self._end_episode_on_boundary:
              num_episodes += np.sum(traj.is_boundary())
          else:
              num_episodes += np.sum(traj.is_last())

          num_steps += np.sum(~traj.is_boundary())

          time_step = next_time_step
          policy_state = action_step.state

     return time_step, policy_state

As you can see, it increases the number of steps if it hits a boundary, and increases the number of episodes if it hits the terminal state. But then there is no call to env.reset().

0

There are 0 best solutions below