{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import numpy as np\n", "import gym\n", "from keras.models import Sequential\n", "from keras.layers import Dense, Activation, Flatten\n", "from keras.optimizers import Adam\n", "from rl.agents import SARSAAgent\n", "from rl.policy import EpsGreedyQPolicy" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" ] } ], "source": [ "# load the environment\n", "env = gym.make('CartPole-v1')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# set seed \n", "seed_val = 456\n", "env.seed(seed_val)\n", "np.random.seed(seed_val)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "states = env.observation_space.shape[0]\n", "actions = env.action_space.n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# define agent\n", "def agent(states, actions):\n", " \"\"\"Simple Deep Neural Network.\"\"\"\n", " model = Sequential()\n", " model.add(Flatten(input_shape=(1,states)))\n", " model.add(Dense(16))\n", " model.add(Activation('relu'))\n", " model.add(Dense(16))\n", " model.add(Activation('relu'))\n", " model.add(Dense(16))\n", " model.add(Activation('relu'))\n", " model.add(Dense(actions))\n", " model.add(Activation('linear'))\n", " return model\n", "\n", "model = agent(states, actions)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training for 50000 steps ...\n", "Interval 1 (0 steps performed)\n", "10000/10000 [==============================] - 95s 10ms/step - reward: 1.0000\n", "319 episodes - episode_reward: 30.984 [8.000, 500.000] - loss: 7.445 - mean_squared_error: 552.109 - mean_q: 29.475\n", "\n", "Interval 2 (10000 steps performed)\n", "10000/10000 [==============================] - 93s 9ms/step - reward: 1.0000\n", "122 episodes - episode_reward: 82.467 [9.000, 435.000] - loss: 6.987 - mean_squared_error: 831.830 - mean_q: 39.221\n", "\n", "Interval 3 (20000 steps performed)\n", "10000/10000 [==============================] - 89s 9ms/step - reward: 1.0000\n", "81 episodes - episode_reward: 122.617 [14.000, 500.000] - loss: 10.447 - mean_squared_error: 1416.096 - mean_q: 52.411\n", "\n", "Interval 4 (30000 steps performed)\n", "10000/10000 [==============================] - 95s 9ms/step - reward: 1.0000\n", "75 episodes - episode_reward: 133.933 [15.000, 500.000] - loss: 6.933 - mean_squared_error: 1527.035 - mean_q: 54.172\n", "\n", "Interval 5 (40000 steps performed)\n", "10000/10000 [==============================] - 99s 10ms/step - reward: 1.0000\n", "done, took 470.141 seconds\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define the policy\n", "policy = EpsGreedyQPolicy()\n", "\n", "# Define SARSA agent by feeding it the policy and the model\n", "sarsa = SARSAAgent(model=model, nb_actions=actions, nb_steps_warmup=10, policy=policy)\n", "\n", "# compile sarsa with mean squared error loss\n", "sarsa.compile('adam', metrics=['mse'])\n", "\n", "# train the agent for 50000 steps\n", "sarsa.fit(env, nb_steps=50000, visualize=False, verbose=1)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing for 100 episodes ...\n", "Episode 1: reward: 350.000, steps: 350\n", "Episode 2: reward: 383.000, steps: 383\n", "Episode 3: reward: 316.000, steps: 316\n", "Episode 4: reward: 371.000, steps: 371\n", "Episode 5: reward: 356.000, steps: 356\n", "Episode 6: reward: 399.000, steps: 399\n", "Episode 7: reward: 317.000, steps: 317\n", "Episode 8: reward: 370.000, steps: 370\n", "Episode 9: reward: 364.000, steps: 364\n", "Episode 10: reward: 432.000, steps: 432\n", "Episode 11: reward: 331.000, steps: 331\n", "Episode 12: reward: 387.000, steps: 387\n", "Episode 13: reward: 361.000, steps: 361\n", "Episode 14: reward: 365.000, steps: 365\n", "Episode 15: reward: 350.000, steps: 350\n", "Episode 16: reward: 375.000, steps: 375\n", "Episode 17: reward: 396.000, steps: 396\n", "Episode 18: reward: 372.000, steps: 372\n", "Episode 19: reward: 386.000, steps: 386\n", "Episode 20: reward: 388.000, steps: 388\n", "Episode 21: reward: 326.000, steps: 326\n", "Episode 22: reward: 336.000, steps: 336\n", "Episode 23: reward: 413.000, steps: 413\n", "Episode 24: reward: 387.000, steps: 387\n", "Episode 25: reward: 356.000, steps: 356\n", "Episode 26: reward: 374.000, steps: 374\n", "Episode 27: reward: 354.000, steps: 354\n", "Episode 28: reward: 328.000, steps: 328\n", "Episode 29: reward: 381.000, steps: 381\n", "Episode 30: reward: 308.000, steps: 308\n", "Episode 31: reward: 348.000, steps: 348\n", "Episode 32: reward: 367.000, steps: 367\n", "Episode 33: reward: 343.000, steps: 343\n", "Episode 34: reward: 387.000, steps: 387\n", "Episode 35: reward: 327.000, steps: 327\n", "Episode 36: reward: 378.000, steps: 378\n", "Episode 37: reward: 339.000, steps: 339\n", "Episode 38: reward: 383.000, steps: 383\n", "Episode 39: reward: 367.000, steps: 367\n", "Episode 40: reward: 333.000, steps: 333\n", "Episode 41: reward: 352.000, steps: 352\n", "Episode 42: reward: 340.000, steps: 340\n", "Episode 43: reward: 365.000, steps: 365\n", "Episode 44: reward: 357.000, steps: 357\n", "Episode 45: reward: 363.000, steps: 363\n", "Episode 46: reward: 368.000, steps: 368\n", "Episode 47: reward: 385.000, steps: 385\n", "Episode 48: reward: 353.000, steps: 353\n", "Episode 49: reward: 342.000, steps: 342\n", "Episode 50: reward: 340.000, steps: 340\n", "Episode 51: reward: 412.000, steps: 412\n", "Episode 52: reward: 366.000, steps: 366\n", "Episode 53: reward: 362.000, steps: 362\n", "Episode 54: reward: 339.000, steps: 339\n", "Episode 55: reward: 306.000, steps: 306\n", "Episode 56: reward: 391.000, steps: 391\n", "Episode 57: reward: 396.000, steps: 396\n", "Episode 58: reward: 371.000, steps: 371\n", "Episode 59: reward: 348.000, steps: 348\n", "Episode 60: reward: 387.000, steps: 387\n", "Episode 61: reward: 392.000, steps: 392\n", "Episode 62: reward: 394.000, steps: 394\n", "Episode 63: reward: 387.000, steps: 387\n", "Episode 64: reward: 388.000, steps: 388\n", "Episode 65: reward: 376.000, steps: 376\n", "Episode 66: reward: 385.000, steps: 385\n", "Episode 67: reward: 342.000, steps: 342\n", "Episode 68: reward: 382.000, steps: 382\n", "Episode 69: reward: 418.000, steps: 418\n", "Episode 70: reward: 351.000, steps: 351\n", "Episode 71: reward: 354.000, steps: 354\n", "Episode 72: reward: 355.000, steps: 355\n", "Episode 73: reward: 355.000, steps: 355\n", "Episode 74: reward: 390.000, steps: 390\n", "Episode 75: reward: 384.000, steps: 384\n", "Episode 76: reward: 344.000, steps: 344\n", "Episode 77: reward: 384.000, steps: 384\n", "Episode 78: reward: 338.000, steps: 338\n", "Episode 79: reward: 375.000, steps: 375\n", "Episode 80: reward: 380.000, steps: 380\n", "Episode 81: reward: 353.000, steps: 353\n", "Episode 82: reward: 381.000, steps: 381\n", "Episode 83: reward: 350.000, steps: 350\n", "Episode 84: reward: 329.000, steps: 329\n", "Episode 85: reward: 319.000, steps: 319\n", "Episode 86: reward: 370.000, steps: 370\n", "Episode 87: reward: 390.000, steps: 390\n", "Episode 88: reward: 382.000, steps: 382\n", "Episode 89: reward: 395.000, steps: 395\n", "Episode 90: reward: 334.000, steps: 334\n", "Episode 91: reward: 392.000, steps: 392\n", "Episode 92: reward: 393.000, steps: 393\n", "Episode 93: reward: 359.000, steps: 359\n", "Episode 94: reward: 370.000, steps: 370\n", "Episode 95: reward: 363.000, steps: 363\n", "Episode 96: reward: 349.000, steps: 349\n", "Episode 97: reward: 338.000, steps: 338\n", "Episode 98: reward: 387.000, steps: 387\n", "Episode 99: reward: 367.000, steps: 367\n", "Episode 100: reward: 425.000, steps: 425\n", "Average score over 100 test games: 365.67\n" ] } ], "source": [ "# Evaluate the agent on 100 new episodes.\n", "scores = sarsa.test(env, nb_episodes=100, visualize=False)\n", "\n", "print('Average score over 100 test games: {}'.format(np.mean(scores.history['episode_reward'])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" } }, "nbformat": 4, "nbformat_minor": 2 }