DMC & Atari#
DMC example#
DrQV2 is a truly-SotA, data-efficient reinforcement learning algorithm. Here it is in ~70 lines of code:
(and see here for a slightly longer from-scratch implementation that only uses Pytorch primitives, made executable with just a single Antelope import statement and run-call.)
import torch
from antelope import Utils
from antelope.Agents.Blocks.Augmentations import RandomShiftsAug
from antelope.Agents.Blocks.Encoders import CNNEncoder
from antelope.Agents.Blocks.Actors import EnsemblePiActor
from antelope.Agents.Blocks.Critics import EnsembleQCritic
from antelope.Agents.Losses import QLearning, PolicyLearning
from antelope import ml
class DrQV2Agent(torch.nn.Module):
"""Data Regularized Q-Learning version 2 (https://arxiv.org/abs/2107.09645)"""
def __init__(self,
obs_spec, action_spec, trunk_dim, hidden_dim, # Architecture
lr, ema_decay, # Optimization
rand_steps, stddev_schedule, # Exploration
log, # On-boarding
):
super().__init__()
self.aug = RandomShiftsAug(pad=4)
self.encoder = CNNEncoder(obs_spec, norm=0.5, lr=lr)
self.actor = EnsemblePiActor(self.encoder.repr_shape, trunk_dim, hidden_dim, action_spec,
stddev_schedule=stddev_schedule, rand_steps=rand_steps, lr=lr)
self.critic = EnsembleQCritic(self.encoder.repr_shape, trunk_dim, hidden_dim, action_spec,
lr=lr, ema_decay=ema_decay)
self.log = log
def act(self, obs):
obs = self.encoder(obs)
Pi = self.actor(obs, self.step)
action = Pi.sample() if self.training else Pi.best
return action
def learn(self, replay, log):
if not self.log:
log = None
batch = next(replay)
# Augment, encode present
batch.obs = self.aug(batch.obs)
batch.obs = self.encoder(batch.obs)
if replay.nstep:
with torch.no_grad():
# Augment, encode future
batch.next_obs = self.aug(batch.next_obs)
batch.next_obs = self.encoder(batch.next_obs)
# Critic loss
critic_loss = QLearning.ensembleQLearning(self.critic, self.actor, batch.obs, batch.action, batch.reward,
batch.discount, batch.next_obs, self.step, log=log)
# Update encoder and critic
Utils.optimize(critic_loss, self.encoder, self.critic)
# Actor loss
actor_loss = PolicyLearning.deepPolicyGradient(self.actor, self.critic, batch.obs.detach(),
step=self.step, log=log)
# Update actor
Utils.optimize(actor_loss, self.actor)
ml(task='dmc/cheetah_run', agent=DrQV2Agent)
# For exact reproduction:
# ml(task='dmc/cheetah_run', agent=DrQV2Agent, index='episode', with_replacement=True, partition_workers=True)
<curves>