I'm working on a LSTM RecurrentPPO that's need a behavioural cloning implementation.
The Imitation library provided with Stable Baselines 3 (see here : https://imitation.readthedocs.io/en/latest/) does not seem made for SB3-contrib's RecurrentPPO.
I found this method that could be adapted for RecurrentPPO : https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pretraining.ipynb
I guess this part of code have to be modified in order to consider lstm_states and episode_starts but I don't know how to implement it.
def pretrain_agent(
student,
batch_size=64,
epochs=1000,
scheduler_gamma=0.7,
learning_rate=1.0,
log_interval=100,
no_cuda=True,
seed=1,
test_batch_size=64,
):
use_cuda = not no_cuda and th.cuda.is_available()
th.manual_seed(seed)
device = th.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
if isinstance(env.action_space, gym.spaces.Box):
criterion = nn.MSELoss()
else:
criterion = nn.CrossEntropyLoss()
# Extract initial policy
model = student.policy.to(device)
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
if isinstance(env.action_space, gym.spaces.Box):
# A2C/PPO policy outputs actions, values, log_prob
# SAC/TD3 policy outputs actions only
if isinstance(student, (A2C, PPO)):
action, _, _ = model(data)
else:
# SAC/TD3:
action = model(data)
action_prediction = action.double()
else:
# Retrieve the logits for A2C/PPO when using discrete actions
dist = model.get_distribution(data)
action_prediction = dist.distribution.logits
target = target.long()
loss = criterion(action_prediction, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
def test(model, device, test_loader):
model.eval()
test_loss = 0
with th.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
if isinstance(env.action_space, gym.spaces.Box):
# A2C/PPO policy outputs actions, values, log_prob
# SAC/TD3 policy outputs actions only
if isinstance(student, (A2C, PPO)):
action, _, _ = model(data)
else:
# SAC/TD3:
action = model(data)
action_prediction = action.double()
else:
# Retrieve the logits for A2C/PPO when using discrete actions
dist = model.get_distribution(data)
action_prediction = dist.distribution.logits
target = target.long()
test_loss = criterion(action_prediction, target)
test_loss /= len(test_loader.dataset)
print(f"Test set: Average loss: {test_loss:.4f}")
# Here, we use PyTorch `DataLoader` to our load previously created `ExpertDataset` for training
# and testing
train_loader = th.utils.data.DataLoader(
dataset=train_expert_dataset, batch_size=batch_size, shuffle=True, **kwargs
)
test_loader = th.utils.data.DataLoader(
dataset=test_expert_dataset,
batch_size=test_batch_size,
shuffle=True,
**kwargs,
)
# Define an Optimizer and a learning rate schedule.
optimizer = optim.Adadelta(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=1, gamma=scheduler_gamma)
# Now we are finally ready to train the policy model.
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
scheduler.step()
# Implant the trained policy network back into the RL student agent
a2c_student.policy = model
Does anyone have a solution?
Just stumbled upon this problem as well.
The problem is obviously that
evaluate_actionsinRecurrentActorCriticPolicyhas a different signature forevaluate_actionswhich needs thelstm_statesandepisode_startsas well.My first thought was that this means that also during rollout collection this information needs to be stored (which I thought, it would, but it does not). And the solution would be to store the missing infos during rollout collection and handle them during BC if they are there and compatible with the policy at hand.
But actually it is unclear what the expert
stateis, when the expert policy from the rollout collection is not recurrent itself (but e.g. a near-optimal search algorithm). Thus for recurrent policies the BC algorithm should train using whole trajectories from begin to end and passing thelstm_stepin between timesteps.