Adding Klyubin Mazeworld
This commit is contained in:
parent
80760bb686
commit
961e46c347
102
empowerment/maze.py
Normal file
102
empowerment/maze.py
Normal file
@ -0,0 +1,102 @@
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class Maze(object):
|
||||
def __init__(self, height, width):
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
self.actions = OrderedDict()
|
||||
self.actions["N"] = np.array([1, 0]) # UP
|
||||
self.actions["S"] = np.array([-1, 0]) # DOWN
|
||||
self.actions["E"] = np.array([0, 1]) # RIGHT
|
||||
self.actions["W"] = np.array([0, -1]) # LEFT
|
||||
self.actions["_"] = np.array([0, 0]) # STAY
|
||||
|
||||
self.opposite = OrderedDict()
|
||||
self.opposite["N"] = "S"
|
||||
self.opposite["S"] = "N"
|
||||
self.opposite["W"] = "E"
|
||||
self.opposite["E"] = "W"
|
||||
|
||||
self.wall_states = []
|
||||
|
||||
self.vecmod = np.vectorize(lambda x, y : x % y)
|
||||
|
||||
|
||||
|
||||
def act(self, s, a, check_wall = True):
|
||||
""" get updated state after action
|
||||
|
||||
s : state, index of grid position
|
||||
a : action
|
||||
prob : probability of performing action
|
||||
"""
|
||||
state = self.state_to_cell(s)
|
||||
new_state = state + self.actions[a]
|
||||
if self.is_outside(new_state):
|
||||
return s
|
||||
if check_wall and (s, a) in self.wall_states:
|
||||
return s
|
||||
return self.cell_to_state(new_state)
|
||||
|
||||
|
||||
def add_wall(self, cell, direction):
|
||||
state = self.cell_to_state(cell)
|
||||
self.wall_states.append((state, direction))
|
||||
|
||||
resulting_states = self.act(state, direction, check_wall = False)
|
||||
self.wall_states.append((resulting_states, self.opposite[direction]))
|
||||
|
||||
|
||||
def is_outside(self, cell):
|
||||
return cell[0] < 0 or cell[0] >= self.height or cell[1] < 0 or cell[1] >= self.width
|
||||
|
||||
|
||||
def state_to_cell(self, s):
|
||||
return np.array([int(s / self.width), s % self.width])
|
||||
|
||||
def cell_to_state(self, cell):
|
||||
return cell[0] * self.width + cell[1]
|
||||
|
||||
|
||||
def plot(self):
|
||||
G = np.zeros([self.height, self.width])
|
||||
plt.pcolor(G, cmap = 'Greys')
|
||||
plt.colorbar()
|
||||
|
||||
print(self.wall_states)
|
||||
for states in self.wall_states:
|
||||
s, a = states
|
||||
if a == "N":
|
||||
plt.hlines(y = self.state_to_cell(s)[0]+1, xmin=self.state_to_cell(s)[1], xmax=self.state_to_cell(s)[1]+1, color = 'w', linewidth = 2)
|
||||
if a == "E":
|
||||
plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
height, width = 10, 10
|
||||
m = Maze(height, width)
|
||||
|
||||
for i in range(width-4):
|
||||
m.add_wall([1, i], "N")
|
||||
for i in range(width-3, width-1):
|
||||
m.add_wall([1, i], "N")
|
||||
for i in range(width-5, width-3):
|
||||
m.add_wall([2, i], "E")
|
||||
m.add_wall([3, i], "E")
|
||||
m.add_wall([3,width-4], "N")
|
||||
|
||||
for i in range(2,5):
|
||||
m.add_wall([5, i], "N")
|
||||
m.add_wall([6, 4], "N")
|
||||
m.add_wall([7, 4], "N")
|
||||
m.add_wall([8, 3], "N")
|
||||
m.add_wall([8, 3], "E")
|
||||
m.add_wall([6, 4], "E")
|
||||
m.add_wall([7, 4], "E")
|
||||
|
||||
m.plot()
|
||||
plt.show()
|
Loading…
Reference in New Issue
Block a user