Curiosity/empowerment/maze.py

109 lines
3.2 KiB
Python
Raw Normal View History

2023-02-23 11:39:03 +00:00
import numpy as np
import matplotlib.pyplot as plt
2023-02-23 15:49:55 +00:00
from collections import OrderedDict
from empowerment_functions import compute_empowerment
2023-02-23 11:39:03 +00:00
class Maze(object):
def __init__(self, height, width):
self.height = height
self.width = width
self.actions = OrderedDict()
self.actions["N"] = np.array([1, 0]) # UP
self.actions["S"] = np.array([-1, 0]) # DOWN
self.actions["E"] = np.array([0, 1]) # RIGHT
self.actions["W"] = np.array([0, -1]) # LEFT
self.actions["_"] = np.array([0, 0]) # STAY
self.opposite = OrderedDict()
self.opposite["N"] = "S"
self.opposite["S"] = "N"
self.opposite["W"] = "E"
self.opposite["E"] = "W"
self.wall_states = []
self.vecmod = np.vectorize(lambda x, y : x % y)
def act(self, s, a, check_wall = True):
""" get updated state after action
s : state, index of grid position
a : action
prob : probability of performing action
"""
state = self.state_to_cell(s)
new_state = state + self.actions[a]
if self.is_outside(new_state):
return s
if check_wall and (s, a) in self.wall_states:
return s
return self.cell_to_state(new_state)
def add_wall(self, cell, direction):
state = self.cell_to_state(cell)
self.wall_states.append((state, direction))
resulting_states = self.act(state, direction, check_wall = False)
self.wall_states.append((resulting_states, self.opposite[direction]))
def is_outside(self, cell):
return cell[0] < 0 or cell[0] >= self.height or cell[1] < 0 or cell[1] >= self.width
def state_to_cell(self, s):
return np.array([int(s / self.width), s % self.width])
def cell_to_state(self, cell):
return cell[0] * self.width + cell[1]
2023-02-23 15:49:55 +00:00
def plot(self, colorMap = None):
G = np.zeros([self.height, self.width]) if colorMap is None else colorMap.copy()
plt.pcolor(G, cmap = 'gray_r')
2023-02-23 11:39:03 +00:00
plt.colorbar()
for states in self.wall_states:
s, a = states
if a == "N":
plt.hlines(y = self.state_to_cell(s)[0]+1, xmin=self.state_to_cell(s)[1], xmax=self.state_to_cell(s)[1]+1, color = 'w', linewidth = 2)
if a == "E":
plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2)
2023-02-23 15:49:55 +00:00
def klyubin_world():
2023-02-23 11:39:03 +00:00
height, width = 10, 10
m = Maze(height, width)
for i in range(width-4):
m.add_wall([1, i], "N")
for i in range(width-3, width-1):
m.add_wall([1, i], "N")
for i in range(width-5, width-3):
m.add_wall([2, i], "E")
m.add_wall([3, i], "E")
m.add_wall([3,width-4], "N")
for i in range(2,5):
m.add_wall([5, i], "N")
m.add_wall([6, 4], "N")
m.add_wall([7, 4], "N")
m.add_wall([8, 3], "N")
m.add_wall([8, 3], "E")
m.add_wall([6, 4], "E")
m.add_wall([7, 4], "E")
2023-02-23 15:49:55 +00:00
return m
if __name__ == "__main__":
maze = klyubin_world()
emp = compute_empowerment(maze,n_steps=10).reshape(10,10)
print(emp.min(), emp.max())
maze.plot(colorMap=emp)
2023-02-23 11:39:03 +00:00
plt.show()