diff --git a/empowerment/maze.py b/empowerment/maze.py new file mode 100644 index 0000000..14af149 --- /dev/null +++ b/empowerment/maze.py @@ -0,0 +1,102 @@ +import numpy as np +from collections import OrderedDict +import matplotlib.pyplot as plt + + +class Maze(object): + def __init__(self, height, width): + self.height = height + self.width = width + + self.actions = OrderedDict() + self.actions["N"] = np.array([1, 0]) # UP + self.actions["S"] = np.array([-1, 0]) # DOWN + self.actions["E"] = np.array([0, 1]) # RIGHT + self.actions["W"] = np.array([0, -1]) # LEFT + self.actions["_"] = np.array([0, 0]) # STAY + + self.opposite = OrderedDict() + self.opposite["N"] = "S" + self.opposite["S"] = "N" + self.opposite["W"] = "E" + self.opposite["E"] = "W" + + self.wall_states = [] + + self.vecmod = np.vectorize(lambda x, y : x % y) + + + + def act(self, s, a, check_wall = True): + """ get updated state after action + + s : state, index of grid position + a : action + prob : probability of performing action + """ + state = self.state_to_cell(s) + new_state = state + self.actions[a] + if self.is_outside(new_state): + return s + if check_wall and (s, a) in self.wall_states: + return s + return self.cell_to_state(new_state) + + + def add_wall(self, cell, direction): + state = self.cell_to_state(cell) + self.wall_states.append((state, direction)) + + resulting_states = self.act(state, direction, check_wall = False) + self.wall_states.append((resulting_states, self.opposite[direction])) + + + def is_outside(self, cell): + return cell[0] < 0 or cell[0] >= self.height or cell[1] < 0 or cell[1] >= self.width + + + def state_to_cell(self, s): + return np.array([int(s / self.width), s % self.width]) + + def cell_to_state(self, cell): + return cell[0] * self.width + cell[1] + + + def plot(self): + G = np.zeros([self.height, self.width]) + plt.pcolor(G, cmap = 'Greys') + plt.colorbar() + + print(self.wall_states) + for states in self.wall_states: + s, a = states + if a == "N": + plt.hlines(y = self.state_to_cell(s)[0]+1, xmin=self.state_to_cell(s)[1], xmax=self.state_to_cell(s)[1]+1, color = 'w', linewidth = 2) + if a == "E": + plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2) + + +if __name__ == "__main__": + height, width = 10, 10 + m = Maze(height, width) + + for i in range(width-4): + m.add_wall([1, i], "N") + for i in range(width-3, width-1): + m.add_wall([1, i], "N") + for i in range(width-5, width-3): + m.add_wall([2, i], "E") + m.add_wall([3, i], "E") + m.add_wall([3,width-4], "N") + + for i in range(2,5): + m.add_wall([5, i], "N") + m.add_wall([6, 4], "N") + m.add_wall([7, 4], "N") + m.add_wall([8, 3], "N") + m.add_wall([8, 3], "E") + m.add_wall([6, 4], "E") + m.add_wall([7, 4], "E") + + m.plot() + plt.show() \ No newline at end of file