{ "cells": [ { "cell_type": "markdown", "id": "9c5c18a1", "metadata": {}, "source": [ "* Taxi V3: https://gymnasium.farama.org/environments/toy_text/taxi/" ] }, { "cell_type": "markdown", "id": "2df4fa0b", "metadata": {}, "source": [ "# Libraries" ] }, { "cell_type": "code", "execution_count": 1, "id": "df27797f", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:14.209983Z", "iopub.status.busy": "2023-11-28T05:26:14.208498Z", "iopub.status.idle": "2023-11-28T05:26:15.646218Z", "shell.execute_reply": "2023-11-28T05:26:15.645898Z", "shell.execute_reply.started": "2023-11-28T05:26:14.209935Z" }, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n", "/home/efren/.pyenv/versions/3.11.6/envs/pong/lib/python3.11/site-packages/flax/configurations.py:42: DeprecationWarning: jax.config.define_bool_state is deprecated. Please use other libraries for configuration instead.\n", " return jax_config.define_bool_state('flax_' + name, default, help)\n" ] } ], "source": [ "import coax\n", "import gymnasium\n", "import haiku\n", "import jax\n", "import matplotlib.pyplot as pyplot\n", "\n", "from jax import numpy\n", "from IPython.display import clear_output" ] }, { "cell_type": "markdown", "id": "4876213b", "metadata": {}, "source": [ "# Environment" ] }, { "cell_type": "code", "execution_count": 2, "id": "092108a6", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:15.646762Z", "iopub.status.busy": "2023-11-28T05:26:15.646571Z", "iopub.status.idle": "2023-11-28T05:26:15.731586Z", "shell.execute_reply": "2023-11-28T05:26:15.731252Z", "shell.execute_reply.started": "2023-11-28T05:26:15.646754Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)\n", "[Powered by Stella]\n" ] } ], "source": [ "name = \"pong\"\n", "\n", "environment = gymnasium.make(\"PongNoFrameskip-v4\", render_mode=\"rgb_array\")\n", "environment = gymnasium.wrappers.AtariPreprocessing(environment)\n", "environment = coax.wrappers.FrameStacking(environment, num_frames=3)\n", "environment = gymnasium.wrappers.TimeLimit(environment, max_episode_steps=108000 // 3)\n", "environment = coax.wrappers.TrainMonitor(environment, name=name, tensorboard_dir=f\"./data/tensorboard/{name}\")" ] }, { "cell_type": "markdown", "id": "29f90cc2", "metadata": {}, "source": [ "# Model" ] }, { "cell_type": "code", "execution_count": 3, "id": "992284e8-4a1e-49b4-89b0-c48e0676fda3", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:15.731937Z", "iopub.status.busy": "2023-11-28T05:26:15.731860Z", "iopub.status.idle": "2023-11-28T05:26:15.734677Z", "shell.execute_reply": "2023-11-28T05:26:15.734499Z", "shell.execute_reply.started": "2023-11-28T05:26:15.731930Z" } }, "outputs": [], "source": [ "def shared(S, is_training):\n", " seq = haiku.Sequential([\n", " coax.utils.diff_transform,\n", " haiku.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu,\n", " haiku.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu,\n", " haiku.Flatten(),\n", " ])\n", " X = numpy.stack(S, axis=-1) / 255.\n", " return seq(X)\n", "\n", "\n", "def func_pi(S, is_training):\n", " logits = haiku.Sequential((\n", " haiku.Linear(256), jax.nn.relu,\n", " haiku.Linear(environment.action_space.n, w_init=numpy.zeros),\n", " ))\n", " X = shared(S, is_training)\n", " return {'logits': logits(X)}\n", "\n", "\n", "def func_v(S, is_training):\n", " value = haiku.Sequential((\n", " haiku.Linear(256), jax.nn.relu,\n", " haiku.Linear(1, w_init=numpy.zeros), numpy.ravel\n", " ))\n", " X = shared(S, is_training)\n", " return value(X)" ] }, { "cell_type": "code", "execution_count": 4, "id": "59d1ea20-3810-473d-a543-f916762a9d5c", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:15.735225Z", "iopub.status.busy": "2023-11-28T05:26:15.735116Z", "iopub.status.idle": "2023-11-28T05:26:19.475119Z", "shell.execute_reply": "2023-11-28T05:26:19.474806Z", "shell.execute_reply.started": "2023-11-28T05:26:15.735218Z" } }, "outputs": [], "source": [ "pi = coax.Policy(func_pi, environment)\n", "v = coax.V(func_v, environment)" ] }, { "cell_type": "code", "execution_count": 5, "id": "abf769a7", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:19.475579Z", "iopub.status.busy": "2023-11-28T05:26:19.475498Z", "iopub.status.idle": "2023-11-28T05:26:19.558226Z", "shell.execute_reply": "2023-11-28T05:26:19.557983Z", "shell.execute_reply.started": "2023-11-28T05:26:19.475571Z" } }, "outputs": [], "source": [ "pi, v, pi_behavior, v_targ = coax.utils.load('checkpoint.pkl.lz4')" ] }, { "cell_type": "markdown", "id": "e86a4fbe", "metadata": {}, "source": [ "# Playthrough" ] }, { "cell_type": "code", "execution_count": 6, "id": "b609ed20", "metadata": { "execution": { "iopub.execute_input": "2023-11-28T05:26:19.558727Z", "iopub.status.busy": "2023-11-28T05:26:19.558624Z", "iopub.status.idle": "2023-11-28T05:27:28.656071Z", "shell.execute_reply": "2023-11-28T05:27:28.655776Z", "shell.execute_reply.started": "2023-11-28T05:26:19.558719Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVEAAAGhCAYAAADY5IdbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkJ0lEQVR4nO3df3BU9b3/8dfm1xIguzEJyWY1gUAVRCUCaky1XCkpSbC0Kr1XKN4LyoC1gY6JXjF3lB9OZ4J66+1oUebOtNBORSwzglfuyAwESbSGKEGGKpovoVFA2KAwySbBbH7s+f7RL/vtNgmw+exmk/h8zJyZnPP5nLPv8zG8PHt+xWZZliUAwIDERLsAABjOCFEAMECIAoABQhQADBCiAGCAEAUAA4QoABggRAHAACEKAAYIUQAwENUQ3bhxoyZMmKBRo0YpLy9PH3zwQTTLAYCQRS1EX3/9dZWVlWnt2rU6dOiQcnNzVVhYqLNnz0arJAAImS1aLyDJy8vTrbfeqt/85jeSJL/fr6ysLK1atUpPPvnkJdf1+/06ffq0kpKSZLPZBqNcAN8ylmWptbVVbrdbMTH9H2/GDWJNAZ2dnaqrq1N5eXlgWUxMjAoKClRTU9Orv8/nk8/nC8x/+eWXmjp16qDUCuDb7eTJk7rmmmv6bY9KiH799dfq6elRRkZG0PKMjAx99tlnvfpXVFRo/fr1vZY/PcuhUXGhHYnG2DSij14nuN0af7W7z7aTZzz666lTg1wR/pElm/5693S1ZKeGdbvZVUeVdvTLPttO5V+rszdPCOvnjfv4hLLerQ/rNoeSjm5L6/a3KCkp6ZL9ohKioSovL1dZWVlg3uv1KisrS2MSYkIO0ZFutD1WjsSEPtvG2GMZryHAstkUPyZOCUn2sG7Xbo/r97+vPTE+7J+XkBj/rfh9utxBV1RCNC0tTbGxsWpqagpa3tTUJJfL1au/3W6X3R7eXwAACIeoXJ1PSEjQzJkzVVlZGVjm9/tVWVmp/Pz8aJQEAAMSta/zZWVlWrJkiW655Rbddttt+vWvf6329nY9+OCD0SoJAEIWtRC9//779dVXX2nNmjXyeDy6+eabtXv37l4XmwBgKIvqhaWVK1dq5cqV0SwBiAJLiV+3yh8XG/KaHSlj1T06zBeIWi4oofWbPts6kxLV6Rwd1s8baYbF1XlgRLGkq9+rlwZwq93nc6fp/PVXh7Wc1KNfKvODhj7bmmbm6Ms7p4T180YaQhQYZDZJNr8lKbSHBS1JisADhjbLUkyPv+9GP39R/XJ4ixMAGCBEAcAAIQoABghRADDAhSVgkFmSfMlj1GMP/Z9fdz/vRUD0EKLAYLPZdGrWFHnHp4W8qj+WL49DDSEKRIE/Lkb+eP75jQT8bw0ADBCiAGCAEAUAA4QoABjgzDYwlFiWElo7FNvZFfKqsR2dESgIl0OIAkPM1X+ul/OvTZfv+A9iunoiUA0uhxAFhpiYrh7F+bqjXQauEOdEAcAAIQoABghRADBAiAKAAS4sAUNM1xi7OpJD/+Nw8Rc6FdsZ+gWp7lHx/X4eb426PEIUGGJO3TlZMfnXhrzeNdWfKu3olyGv9/VNWTo/xd1nmz8+9L9I+m1DiAJDic0mvz1e/fzZuEuyYgcWeP74ON4oZYBzogBggBAFAAOEKAAYIEQBwABnk0eY7p4edfh8fbZ1dfOCiqHBUtyFTsW3fhPWrcZ09X97U6yvK+yfF9cR+pumRiJCdIT5sumsmr4+12dbdw8hOiRYUvb+o/LHhfeLYGxH/yGa/tHnSvvkZFg/L6aT3yeJEB1xenp61ENYDmk2Df5RXFxntzSAG/FxeZwTBQADhCgAGBjmX+dtks0W7SIAfIuFPUQrKir0xhtv6LPPPlNiYqK++93v6tlnn9XkyZMDfe666y5VVVUFrffwww9r06ZNIX3WHT//lcaOCf1FDQBwOW3tF6S9D122X9hDtKqqSiUlJbr11lvV3d2t//iP/9DcuXN19OhRjRkzJtBv+fLleuaZZwLzo0eHHobXTJ+tpKSksNQNAH+vtbX1ivqFPUR3794dNL9lyxalp6errq5Os2bNCiwfPXq0XC5XuD8eAAZVxC8stbS0SJJSUlKClr/66qtKS0vTjTfeqPLycl24cKHfbfh8Pnm93qAJAIaCiF5Y8vv9evTRR3XHHXfoxhtvDCz/6U9/qvHjx8vtduvIkSNavXq16uvr9cYbb/S5nYqKCq1fvz6SpQLAgNgsy7IitfFHHnlEb7/9tt577z1dc801/fbbt2+f5syZo4aGBk2aNKlXu8/nk+/vHmX0er3KyspSY2Mj50QBRERra6tycnLU0tIih8PRb7+IHYmuXLlSu3btUnV19SUDVJLy8vIkqd8QtdvtstvtEakTAEyEPUQty9KqVau0Y8cO7d+/Xzk5OZdd5/Dhw5KkzMzMcJcDABEV9hAtKSnR1q1b9eabbyopKUkej0eS5HQ6lZiYqOPHj2vr1q2aN2+eUlNTdeTIEZWWlmrWrFmaNm1auMsBgIgK+zlRWz9PEG3evFlLly7VyZMn9cADD+jjjz9We3u7srKydO+99+qpp5665HmHv+f1euV0OjknCiBionZO9HKZnJWV1etpJQAYrngBCQAYIEQBwAAhCgAGCFEAMECIAoABQhQADAzrN9s3n2pQz9gxl+8IACFqbWu/on7DOkQrn3tQifEcTAMIv2+6/FfUb1iHaPc3berq4m8sAQi/7u4re5iTwzgAMECIAoABQhQADBCiAGCAEAUAA4QoABggRAHAACEKAAYIUQAwQIgCgAFCFAAMEKIAYIAQBQADhCgAGCBEAcAAIQoABghRADBAiAKAAUIUAAwQogBggBAFAAOEKAAYIEQBwAAhCgAGwh6i69atk81mC5qmTJkSaO/o6FBJSYlSU1M1duxYLViwQE1NTeEuAwAGRUSORG+44QadOXMmML333nuBttLSUr311lvavn27qqqqdPr0ad13332RKAMAIi4uIhuNi5PL5eq1vKWlRb/97W+1detWff/735ckbd68Wddff70OHDig22+/PRLlAEDERORI9NixY3K73Zo4caIWL16sEydOSJLq6urU1dWlgoKCQN8pU6YoOztbNTU1/W7P5/PJ6/UGTQAwFIQ9RPPy8rRlyxbt3r1br7zyihobG/W9731Pra2t8ng8SkhIUHJyctA6GRkZ8ng8/W6zoqJCTqczMGVlZYW7bAAYkLB/nS8uLg78PG3aNOXl5Wn8+PH605/+pMTExAFts7y8XGVlZYF5r9dLkAIYEiJ+i1NycrKuu+46NTQ0yOVyqbOzU83NzUF9mpqa+jyHepHdbpfD4QiaAGAoiHiItrW16fjx48rMzNTMmTMVHx+vysrKQHt9fb1OnDih/Pz8SJcCAGEX9q/zjz/+uObPn6/x48fr9OnTWrt2rWJjY7Vo0SI5nU4tW7ZMZWVlSklJkcPh0KpVq5Sfn8+VeQDDUthD9NSpU1q0aJHOnTuncePG6c4779SBAwc0btw4SdJ//dd/KSYmRgsWLJDP51NhYaFefvnlcJcBAIPCZlmWFe0iQuX1euV0OrWhIFmj4mzRLgfACNTRbenJvc1qaWm55HUYnp0HAAOEKAAYIEQBwAAhCgAGCFEAMECIAoABQhQADBCiAGCAEAUAA4QoABggRAHAACEKAAYIUQAwQIgCgAFCFAAMEKIAYIAQBQADhCgAGCBEAcAAIQoABghRADBAiAKAAUIUAAwQogBggBAFAAOEKAAYIEQBwAAhCgAGCFEAMECIAoABQhQADBCiAGCAEAUAA2EP0QkTJshms/WaSkpKJEl33XVXr7af/exn4S4DAAZFXLg3+OGHH6qnpycw//HHH+sHP/iB/vmf/zmwbPny5XrmmWcC86NHjw53GQAwKMIeouPGjQua37BhgyZNmqR/+qd/CiwbPXq0XC5XuD8aAAZdRM+JdnZ26o9//KMeeugh2Wy2wPJXX31VaWlpuvHGG1VeXq4LFy5ccjs+n09erzdoAoChIOxHon9v586dam5u1tKlSwPLfvrTn2r8+PFyu906cuSIVq9erfr6er3xxhv9bqeiokLr16+PZKkAMCA2y7KsSG28sLBQCQkJeuutt/rts2/fPs2ZM0cNDQ2aNGlSn318Pp98Pl9g3uv1KisrSxsKkjUqztbnOgBgoqPb0pN7m9XS0iKHw9Fvv4gdiX7xxRfau3fvJY8wJSkvL0+SLhmidrtddrs97DUCgKmInRPdvHmz0tPTdffdd1+y3+HDhyVJmZmZkSoFACImIkeifr9fmzdv1pIlSxQX9/8/4vjx49q6davmzZun1NRUHTlyRKWlpZo1a5amTZsWiVIAIKIiEqJ79+7ViRMn9NBDDwUtT0hI0N69e/XrX/9a7e3tysrK0oIFC/TUU09FogwAiLiIhOjcuXPV1/WqrKwsVVVVReIjASAqeHYeAAwQogBggBAFAAOEKAAYIEQBwAAhCgAGCFEAMECIAoABQhQADBCiAGCAEAUAA4QoABggRAHAACEKAAYIUQAwQIgCgAFCFAAMEKIAYIAQBQADhCgAGCBEAcAAIQoABghRADBAiAKAAUIUAAwQogBggBAFAAOEKAAYIEQBwAAhCgAGCFEAMECIAoABQhQADIQcotXV1Zo/f77cbrdsNpt27twZ1G5ZltasWaPMzEwlJiaqoKBAx44dC+pz/vx5LV68WA6HQ8nJyVq2bJna2tqMdgQAoiHkEG1vb1dubq42btzYZ/tzzz2nF198UZs2bVJtba3GjBmjwsJCdXR0BPosXrxYn3zyifbs2aNdu3apurpaK1asGPheAECU2CzLsga8ss2mHTt26J577pH0t6NQt9utxx57TI8//rgkqaWlRRkZGdqyZYsWLlyoTz/9VFOnTtWHH36oW265RZK0e/duzZs3T6dOnZLb7b7s53q9XjmdTm0oSNaoONtAyweAfnV0W3pyb7NaWlrkcDj67RfWc6KNjY3yeDwqKCgILHM6ncrLy1NNTY0kqaamRsnJyYEAlaSCggLFxMSotra2z+36fD55vd6gCQCGgrCGqMfjkSRlZGQELc/IyAi0eTwepaenB7XHxcUpJSUl0OcfVVRUyOl0BqasrKxwlg0AAzYsrs6Xl5erpaUlMJ08eTLaJQGApDCHqMvlkiQ1NTUFLW9qagq0uVwunT17Nqi9u7tb58+fD/T5R3a7XQ6HI2gCgKEgrCGak5Mjl8ulysrKwDKv16va2lrl5+dLkvLz89Xc3Ky6urpAn3379snv9ysvLy+c5QBAxMWFukJbW5saGhoC842NjTp8+LBSUlKUnZ2tRx99VL/85S917bXXKicnR08//bTcbnfgCv7111+voqIiLV++XJs2bVJXV5dWrlyphQsXXtGVeQAYSkIO0YMHD2r27NmB+bKyMknSkiVLtGXLFj3xxBNqb2/XihUr1NzcrDvvvFO7d+/WqFGjAuu8+uqrWrlypebMmaOYmBgtWLBAL774Yhh2BwAGl9F9otHCfaIAIi0q94kCwLcNIQoABghRADBAiAKAAUIUAAwQogBgIOT7RIey+MQkxcYnSJL83V3qvMDbngBE1sgJUZtNMxY9IdeN35UknTt+RO9vekL+nq4oFwZgJBs5ISpplDNNSenZkqQL5z0S9+EDiDDOiQKAAUIUAAwQogBggBAFAAMj6sJSd8cF+dpaJEld37RHuRoA3wYjJ0QtSx+9/rw+fvNlSVJ3Z4f83dzeBCCyRk6I6v/d1gQAg4hzogBggBAFAAOEKAAYIEQBwMCIurAUSTExMf0+it/j9w9qLQCGDkL0CthsNk3JmaCkMWN6tVmWVP95o1pa26JQGYBoI0Sv0OjERDnGju213G9ZiouNjUJFAIYCzokCgAFCFAAMEKIAYIAQBQADhCgAGCBEAcAAIQoABghRADBAiAKAAUIUAAyEHKLV1dWaP3++3G63bDabdu7cGWjr6urS6tWrddNNN2nMmDFyu936t3/7N50+fTpoGxMmTJDNZguaNmzYYLwzADDYQg7R9vZ25ebmauPGjb3aLly4oEOHDunpp5/WoUOH9MYbb6i+vl4/+tGPevV95plndObMmcC0atWqge0BAERRyC8gKS4uVnFxcZ9tTqdTe/bsCVr2m9/8RrfddptOnDih7OzswPKkpCS5XK5QPx4AhpSInxNtaWmRzWZTcnJy0PINGzYoNTVV06dP1/PPP6/u7u5+t+Hz+eT1eoMmABgKIvoqvI6ODq1evVqLFi2Sw+EILP/FL36hGTNmKCUlRe+//77Ky8t15swZvfDCC31up6KiQuvXr49kqQAwIBEL0a6uLv3Lv/yLLMvSK6+8EtRWVlYW+HnatGlKSEjQww8/rIqKCtnt9l7bKi8vD1rH6/UqKysrUqUDwBWLSIheDNAvvvhC+/btCzoK7UteXp66u7v1+eefa/Lkyb3a7XZ7n+EKANEW9hC9GKDHjh3TO++8o9TU1Muuc/jwYcXExCg9PT3c5QBARIUcom1tbWpoaAjMNzY26vDhw0pJSVFmZqZ+8pOf6NChQ9q1a5d6enrk8XgkSSkpKUpISFBNTY1qa2s1e/ZsJSUlqaamRqWlpXrggQd01VVXhW/PAGAQhByiBw8e1OzZswPzF89VLlmyROvWrdP//M//SJJuvvnmoPXeeecd3XXXXbLb7dq2bZvWrVsnn8+nnJwclZaWBp3zBIDhIuQQveuuu2RZVr/tl2qTpBkzZujAgQOhfiwADEk8Ow8ABghRADBAiAKAAUIUAAxE9LHPkabPi2aXvo4GYIQjRK+AZVn668lTSojvPVyWJbW2X4hCVQCGAkL0Cp1vaYl2CcCI54+xyd/HwYok2fx+xXT1yDbINV0OIQpgyGjPvEon5twgy9Y7Ksd4mjV+z19k8w+tc2iEKIAhoychVt+kJEkxvUM0/oIvChVdHlfnAcAAIQoABghRADBAiAKAAUIUAAwQogBggBAFAAOEKAAYIEQBwAAhCgAGCFEAMECIAoABQhQADBCiAGCAEAUAA4QoABggRAHAACEKAAYIUQAwQIgCgAFCFAAMEKIAYIAQBQADhCgAGAg5RKurqzV//ny53W7ZbDbt3LkzqH3p0qWy2WxBU1FRUVCf8+fPa/HixXI4HEpOTtayZcvU1tZmtCMAEA0hh2h7e7tyc3O1cePGfvsUFRXpzJkzgem1114Lal+8eLE++eQT7dmzR7t27VJ1dbVWrFgRevUARhZLslmW1OcU7eL6FhfqCsXFxSouLr5kH7vdLpfL1Wfbp59+qt27d+vDDz/ULbfcIkl66aWXNG/ePP3nf/6n3G53qCUBGCHGnG3RpLfqZNlsvdriv+mUzT/0kjTkEL0S+/fvV3p6uq666ip9//vf1y9/+UulpqZKkmpqapScnBwIUEkqKChQTEyMamtrde+99/bans/nk8/nC8x7vd5IlA0gyuIvdCr5r2ejXUZIwn5hqaioSH/4wx9UWVmpZ599VlVVVSouLlZPT48kyePxKD09PWiduLg4paSkyOPx9LnNiooKOZ3OwJSVlRXusgFgQMJ+JLpw4cLAzzfddJOmTZumSZMmaf/+/ZozZ86AtlleXq6ysrLAvNfrJUgBDAkRv8Vp4sSJSktLU0NDgyTJ5XLp7Nngw/Xu7m6dP3++3/OodrtdDocjaAKAoSDiIXrq1CmdO3dOmZmZkqT8/Hw1Nzerrq4u0Gffvn3y+/3Ky8uLdDkAEFYhf51va2sLHFVKUmNjow4fPqyUlBSlpKRo/fr1WrBggVwul44fP64nnnhC3/nOd1RYWChJuv7661VUVKTly5dr06ZN6urq0sqVK7Vw4UKuzAMYdkI+Ej148KCmT5+u6dOnS5LKyso0ffp0rVmzRrGxsTpy5Ih+9KMf6brrrtOyZcs0c+ZMvfvuu7Lb7YFtvPrqq5oyZYrmzJmjefPm6c4779R///d/h2+vAGCQ2CzLGno3Xl2G1+uV0+nUhoJkjYrrfT8ZAJjq6Lb05N5mtbS0XPI6DM/OA4ABQhQADBCiAGCAEAUAA4QoABggRAHAACEKAAYIUQAwQIgCgAFCFAAMEKIAYIAQBQADhCgAGCBEAcAAIQoABghRADBAiAKAAUIUAAwQogBggBAFAAOEKAAYIEQBwAAhCgAGCFEAMECIAoABQhQADBCiAGCAEAUAA4QoABggRAHAACEKAAYIUQAwQIgCgIGQQ7S6ulrz58+X2+2WzWbTzp07g9ptNluf0/PPPx/oM2HChF7tGzZsMN4ZABhsIYdoe3u7cnNztXHjxj7bz5w5EzT97ne/k81m04IFC4L6PfPMM0H9Vq1aNbA9AIAoigt1heLiYhUXF/fb7nK5gubffPNNzZ49WxMnTgxanpSU1KsvAAw3ET0n2tTUpP/93//VsmXLerVt2LBBqampmj59up5//nl1d3f3ux2fzyev1xs0AcBQEPKRaCh+//vfKykpSffdd1/Q8l/84heaMWOGUlJS9P7776u8vFxnzpzRCy+80Od2KioqtH79+kiWCgADYrMsyxrwyjabduzYoXvuuafP9ilTpugHP/iBXnrppUtu53e/+50efvhhtbW1yW6392r3+Xzy+XyBea/Xq6ysLG0oSNaoONtAyweAfnV0W3pyb7NaWlrkcDj67RexI9F3331X9fX1ev311y/bNy8vT93d3fr88881efLkXu12u73PcAWAaIvYOdHf/va3mjlzpnJzcy/b9/Dhw4qJiVF6enqkygGAiAj5SLStrU0NDQ2B+cbGRh0+fFgpKSnKzs6W9Lev29u3b9evfvWrXuvX1NSotrZWs2fPVlJSkmpqalRaWqoHHnhAV111lcGuAMDgCzlEDx48qNmzZwfmy8rKJElLlizRli1bJEnbtm2TZVlatGhRr/Xtdru2bdumdevWyefzKScnR6WlpYHtAMBwYnRhKVq8Xq+cTicXlgBEzJVeWOLZeQAwQIgCgAFCFAAMEKIAYIAQBQADhCgAGCBEAcAAIQoABghRADBAiAKAAUIUAAwQogBggBAFAAOEKAAYIEQBwAAhCgAGCFEAMECIAoABQhQADBCiAGCAEAUAA4QoABggRAHAACEKAAbiol2AibTvzNBo+7DeBQBD1AVft7R332X72SzLsgahnrDyer1yOp1q+D/1SkpKinY5AEag1tZWfee6yWppaZHD4ei337A+jIuNT1BsfEK0ywAwAl1ptnBOFAAMEKIAYIAQBQADhCgAGCBEAcAAIQoABkIK0YqKCt16661KSkpSenq67rnnHtXX1wf16ejoUElJiVJTUzV27FgtWLBATU1NQX1OnDihu+++W6NHj1Z6err+/d//Xd3d3eZ7AwCDLKQQraqqUklJiQ4cOKA9e/aoq6tLc+fOVXt7e6BPaWmp3nrrLW3fvl1VVVU6ffq07rvvvkB7T0+P7r77bnV2dur999/X73//e23ZskVr1qwJ314BwCAxemLpq6++Unp6uqqqqjRr1iy1tLRo3Lhx2rp1q37yk59Ikj777DNdf/31qqmp0e233663335bP/zhD3X69GllZGRIkjZt2qTVq1frq6++UkLC5W9wvfjEUmNjI08sAYiI1tZW5eTkXPaJJaNzoi0tLZKklJQUSVJdXZ26urpUUFAQ6DNlyhRlZ2erpqZGklRTU6ObbropEKCSVFhYKK/Xq08++aTPz/H5fPJ6vUETAAwFAw5Rv9+vRx99VHfccYduvPFGSZLH41FCQoKSk5OD+mZkZMjj8QT6/H2AXmy/2NaXiooKOZ3OwJSVlTXQsgEgrAYcoiUlJfr444+1bdu2cNbTp/LycrW0tASmkydPRvwzAeBKDOgFJCtXrtSuXbtUXV2ta665JrDc5XKps7NTzc3NQUejTU1NcrlcgT4ffPBB0PYuXr2/2Ocf2e122e32gZQKABEV0pGoZVlauXKlduzYoX379iknJyeofebMmYqPj1dlZWVgWX19vU6cOKH8/HxJUn5+vv7yl7/o7NmzgT579uyRw+HQ1KlTTfYFAAZdSEeiJSUl2rp1q958800lJSUFzmE6nU4lJibK6XRq2bJlKisrU0pKihwOh1atWqX8/HzdfvvtkqS5c+dq6tSp+td//Vc999xz8ng8euqpp1RSUsLRJoBhJ6RbnGw2W5/LN2/erKVLl0r62832jz32mF577TX5fD4VFhbq5ZdfDvqq/sUXX+iRRx7R/v37NWbMGC1ZskQbNmxQXNyVZTq3OAGItCu9xWlYv9meEAUQKYNynygAfNsRogBggBAFAAOEKAAYIEQBwAAhCgAGCFEAMECIAoCBAb2AJNouPh/Q2toa5UoAjFQX8+VyzyMNyxC9uHPTpk2LciUARrrW1lY5nc5+24flY59+v1/19fWaOnWqTp48eclHsjAwXq9XWVlZjG+EML6RFY7xtSxLra2tcrvdionp/8znsDwSjYmJ0dVXXy1Jcjgc/BJGEOMbWYxvZJmO76WOQC/iwhIAGCBEAcDAsA1Ru92utWvX8iLnCGF8I4vxjazBHN9heWEJAIaKYXskCgBDASEKAAYIUQAwQIgCgAFCFAAMDMsQ3bhxoyZMmKBRo0YpLy9PH3zwQbRLGpbWrVsnm80WNE2ZMiXQ3tHRoZKSEqWmpmrs2LFasGCBmpqaoljx0FZdXa358+fL7XbLZrNp586dQe2WZWnNmjXKzMxUYmKiCgoKdOzYsaA+58+f1+LFi+VwOJScnKxly5apra1tEPdi6Lrc+C5durTX73NRUVFQn0iM77AL0ddff11lZWVau3atDh06pNzcXBUWFurs2bPRLm1YuuGGG3TmzJnA9N577wXaSktL9dZbb2n79u2qqqrS6dOndd9990Wx2qGtvb1dubm52rhxY5/tzz33nF588UVt2rRJtbW1GjNmjAoLC9XR0RHos3jxYn3yySfas2ePdu3aperqaq1YsWKwdmFIu9z4SlJRUVHQ7/Nrr70W1B6R8bWGmdtuu80qKSkJzPf09Fhut9uqqKiIYlXD09q1a63c3Nw+25qbm634+Hhr+/btgWWffvqpJcmqqakZpAqHL0nWjh07AvN+v99yuVzW888/H1jW3Nxs2e1267XXXrMsy7KOHj1qSbI+/PDDQJ+3337bstls1pdffjlotQ8H/zi+lmVZS5YssX784x/3u06kxndYHYl2dnaqrq5OBQUFgWUxMTEqKChQTU1NFCsbvo4dOya3262JEydq8eLFOnHihCSprq5OXV1dQWM9ZcoUZWdnM9YD0NjYKI/HEzSeTqdTeXl5gfGsqalRcnKybrnllkCfgoICxcTEqLa2dtBrHo7279+v9PR0TZ48WY888ojOnTsXaIvU+A6rEP3666/V09OjjIyMoOUZGRnyeDxRqmr4ysvL05YtW7R792698soramxs1Pe+9z21trbK4/EoISFBycnJQesw1gNzccwu9bvr8XiUnp4e1B4XF6eUlBTG/AoUFRXpD3/4gyorK/Xss8+qqqpKxcXF6unpkRS58R2Wr8JDeBQXFwd+njZtmvLy8jR+/Hj96U9/UmJiYhQrA0K3cOHCwM833XSTpk2bpkmTJmn//v2aM2dOxD53WB2JpqWlKTY2ttcV4qamJrlcrihVNXIkJyfruuuuU0NDg1wulzo7O9Xc3BzUh7EemItjdqnfXZfL1esCaXd3t86fP8+YD8DEiROVlpamhoYGSZEb32EVogkJCZo5c6YqKysDy/x+vyorK5Wfnx/FykaGtrY2HT9+XJmZmZo5c6bi4+ODxrq+vl4nTpxgrAcgJydHLpcraDy9Xq9qa2sD45mfn6/m5mbV1dUF+uzbt09+v195eXmDXvNwd+rUKZ07d06ZmZmSIji+A74kFSXbtm2z7Ha7tWXLFuvo0aPWihUrrOTkZMvj8US7tGHnscces/bv3281NjZaf/7zn62CggIrLS3NOnv2rGVZlvWzn/3Mys7Otvbt22cdPHjQys/Pt/Lz86Nc9dDV2tpqffTRR9ZHH31kSbJeeOEF66OPPrK++OILy7Isa8OGDVZycrL15ptvWkeOHLF+/OMfWzk5OdY333wT2EZRUZE1ffp0q7a21nrvvfesa6+91lq0aFG0dmlIudT4tra2Wo8//rhVU1NjNTY2Wnv37rVmzJhhXXvttVZHR0dgG5EY32EXopZlWS+99JKVnZ1tJSQkWLfddpt14MCBaJc0LN1///1WZmamlZCQYF199dXW/fffbzU0NATav/nmG+vnP/+5ddVVV1mjR4+27r33XuvMmTNRrHhoe+eddyxJvaYlS5ZYlvW325yefvppKyMjw7Lb7dacOXOs+vr6oG2cO3fOWrRokTV27FjL4XBYDz74oNXa2hqFvRl6LjW+Fy5csObOnWuNGzfOio+Pt8aPH28tX76818FVJMaX94kCgIFhdU4UAIYaQhQADBCiAGCAEAUAA4QoABggRAHAACEKAAYIUQAwQIgCgAFCFAAMEKIAYOD/AhGOA+EqidnAAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Steps taken: 1662\n", "Final score: 20.0\n", "Penalties: 0\n" ] } ], "source": [ "state, _ = environment.reset()\n", "\n", "score, steps, penalties, reward = 0, 0, 0, 0\n", "\n", "done=False\n", "while not done:\n", " steps += 1\n", " action = pi_behavior(state)\n", " state, reward, done, truncated, info = environment.step(action)\n", " score += reward\n", " if reward == -10:\n", " penalties += 1\n", "\n", " clear_output(wait=True)\n", " pyplot.imshow(environment.render())\n", " pyplot.show()\n", "\n", "print(f\"Steps taken: {steps}\")\n", "print(f\"Final score: {score}\")\n", "print(f\"Penalties: {penalties}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }