Adding Empowerment functions
This commit is contained in:
parent
961e46c347
commit
69d1528077
23
empowerment/empowerment_functions.py
Normal file
23
empowerment/empowerment_functions.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def total_states(states, model, actions, states_tot=[]):
|
||||||
|
if type(states) == int:
|
||||||
|
states_tot.extend(list(set([model.act(states, a, check_wall=True) for a in actions])))
|
||||||
|
else:
|
||||||
|
for i in range(len(states)):
|
||||||
|
total_states(int(states[i]), model, actions, states_tot)
|
||||||
|
return list(set(states_tot))
|
||||||
|
|
||||||
|
def compute_empowerment(model, n_steps):
|
||||||
|
states = range(model.height * model.width)
|
||||||
|
|
||||||
|
actions = ["N", "S", "E", "W", "_"]
|
||||||
|
empowerment = []
|
||||||
|
|
||||||
|
for s in states:
|
||||||
|
for i in range(n_steps):
|
||||||
|
s = total_states(s, model, actions,[])
|
||||||
|
empowerment.append(np.log2(len(s)))
|
||||||
|
|
||||||
|
return np.asarray(empowerment)
|
@ -1,6 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from collections import OrderedDict
|
||||||
|
from empowerment_functions import compute_empowerment
|
||||||
|
|
||||||
|
|
||||||
class Maze(object):
|
class Maze(object):
|
||||||
@ -62,12 +63,11 @@ class Maze(object):
|
|||||||
return cell[0] * self.width + cell[1]
|
return cell[0] * self.width + cell[1]
|
||||||
|
|
||||||
|
|
||||||
def plot(self):
|
def plot(self, colorMap = None):
|
||||||
G = np.zeros([self.height, self.width])
|
G = np.zeros([self.height, self.width]) if colorMap is None else colorMap.copy()
|
||||||
plt.pcolor(G, cmap = 'Greys')
|
plt.pcolor(G, cmap = 'gray_r')
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
|
|
||||||
print(self.wall_states)
|
|
||||||
for states in self.wall_states:
|
for states in self.wall_states:
|
||||||
s, a = states
|
s, a = states
|
||||||
if a == "N":
|
if a == "N":
|
||||||
@ -76,7 +76,7 @@ class Maze(object):
|
|||||||
plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2)
|
plt.vlines(x = self.state_to_cell(s)[1]+1, ymin=self.state_to_cell(s)[0], ymax=self.state_to_cell(s)[0]+1, color = 'w', linewidth = 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def klyubin_world():
|
||||||
height, width = 10, 10
|
height, width = 10, 10
|
||||||
m = Maze(height, width)
|
m = Maze(height, width)
|
||||||
|
|
||||||
@ -98,5 +98,12 @@ if __name__ == "__main__":
|
|||||||
m.add_wall([6, 4], "E")
|
m.add_wall([6, 4], "E")
|
||||||
m.add_wall([7, 4], "E")
|
m.add_wall([7, 4], "E")
|
||||||
|
|
||||||
m.plot()
|
return m
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
maze = klyubin_world()
|
||||||
|
emp = compute_empowerment(maze,n_steps=10).reshape(10,10)
|
||||||
|
print(emp.min(), emp.max())
|
||||||
|
|
||||||
|
maze.plot(colorMap=emp)
|
||||||
plt.show()
|
plt.show()
|
Loading…
Reference in New Issue
Block a user