309 lines
11 KiB
Python
309 lines
11 KiB
Python
|
"""
|
||
|
Classic cart-pole system implemented by Rich Sutton et al.
|
||
|
Copied from http://incompleteideas.net/sutton/book/code/pole.c
|
||
|
permalink: https://perma.cc/C9ZM-652R
|
||
|
"""
|
||
|
import math
|
||
|
from typing import Optional, Union
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
import gym
|
||
|
from gym import logger, spaces
|
||
|
from gym.spaces import Box
|
||
|
from gym.envs.classic_control import utils
|
||
|
from gym.error import DependencyNotInstalled
|
||
|
|
||
|
|
||
|
class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||
|
"""
|
||
|
### Description
|
||
|
|
||
|
This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson in
|
||
|
["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077).
|
||
|
A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track.
|
||
|
The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces
|
||
|
in the left and right direction on the cart.
|
||
|
|
||
|
### Action Space
|
||
|
|
||
|
Due to the policy shaping approach the action is a 'ndarray' with shape '(1,)' which can take values '[-1,1]' which
|
||
|
is scaled by the force_mag pushing the cart to the left if it is lower than 0, to the right if it is higher than 0
|
||
|
and doing nothing if the action is equal to 0
|
||
|
|
||
|
|
||
|
### Observation Space
|
||
|
|
||
|
The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities:
|
||
|
|
||
|
| Num | Observation | Min | Max |
|
||
|
|-----|-----------------------|---------------------|-------------------|
|
||
|
| 0 | Cart Position | -4.8 | 4.8 |
|
||
|
| 1 | Cart Velocity | -Inf | Inf |
|
||
|
| 2 | Pole Angle | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
|
||
|
| 3 | Pole Angular Velocity | -Inf | Inf |
|
||
|
|
||
|
**Note:** While the ranges above denote the possible values for observation space of each element,
|
||
|
it is not reflective of the allowed values of the state space in an unterminated episode. Particularly:
|
||
|
- The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates
|
||
|
if the cart leaves the `(-2.4, 2.4)` range.
|
||
|
- The pole angle can be observed between `(-.418, .418)` radians (or **±24°**), but the episode terminates
|
||
|
if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)
|
||
|
|
||
|
### Rewards
|
||
|
|
||
|
Since the goal is to keep the pole upright for as long as possible, a reward of `+1` for every step taken,
|
||
|
including the termination step, is allotted. The threshold for rewards is 475 for v1.
|
||
|
|
||
|
### Starting State
|
||
|
|
||
|
All observations are assigned a uniformly random value in `(-0.05, 0.05)`
|
||
|
|
||
|
### Episode End
|
||
|
|
||
|
The episode ends if any one of the following occurs:
|
||
|
|
||
|
1. Termination: Pole Angle is greater than ±12°
|
||
|
2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
|
||
|
3. Truncation: Episode length is greater than 500 (200 for v0)
|
||
|
|
||
|
### Arguments
|
||
|
|
||
|
```
|
||
|
gym.make('CartPole-v1')
|
||
|
```
|
||
|
|
||
|
No additional arguments are currently supported.
|
||
|
"""
|
||
|
|
||
|
metadata = {
|
||
|
"render_modes": ["human", "rgb_array"],
|
||
|
"render_fps": 50,
|
||
|
}
|
||
|
|
||
|
def __init__(self, render_mode: Optional[str] = None):
|
||
|
self.gravity = 9.8
|
||
|
self.masscart = 1.0
|
||
|
self.masspole = 0.1
|
||
|
self.total_mass = self.masspole + self.masscart
|
||
|
self.length = 0.5 # actually half the pole's length
|
||
|
self.polemass_length = self.masspole * self.length
|
||
|
self.force_mag = 10.0
|
||
|
self.tau = 0.02 # seconds between state updates
|
||
|
self.kinematics_integrator = "euler"
|
||
|
|
||
|
# Angle at which to fail the episode
|
||
|
self.theta_threshold_radians = 12 * 2 * math.pi / 360
|
||
|
self.x_threshold = 2.4
|
||
|
|
||
|
# Angle limit set to 2 * theta_threshold_radians so failing observation
|
||
|
# is still within bounds.
|
||
|
high = np.array(
|
||
|
[
|
||
|
self.x_threshold * 2,
|
||
|
np.finfo(np.float32).max,
|
||
|
self.theta_threshold_radians * 2,
|
||
|
np.finfo(np.float32).max,
|
||
|
],
|
||
|
dtype=np.float32,
|
||
|
)
|
||
|
|
||
|
self.action_space = Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
|
||
|
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
||
|
|
||
|
self.render_mode = render_mode
|
||
|
|
||
|
self.screen_width = 600
|
||
|
self.screen_height = 400
|
||
|
self.screen = None
|
||
|
self.clock = None
|
||
|
self.isopen = True
|
||
|
self.state = None
|
||
|
|
||
|
self.steps_beyond_terminated = None
|
||
|
|
||
|
def step(self, action):
|
||
|
err_msg = f"{action!r} ({type(action)}) invalid"
|
||
|
assert self.action_space.contains(action), err_msg
|
||
|
assert self.state is not None, "Call reset before using step method."
|
||
|
x, x_dot, theta, theta_dot = self.state
|
||
|
# changed usage of action due to policy shaping approach
|
||
|
force = action * self.force_mag
|
||
|
costheta = math.cos(theta)
|
||
|
sintheta = math.sin(theta)
|
||
|
|
||
|
# For the interested reader:
|
||
|
# https://coneural.org/florian/papers/05_cart_pole.pdf
|
||
|
temp = (
|
||
|
force + self.polemass_length * theta_dot**2 * sintheta
|
||
|
) / self.total_mass
|
||
|
thetaacc = (self.gravity * sintheta - costheta * temp) / (
|
||
|
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
|
||
|
)
|
||
|
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
|
||
|
|
||
|
if self.kinematics_integrator == "euler":
|
||
|
x = x + self.tau * x_dot
|
||
|
x_dot = x_dot + self.tau * xacc
|
||
|
theta = theta + self.tau * theta_dot
|
||
|
theta_dot = theta_dot + self.tau * thetaacc
|
||
|
else: # semi-implicit euler
|
||
|
x_dot = x_dot + self.tau * xacc
|
||
|
x = x + self.tau * x_dot
|
||
|
theta_dot = theta_dot + self.tau * thetaacc
|
||
|
theta = theta + self.tau * theta_dot
|
||
|
|
||
|
self.state = (x, x_dot[0], theta, theta_dot[0])
|
||
|
|
||
|
terminated = bool(
|
||
|
x < -self.x_threshold
|
||
|
or x > self.x_threshold
|
||
|
or theta < -self.theta_threshold_radians
|
||
|
or theta > self.theta_threshold_radians
|
||
|
)
|
||
|
|
||
|
if not terminated:
|
||
|
reward = 1.0
|
||
|
elif self.steps_beyond_terminated is None:
|
||
|
# Pole just fell!
|
||
|
self.steps_beyond_terminated = 0
|
||
|
reward = 1.0
|
||
|
else:
|
||
|
if self.steps_beyond_terminated == 0:
|
||
|
logger.warn(
|
||
|
"You are calling 'step()' even though this "
|
||
|
"environment has already returned terminated = True. You "
|
||
|
"should always call 'reset()' once you receive 'terminated = "
|
||
|
"True' -- any further steps are undefined behavior."
|
||
|
)
|
||
|
self.steps_beyond_terminated += 1
|
||
|
reward = 0.0
|
||
|
|
||
|
if self.render_mode == "human":
|
||
|
self.render()
|
||
|
|
||
|
return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
|
||
|
|
||
|
def reset(
|
||
|
self,
|
||
|
*,
|
||
|
seed: Optional[int] = None,
|
||
|
options: Optional[dict] = None,
|
||
|
):
|
||
|
super().reset(seed=seed)
|
||
|
# Note that if you use custom reset bounds, it may lead to out-of-bound
|
||
|
# state/observations.
|
||
|
low, high = utils.maybe_parse_reset_bounds(
|
||
|
options, -0.05, 0.05 # default low
|
||
|
) # default high
|
||
|
self.state = self.np_random.uniform(low=low, high=high, size=(4,))
|
||
|
self.steps_beyond_terminated = None
|
||
|
|
||
|
if self.render_mode == "human":
|
||
|
self.render()
|
||
|
return np.array(self.state, dtype=np.float32), {}
|
||
|
|
||
|
def render(self):
|
||
|
if self.render_mode is None:
|
||
|
gym.logger.warn(
|
||
|
"You are calling render method without specifying any render mode. "
|
||
|
"You can specify the render_mode at initialization, "
|
||
|
f'e.g. gym("{self.spec.id}", render_mode="rgb_array")'
|
||
|
)
|
||
|
return
|
||
|
|
||
|
try:
|
||
|
import pygame
|
||
|
from pygame import gfxdraw
|
||
|
except ImportError:
|
||
|
raise DependencyNotInstalled(
|
||
|
"pygame is not installed, run `pip install gym[classic_control]`"
|
||
|
)
|
||
|
|
||
|
if self.screen is None:
|
||
|
pygame.init()
|
||
|
if self.render_mode == "human":
|
||
|
pygame.display.init()
|
||
|
self.screen = pygame.display.set_mode(
|
||
|
(self.screen_width, self.screen_height)
|
||
|
)
|
||
|
else: # mode == "rgb_array"
|
||
|
self.screen = pygame.Surface((self.screen_width, self.screen_height))
|
||
|
if self.clock is None:
|
||
|
self.clock = pygame.time.Clock()
|
||
|
|
||
|
world_width = self.x_threshold * 2
|
||
|
scale = self.screen_width / world_width
|
||
|
polewidth = 10.0
|
||
|
polelen = scale * (2 * self.length)
|
||
|
cartwidth = 50.0
|
||
|
cartheight = 30.0
|
||
|
|
||
|
if self.state is None:
|
||
|
return None
|
||
|
|
||
|
x = self.state
|
||
|
|
||
|
self.surf = pygame.Surface((self.screen_width, self.screen_height))
|
||
|
self.surf.fill((255, 255, 255))
|
||
|
|
||
|
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
|
||
|
axleoffset = cartheight / 4.0
|
||
|
cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART
|
||
|
carty = 100 # TOP OF CART
|
||
|
cart_coords = [(l, b), (l, t), (r, t), (r, b)]
|
||
|
cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
|
||
|
gfxdraw.aapolygon(self.surf, cart_coords, (0, 0, 0))
|
||
|
gfxdraw.filled_polygon(self.surf, cart_coords, (0, 0, 0))
|
||
|
|
||
|
l, r, t, b = (
|
||
|
-polewidth / 2,
|
||
|
polewidth / 2,
|
||
|
polelen - polewidth / 2,
|
||
|
-polewidth / 2,
|
||
|
)
|
||
|
|
||
|
pole_coords = []
|
||
|
for coord in [(l, b), (l, t), (r, t), (r, b)]:
|
||
|
coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
|
||
|
coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
|
||
|
pole_coords.append(coord)
|
||
|
gfxdraw.aapolygon(self.surf, pole_coords, (202, 152, 101))
|
||
|
gfxdraw.filled_polygon(self.surf, pole_coords, (202, 152, 101))
|
||
|
|
||
|
gfxdraw.aacircle(
|
||
|
self.surf,
|
||
|
int(cartx),
|
||
|
int(carty + axleoffset),
|
||
|
int(polewidth / 2),
|
||
|
(129, 132, 203),
|
||
|
)
|
||
|
gfxdraw.filled_circle(
|
||
|
self.surf,
|
||
|
int(cartx),
|
||
|
int(carty + axleoffset),
|
||
|
int(polewidth / 2),
|
||
|
(129, 132, 203),
|
||
|
)
|
||
|
|
||
|
gfxdraw.hline(self.surf, 0, self.screen_width, carty, (0, 0, 0))
|
||
|
|
||
|
self.surf = pygame.transform.flip(self.surf, False, True)
|
||
|
self.screen.blit(self.surf, (0, 0))
|
||
|
if self.render_mode == "human":
|
||
|
pygame.event.pump()
|
||
|
self.clock.tick(self.metadata["render_fps"])
|
||
|
pygame.display.flip()
|
||
|
|
||
|
elif self.render_mode == "rgb_array":
|
||
|
return np.transpose(
|
||
|
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
|
||
|
)
|
||
|
|
||
|
def close(self):
|
||
|
if self.screen is not None:
|
||
|
import pygame
|
||
|
|
||
|
pygame.display.quit()
|
||
|
pygame.quit()
|
||
|
self.isopen = False
|