* Pong V4: https://gymnasium.farama.org/environments/atari/pong/
* JAX installation: https://github.com/google/jax#installation
* Proximal Policy Optimization: https://coax.readthedocs.io/en/latest/examples/stubs/ppo.html
* Original: https://coax.readthedocs.io/en/latest/examples/atari/ppo.html

# Libraries

In [1]:
import gymnasium
import jax
import coax
import haiku

from jax import numpy
from matplotlib import pyplot
from optax import adam
from os import environ
from IPython.display import clear_output

 return jax_config.define_bool_state('flax_' + name, default, help)
 return jax_config.define_bool_state('flax_' + name, default, help)
 return jax_config.define_bool_state('flax_' + name, default, help)
 return jax_config.define_bool_state('flax_' + name, default, help)
 return jax_config.define_bool_state('flax_' + name, default, help)
 return jax_config.define_bool_state('flax_' + name, default, help)


## Environment Variables

In [2]:
environ["JAX_PLATFORM_NAME"] = "gpu" # tell JAX to use GPU
environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1" # don't use all gpu mem
environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # tell XLA to be quiet

# Environment

In [3]:
name = "pong"

environment = gymnasium.make('PongNoFrameskip-v4', render_mode='rgb_array')
environment = gymnasium.wrappers.AtariPreprocessing(environment)
environment = coax.wrappers.FrameStacking(environment, num_frames=3)
environment = gymnasium.wrappers.TimeLimit(environment, max_episode_steps=108000 // 3)
environment = coax.wrappers.TrainMonitor(environment, name=name, tensorboard_dir=f"./data/tensorboard/{name}")

environment.reset()

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


((array([[ 52, 52, 52, ..., 87, 87, 87],
 [ 87, 87, 87, ..., 87, 87, 87],
 [ 87, 87, 87, ..., 87, 87, 87],
 ...,
 [236, 236, 236, ..., 236, 236, 236],
 [236, 236, 236, ..., 236, 236, 236],
 [236, 236, 236, ..., 236, 236, 236]], dtype=uint8),
 array([[ 52, 52, 52, ..., 87, 87, 87],
 [ 87, 87, 87, ..., 87, 87, 87],
 [ 87, 87, 87, ..., 87, 87, 87],
 ...,
 [236, 236, 236, ..., 236, 236, 236],
 [236, 236, 236, ..., 236, 236, 236],
 [236, 236, 236, ..., 236, 236, 236]], dtype=uint8),
 array([[ 52, 52, 52, ..., 87, 87, 87],
 [ 87, 87, 87, ..., 87, 87, 87],
 [ 87, 87, 87, ..., 87, 87, 87],
 ...,
 [236, 236, 236, ..., 236, 236, 236],
 [236, 236, 236, ..., 236, 236, 236],
 [236, 236, 236, ..., 236, 236, 236]], dtype=uint8)),
 {'lives': 0, 'episode_frame_number': 20, 'frame_number': 20})

## Possible actions

In [4]:
actions = environment.action_space.n
meanings = environment.unwrapped.get_action_meanings()

print(f"{actions} possible actions: {meanings}")

6 possible actions: ['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']


# Support Functions

In [5]:
def shared(S, is_training):
 seq = haiku.Sequential([
 coax.utils.diff_transform,
 haiku.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu,
 haiku.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu,
 haiku.Flatten(),
 ])
 X = numpy.stack(S, axis=-1) / 255.
 return seq(X)


def func_pi(S, is_training):
 logits = haiku.Sequential((
 haiku.Linear(256), jax.nn.relu,
 haiku.Linear(environment.action_space.n, w_init=numpy.zeros),
 ))
 X = shared(S, is_training)
 return {'logits': logits(X)}


def func_v(S, is_training):
 value = haiku.Sequential((
 haiku.Linear(256), jax.nn.relu,
 haiku.Linear(1, w_init=numpy.zeros), numpy.ravel
 ))
 X = shared(S, is_training)
 return value(X)

# Function Approximators

In [6]:
pi = coax.Policy(func_pi, environment)
v = coax.V(func_v, environment)

# Target Networks

In [7]:
pi_behavior = pi.copy()
v_targ = v.copy()

# Policy Regularizer (Avoid Premature Exploitation)

In [8]:
entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)

# Updaters

In [9]:
simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))

# Reward Tracer and Replay Buffer

In [10]:
tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)

# Training

In [11]:
while environment.T < 3000000:
 s, info = environment.reset()

 for t in range(environment.spec.max_episode_steps):
 a, logp = pi_behavior(s, return_logp=True)
 s_next, r, done, truncated, info = environment.step(a)

 tracer.add(s, a, r, done, logp)
 while tracer:
 buffer.add(tracer.pop())

 if len(buffer) >= buffer.capacity:
 num_batches = int(4 * buffer.capacity / 32)
 for _ in range(num_batches):
 transition_batch = buffer.sample(32)
 metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
 metrics_pi = ppo_clip.update(transition_batch, td_error)
 environment.record_metrics(metrics_v)
 environment.record_metrics(metrics_pi)

 buffer.clear()

 pi_behavior.soft_update(pi, tau=0.1)
 v_targ.soft_update(v, tau=0.1)

 if done or truncated:
 break

 s = s_next

 if environment.period(name='generate_gif', T_period=10000):
 T = environment.T - environment.T % 10000 # round to 10000s
 coax.utils.generate_gif(
 env=environment, policy=pi, resize_to=(320, 420),
 filepath=f"./data/gifs/{name}/T{T:08d}.gif")

INFO:TrainMonitor:ep: 1,	T: 1,	G: 0,	avg_r: nan,	avg_G: 0,	t: 0,	dt: nanms
INFO:TrainMonitor:ep: 2,	T: 933,	G: -20,	avg_r: -0.0215,	avg_G: -20,	t: 931,	dt: 6.052ms,	SimpleTD/loss: 0.0499,	PPOClip/EntropyRegularizer/entropy: 1.79,	PPOClip/loss: 0.0104
INFO:TrainMonitor:ep: 3,	T: 1,758,	G: -21,	avg_r: -0.0255,	avg_G: -20.5,	t: 824,	dt: 2.822ms,	SimpleTD/loss: 0.0456,	PPOClip/EntropyRegularizer/entropy: 1.78,	PPOClip/loss: -0.000834
INFO:TrainMonitor:ep: 4,	T: 2,788,	G: -20,	avg_r: -0.0194,	avg_G: -20.3,	t: 1029,	dt: 2.838ms,	SimpleTD/loss: 0.0168,	PPOClip/EntropyRegularizer/entropy: 1.79,	PPOClip/loss: 0.000254
INFO:TrainMonitor:ep: 5,	T: 3,629,	G: -20,	avg_r: -0.0238,	avg_G: -20.2,	t: 840,	dt: 2.921ms,	SimpleTD/loss: 0.00722,	PPOClip/EntropyRegularizer/entropy: 1.78,	PPOClip/loss: -0.00171
INFO:TrainMonitor:ep: 6,	T: 4,607,	G: -20,	avg_r: -0.0205,	avg_G: -20.2,	t: 977,	dt: 2.795ms,	SimpleTD/loss: 0.0162,	PPOClip/EntropyRegularizer/entropy: 1.78,	PPOClip/loss: -0.00592
INFO:TrainMonitor:

# Save Model

In [12]:
coax.utils.dump((pi, v, pi_behavior, v_targ), 'checkpoint.pkl.lz4')