diff --git a/cherry/td.py b/cherry/td.py index 03338d9..c62a64c 100644 --- a/cherry/td.py +++ b/cherry/td.py @@ -52,7 +52,7 @@ def discount(gamma, rewards, dones, bootstrap=0.0): msg = 'dones and rewards must have equal length.' assert rewards.size(0) == dones.size(0), msg - R = th.zeros_like(rewards[0]) + bootstrap + R = th.zeros_like(rewards) + bootstrap discounted = th.zeros_like(rewards) length = discounted.size(0) for t in reversed(range(length)): diff --git a/examples/actor_critic_cartpole.py b/examples/actor_critic_cartpole.py index 61c217e..4696ffb 100644 --- a/examples/actor_critic_cartpole.py +++ b/examples/actor_critic_cartpole.py @@ -1,112 +1,131 @@ #!/usr/bin/env python3 -""" -Simple example of using cherry to solve cartpole with an actor-critic. - -The code is an adaptation of the PyTorch reinforcement learning example. -""" - -import random +import torch +import cherry import gym import numpy as np - from itertools import count +import statistics + +NUM_ENVS = 6 +STEPS = 5 +TRAIN_STEPS = int(1e4) + +class A2C(torch.nn.Module): + def __init__(self, num_envs): + super(A2C, self).__init__() + + self.num_envs = num_envs + self.gamma = 0.99 + self.vf_coef = 0.25 + self.ent_coef = 0.01 + self.max_clip_norm = 0.5 + + def select_action(self, state): + probs, value = self(state) + mass = torch.distributions.Categorical(probs) + action = mass.sample() + # Return selected action, logprob, value estimation and categorical entropy + return action, {"log_prob": mass.log_prob(action), "value": value, "entropy": mass.entropy()} + + + def learn_step(self, replay, optimizer): + policy_loss = [] + value_loss = [] + entropy_loss = [] + + # Discount rewards and boostrap them with the estimation from the next state + last_action, last_value = self(replay.next_state()[-1,:,:]) + # Boostrap from zero if it is a terminal state + last_value = (last_value[:, 0]*(1 - replay.done()[-1])) + + rewards = cherry.td.discount(self.gamma, replay.reward(), replay.done(), last_value) + for sars, reward in zip(replay, rewards): + log_prob = sars.log_prob.view(self.num_envs, -1) + value = sars.value.view(self.num_envs, -1) + entropy = sars.entropy.view(self.num_envs, -1) + reward = reward.view(self.num_envs, -1) + + # Compute advantage + advantage = reward - value + + # Compute policy gradient loss + # (advantage.detach() because you do not have to backward on the advantage path) + policy_loss.append(-log_prob * advantage.detach()) + # Compute value estimation loss + value_loss.append((reward - value)**2) + # Compute entropy loss + entropy_loss.append(entropy) + + + # Compute means over accumulated errors + value_loss = torch.stack(value_loss).mean() + policy_loss = torch.stack(policy_loss).mean() + entropy_loss = torch.stack(entropy_loss).mean() + + # Take an optimization step + optimizer.zero_grad() + loss = policy_loss + self.vf_coef * value_loss - self.ent_coef * entropy_loss + loss.backward() + # Clip gradients + torch.nn.utils.clip_grad_norm_(self.parameters(), self.max_clip_norm) + optimizer.step() + + + + +class A2CPolicy(A2C): + def __init__(self, state_size, action_size, num_envs): + super(A2CPolicy, self).__init__(num_envs) + self.state_size = state_size + self.action_size = action_size + self.n_hidden = 128 + + # Backbone net + self.net = torch.nn.Sequential( + torch.nn.Linear(self.state_size, self.n_hidden), + torch.nn.LeakyReLU(), + torch.nn.Linear(self.n_hidden, self.n_hidden), + torch.nn.LeakyReLU(), + ) + + # Action head (policy gradient) + self.action_head = torch.nn.Sequential( + torch.nn.Linear(self.n_hidden, self.action_size), + torch.nn.Softmax(dim=1) + ) + + # Value estimation head (A2C) + self.value_head = torch.nn.Sequential( + torch.nn.Linear(self.n_hidden, 1), + ) -import torch as th -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim - -import cherry.envs as envs -from cherry.td import discount -from cherry import normalize -import cherry.distributions as distributions - -SEED = 567 -GAMMA = 0.99 -RENDER = False -V_WEIGHT = 0.5 - -random.seed(SEED) -np.random.seed(SEED) -th.manual_seed(SEED) - - -class ActorCriticNet(nn.Module): - def __init__(self, env): - super(ActorCriticNet, self).__init__() - self.affine1 = nn.Linear(env.state_size, 128) - self.action_head = nn.Linear(128, env.action_size) - self.value_head = nn.Linear(128, 1) - self.distribution = distributions.ActionDistribution(env, - use_probs=True) def forward(self, x): - x = F.relu(self.affine1(x)) - action_scores = self.action_head(x) - action_mass = self.distribution(F.softmax(action_scores, dim=1)) - value = self.value_head(x) - return action_mass, value - - -def update(replay, optimizer): - policy_loss = [] - value_loss = [] - - # Discount and normalize rewards - rewards = discount(GAMMA, replay.reward(), replay.done()) - rewards = normalize(rewards) - - # Compute losses - for sars, reward in zip(replay, rewards): - log_prob = sars.log_prob - value = sars.value - policy_loss.append(-log_prob * (reward - value.item())) - value_loss.append(F.mse_loss(value, reward.detach())) - - # Take optimization step - optimizer.zero_grad() - loss = th.stack(policy_loss).sum() + V_WEIGHT * th.stack(value_loss).sum() - loss.backward() - optimizer.step() - - -def get_action_value(state, policy): - mass, value = policy(state) - action = mass.sample() - info = { - 'log_prob': mass.log_prob(action), # Cache log_prob for later - 'value': value - } - return action, info - + # Return both the action probabilities and the value estimations + return self.action_head(self.net(x)), self.value_head(self.net(x)) if __name__ == '__main__': - env = gym.vector.make('CartPole-v0', num_envs=1) - env = envs.Logger(env, interval=1000) - env = envs.Torch(env) - env = envs.Runner(env) - env.seed(SEED) - - policy = ActorCriticNet(env) - optimizer = optim.Adam(policy.parameters(), lr=1e-2) - running_reward = 10.0 - get_action = lambda state: get_action_value(state, policy) - - for episode in count(1): - # We use the Runner collector, but could've written our own - replay = env.run(get_action, episodes=1) - - # Update policy - update(replay, optimizer) - - # Compute termination criterion - running_reward = running_reward * 0.99 + len(replay) * 0.01 - if episode % 10 == 0: - # Should start with 10.41, 12.21, 14.60, then 100:71.30, 200:135.74 - print(episode, running_reward) - if running_reward > 190.0: - print('Solved! Running reward now {} and ' - 'the last episode runs to {} time steps!'.format(running_reward, - len(replay))) - break + env = gym.vector.make('CartPole-v0', num_envs=NUM_ENVS) + env = cherry.envs.Logger(env, interval=1000) + env = cherry.envs.Torch(env) + + policy = A2CPolicy(env.state_size, env.action_size, NUM_ENVS) + optimizer = torch.optim.RMSprop(policy.parameters(), lr=7e-4, eps=1e-5, alpha=0.99) + + state = env.reset() + for train_step in range(0, TRAIN_STEPS): + replay = cherry.ExperienceReplay() + for step in range(0, STEPS): + action, info = policy.select_action(state) + new_state, reward, done, _ = env.step(action) + replay.append(state, action, reward, new_state, done, **info) + state = new_state + + policy.learn_step(replay, optimizer) + + env = gym.make('CartPole-v0') + env = cherry.envs.Torch(env) + env = cherry.envs.Runner(env) + while True: + env.run(lambda state: policy.select_action(state), episodes=1, render=True)