From 69d1528077337e320c90e4eb4f927123a56cdc0e Mon Sep 17 00:00:00 2001 From: ved1 Date: Thu, 23 Feb 2023 16:49:55 +0100 Subject: [PATCH] Adding Empowerment functions --- empowerment/empowerment_functions.py | 23 +++++++++++++++++++++++ empowerment/maze.py | 21 ++++++++++++++------- 2 files changed, 37 insertions(+), 7 deletions(-) create mode 100644 empowerment/empowerment_functions.py diff --git a/empowerment/empowerment_functions.py b/empowerment/empowerment_functions.py new file mode 100644 index 0000000..0bdf773 --- /dev/null +++ b/empowerment/empowerment_functions.py @@ -0,0 +1,23 @@ +import numpy as np + + +def total_states(states, model, actions, states_tot=[]): + if type(states) == int: + states_tot.extend(list(set([model.act(states, a, check_wall=True) for a in actions]))) + else: + for i in range(len(states)): + total_states(int(states[i]), model, actions, states_tot) + return list(set(states_tot)) + +def compute_empowerment(model, n_steps): + states = range(model.height * model.width) + + actions = ["N", "S", "E", "W", "_"] + empowerment = [] + + for s in states: + for i in range(n_steps): + s = total_states(s, model, actions,[]) + empowerment.append(np.log2(len(s))) + + return np.asarray(empowerment) \ No newline at end of file diff --git a/empowerment/maze.py b/empowerment/maze.py index 14af149..37410f0 100644 --- a/empowerment/maze.py +++ b/empowerment/maze.py @@ -1,6 +1,7 @@ import numpy as np -from collections import OrderedDict import matplotlib.pyplot as plt +from collections import OrderedDict +from empowerment_functions import compute_empowerment class Maze(object): @@ -62,12 +63,11 @@ class Maze(object): return cell[0] * self.width + cell[1] - def plot(self): - G = np.zeros([self.height, self.width]) - plt.pcolor(G, cmap = 'Greys') + def plot(self, colorMap = None): + G = np.zeros([self.height, self.width]) if colorMap is None else colorMap.copy() + plt.pcolor(G, cmap = 'gray_r') plt.colorbar() - print(self.wall_states) for states in self.wall_states: s, a = states if a == "N": @@ -76,7 +76,7 @@ class Maze(object): plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2) -if __name__ == "__main__": +def klyubin_world(): height, width = 10, 10 m = Maze(height, width) @@ -98,5 +98,12 @@ if __name__ == "__main__": m.add_wall([6, 4], "E") m.add_wall([7, 4], "E") - m.plot() + return m + +if __name__ == "__main__": + maze = klyubin_world() + emp = compute_empowerment(maze,n_steps=10).reshape(10,10) + print(emp.min(), emp.max()) + + maze.plot(colorMap=emp) plt.show() \ No newline at end of file