PyTorch is one of the most used frameworks in the field of deep learning. We can use this library in every aspect and field data science and machine learning. We can also use it for reinforcement learning. In one of our articles, we have discussed reinforcement learning and the procedure that can be followed for building reinforcement learning models using TensorFlow in detail. In this article, we will discuss how we can build reinforcement learning models using PyTorch. The major points to be discussed in this article are listed below.
Table of contents
The CartPole problem
Importing libraries
Defining setup
Storing memory
Deep Q network Algorithm
Training of network
Let’s start with understanding the CartPole problem.
The CartPole problem
In this article, we are dealing with the CartPole problem. Where we will try to make an agent learn to decide whether to move the cart on the left side or right side so that the pole attached to the cart will be straight. We can see an example in the below image.
Internally, any action taken by an agent depends on the state of the environment, as the environment changes its state it returns a reward to the agent that decides the action of the agent at that state of the environment.
Here in this procedure, we will use +1 rewards at every timestep and if the cart moves more than a limit from the centre (here it is 2.4 units) or the pole will fall over too far, the environment will not return any reward to the agent. This is how if the agent is performing well then only it will work for a longer duration and get larger rewards.
Using the below lines of codes we can render the CartPole problem.
import gym
import numpy as np
import matplotlib.pyplot as plt
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
display = Display(visible=0, size=(400, 300))
display.start()
env = gym.make("CartPole-v0")
env.reset()
prev_screen = env.render(mode='rgb_array')
plt.imshow(prev_screen)
for i in range(50):
action = env.action_space.sample()
obs, reward, done, info = env.step(action)
screen = env.render(mode='rgb_array')
plt.imshow(screen)
ipythondisplay.clear_output(wait=True)
ipythondisplay.display(plt.gcf())
if done:
break
ipythondisplay.clear_output(wait=True)