{ "cells": [ { "cell_type": "markdown", "id": "9c5c18a1", "metadata": {}, "source": [ "* Taxi V3: https://gymnasium.farama.org/environments/toy_text/taxi/" ] }, { "cell_type": "markdown", "id": "2df4fa0b", "metadata": {}, "source": [ "# Libraries" ] }, { "cell_type": "code", "execution_count": 1, "id": "df27797f", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:14.209983Z", "iopub.status.busy": "2023-11-28T05:26:14.208498Z", "iopub.status.idle": "2023-11-28T05:26:15.646218Z", "shell.execute_reply": "2023-11-28T05:26:15.645898Z", "shell.execute_reply.started": "2023-11-28T05:26:14.209935Z" }, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n" ] } ], "source": [ "import coax\n", "import gymnasium\n", "import haiku\n", "import jax\n", "import matplotlib.pyplot as pyplot\n", "\n", "from jax import numpy\n", "from IPython.display import clear_output" ] }, { "cell_type": "markdown", "id": "4876213b", "metadata": {}, "source": [ "# Environment" ] }, { "cell_type": "code", "execution_count": 2, "id": "092108a6", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:15.646762Z", "iopub.status.busy": "2023-11-28T05:26:15.646571Z", "iopub.status.idle": "2023-11-28T05:26:15.731586Z", "shell.execute_reply": "2023-11-28T05:26:15.731252Z", "shell.execute_reply.started": "2023-11-28T05:26:15.646754Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)\n", "[Powered by Stella]\n" ] } ], "source": [ "name = \"pong\"\n", "\n", "environment = gymnasium.make(\"PongNoFrameskip-v4\", render_mode=\"rgb_array\")\n", "environment = gymnasium.wrappers.AtariPreprocessing(environment)\n", "environment = coax.wrappers.FrameStacking(environment, num_frames=3)\n", "environment = gymnasium.wrappers.TimeLimit(environment, max_episode_steps=108000 // 3)\n", "environment = coax.wrappers.TrainMonitor(environment, name=name, tensorboard_dir=f\"./data/tensorboard/{name}\")" ] }, { "cell_type": "markdown", "id": "29f90cc2", "metadata": {}, "source": [ "# Model" ] }, { "cell_type": "code", "execution_count": 3, "id": "992284e8-4a1e-49b4-89b0-c48e0676fda3", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:15.731937Z", "iopub.status.busy": "2023-11-28T05:26:15.731860Z", "iopub.status.idle": "2023-11-28T05:26:15.734677Z", "shell.execute_reply": "2023-11-28T05:26:15.734499Z", "shell.execute_reply.started": "2023-11-28T05:26:15.731930Z" } }, "outputs": [], "source": [ "def shared(S, is_training):\n", " seq = haiku.Sequential([\n", " coax.utils.diff_transform,\n", " haiku.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu,\n", " haiku.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu,\n", " haiku.Flatten(),\n", " ])\n", " X = numpy.stack(S, axis=-1) / 255.\n", " return seq(X)\n", "\n", "\n", "def func_pi(S, is_training):\n", " logits = haiku.Sequential((\n", " haiku.Linear(256), jax.nn.relu,\n", " haiku.Linear(environment.action_space.n, w_init=numpy.zeros),\n", " ))\n", " X = shared(S, is_training)\n", " return {'logits': logits(X)}\n", "\n", "\n", "def func_v(S, is_training):\n", " value = haiku.Sequential((\n", " haiku.Linear(256), jax.nn.relu,\n", " haiku.Linear(1, w_init=numpy.zeros), numpy.ravel\n", " ))\n", " X = shared(S, is_training)\n", " return value(X)" ] }, { "cell_type": "code", "execution_count": 4, "id": "59d1ea20-3810-473d-a543-f916762a9d5c", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:15.735225Z", "iopub.status.busy": "2023-11-28T05:26:15.735116Z", "iopub.status.idle": "2023-11-28T05:26:19.475119Z", "shell.execute_reply": "2023-11-28T05:26:19.474806Z", "shell.execute_reply.started": "2023-11-28T05:26:15.735218Z" } }, "outputs": [], "source": [ "pi = coax.Policy(func_pi, environment)\n", "v = coax.V(func_v, environment)" ] }, { "cell_type": "code", "execution_count": 5, "id": "abf769a7", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:19.475579Z", "iopub.status.busy": "2023-11-28T05:26:19.475498Z", "iopub.status.idle": "2023-11-28T05:26:19.558226Z", "shell.execute_reply": "2023-11-28T05:26:19.557983Z", "shell.execute_reply.started": "2023-11-28T05:26:19.475571Z" } }, "outputs": [], "source": [ "pi, v, pi_behavior, v_targ = coax.utils.load('checkpoint.pkl.lz4')" ] }, { "cell_type": "markdown", "id": "e86a4fbe", "metadata": {}, "source": [ "# Playthrough" ] }, { "cell_type": "code", "execution_count": 6, "id": "b609ed20", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:19.558727Z", "iopub.status.busy": "2023-11-28T05:26:19.558624Z", "iopub.status.idle": "2023-11-28T05:27:28.656071Z", "shell.execute_reply": "2023-11-28T05:27:28.655776Z", "shell.execute_reply.started": "2023-11-28T05:26:19.558719Z" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Steps taken: 1662\n", "Final score: 20.0\n", "Penalties: 0\n" ] } ], "source": [ "state, _ = environment.reset()\n", "\n", "score, steps, penalties, reward = 0, 0, 0, 0\n", "\n", "done=False\n", "while not done:\n", " steps += 1\n", " action = pi_behavior(state)\n", " state, reward, done, truncated, info = environment.step(action)\n", " score += reward\n", " if reward == -10:\n", " penalties += 1\n", "\n", " clear_output(wait=True)\n", " pyplot.imshow(environment.render())\n", " pyplot.show()\n", "\n", "print(f\"Steps taken: {steps}\")\n", "print(f\"Final score: {score}\")\n", "print(f\"Penalties: {penalties}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }