Adding Empowerment functions
This commit is contained in:
parent
961e46c347
commit
69d1528077
23
empowerment/empowerment_functions.py
Normal file
23
empowerment/empowerment_functions.py
Normal file
@ -0,0 +1,23 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def total_states(states, model, actions, states_tot=[]):
|
||||
if type(states) == int:
|
||||
states_tot.extend(list(set([model.act(states, a, check_wall=True) for a in actions])))
|
||||
else:
|
||||
for i in range(len(states)):
|
||||
total_states(int(states[i]), model, actions, states_tot)
|
||||
return list(set(states_tot))
|
||||
|
||||
def compute_empowerment(model, n_steps):
|
||||
states = range(model.height * model.width)
|
||||
|
||||
actions = ["N", "S", "E", "W", "_"]
|
||||
empowerment = []
|
||||
|
||||
for s in states:
|
||||
for i in range(n_steps):
|
||||
s = total_states(s, model, actions,[])
|
||||
empowerment.append(np.log2(len(s)))
|
||||
|
||||
return np.asarray(empowerment)
|
@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import OrderedDict
|
||||
from empowerment_functions import compute_empowerment
|
||||
|
||||
|
||||
class Maze(object):
|
||||
@ -62,12 +63,11 @@ class Maze(object):
|
||||
return cell[0] * self.width + cell[1]
|
||||
|
||||
|
||||
def plot(self):
|
||||
G = np.zeros([self.height, self.width])
|
||||
plt.pcolor(G, cmap = 'Greys')
|
||||
def plot(self, colorMap = None):
|
||||
G = np.zeros([self.height, self.width]) if colorMap is None else colorMap.copy()
|
||||
plt.pcolor(G, cmap = 'gray_r')
|
||||
plt.colorbar()
|
||||
|
||||
print(self.wall_states)
|
||||
for states in self.wall_states:
|
||||
s, a = states
|
||||
if a == "N":
|
||||
@ -76,7 +76,7 @@ class Maze(object):
|
||||
plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def klyubin_world():
|
||||
height, width = 10, 10
|
||||
m = Maze(height, width)
|
||||
|
||||
@ -98,5 +98,12 @@ if __name__ == "__main__":
|
||||
m.add_wall([6, 4], "E")
|
||||
m.add_wall([7, 4], "E")
|
||||
|
||||
m.plot()
|
||||
return m
|
||||
|
||||
if __name__ == "__main__":
|
||||
maze = klyubin_world()
|
||||
emp = compute_empowerment(maze,n_steps=10).reshape(10,10)
|
||||
print(emp.min(), emp.max())
|
||||
|
||||
maze.plot(colorMap=emp)
|
||||
plt.show()
|
Loading…
Reference in New Issue
Block a user