Adding Empowerment functions

This commit is contained in:
ved1 2023-02-23 16:49:55 +01:00
parent 961e46c347
commit 69d1528077
2 changed files with 37 additions and 7 deletions

View File

@ -0,0 +1,23 @@
import numpy as np
def total_states(states, model, actions, states_tot=[]):
if type(states) == int:
states_tot.extend(list(set([model.act(states, a, check_wall=True) for a in actions])))
else:
for i in range(len(states)):
total_states(int(states[i]), model, actions, states_tot)
return list(set(states_tot))
def compute_empowerment(model, n_steps):
states = range(model.height * model.width)
actions = ["N", "S", "E", "W", "_"]
empowerment = []
for s in states:
for i in range(n_steps):
s = total_states(s, model, actions,[])
empowerment.append(np.log2(len(s)))
return np.asarray(empowerment)

View File

@ -1,6 +1,7 @@
import numpy as np import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from collections import OrderedDict
from empowerment_functions import compute_empowerment
class Maze(object): class Maze(object):
@ -62,12 +63,11 @@ class Maze(object):
return cell[0] * self.width + cell[1] return cell[0] * self.width + cell[1]
def plot(self): def plot(self, colorMap = None):
G = np.zeros([self.height, self.width]) G = np.zeros([self.height, self.width]) if colorMap is None else colorMap.copy()
plt.pcolor(G, cmap = 'Greys') plt.pcolor(G, cmap = 'gray_r')
plt.colorbar() plt.colorbar()
print(self.wall_states)
for states in self.wall_states: for states in self.wall_states:
s, a = states s, a = states
if a == "N": if a == "N":
@ -76,7 +76,7 @@ class Maze(object):
plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2) plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2)
if __name__ == "__main__": def klyubin_world():
height, width = 10, 10 height, width = 10, 10
m = Maze(height, width) m = Maze(height, width)
@ -98,5 +98,12 @@ if __name__ == "__main__":
m.add_wall([6, 4], "E") m.add_wall([6, 4], "E")
m.add_wall([7, 4], "E") m.add_wall([7, 4], "E")
m.plot() return m
if __name__ == "__main__":
maze = klyubin_world()
emp = compute_empowerment(maze,n_steps=10).reshape(10,10)
print(emp.min(), emp.max())
maze.plot(colorMap=emp)
plt.show() plt.show()