import numpy as np from collections import OrderedDict import matplotlib.pyplot as plt class Maze(object): def __init__(self, height, width): self.height = height self.width = width self.actions = OrderedDict() self.actions["N"] = np.array([1, 0]) # UP self.actions["S"] = np.array([-1, 0]) # DOWN self.actions["E"] = np.array([0, 1]) # RIGHT self.actions["W"] = np.array([0, -1]) # LEFT self.actions["_"] = np.array([0, 0]) # STAY self.opposite = OrderedDict() self.opposite["N"] = "S" self.opposite["S"] = "N" self.opposite["W"] = "E" self.opposite["E"] = "W" self.wall_states = [] self.vecmod = np.vectorize(lambda x, y : x % y) def act(self, s, a, check_wall = True): """ get updated state after action s : state, index of grid position a : action prob : probability of performing action """ state = self.state_to_cell(s) new_state = state + self.actions[a] if self.is_outside(new_state): return s if check_wall and (s, a) in self.wall_states: return s return self.cell_to_state(new_state) def add_wall(self, cell, direction): state = self.cell_to_state(cell) self.wall_states.append((state, direction)) resulting_states = self.act(state, direction, check_wall = False) self.wall_states.append((resulting_states, self.opposite[direction])) def is_outside(self, cell): return cell[0] < 0 or cell[0] >= self.height or cell[1] < 0 or cell[1] >= self.width def state_to_cell(self, s): return np.array([int(s / self.width), s % self.width]) def cell_to_state(self, cell): return cell[0] * self.width + cell[1] def plot(self): G = np.zeros([self.height, self.width]) plt.pcolor(G, cmap = 'Greys') plt.colorbar() print(self.wall_states) for states in self.wall_states: s, a = states if a == "N": plt.hlines(y = self.state_to_cell(s)[0]+1, xmin=self.state_to_cell(s)[1], xmax=self.state_to_cell(s)[1]+1, color = 'w', linewidth = 2) if a == "E": plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2) if __name__ == "__main__": height, width = 10, 10 m = Maze(height, width) for i in range(width-4): m.add_wall([1, i], "N") for i in range(width-3, width-1): m.add_wall([1, i], "N") for i in range(width-5, width-3): m.add_wall([2, i], "E") m.add_wall([3, i], "E") m.add_wall([3,width-4], "N") for i in range(2,5): m.add_wall([5, i], "N") m.add_wall([6, 4], "N") m.add_wall([7, 4], "N") m.add_wall([8, 3], "N") m.add_wall([8, 3], "E") m.add_wall([6, 4], "E") m.add_wall([7, 4], "E") m.plot() plt.show()