Curiosity/empowerment/empowerment_functions.py
2023-02-23 16:49:55 +01:00

23 lines
689 B
Python

import numpy as np
def total_states(states, model, actions, states_tot=[]):
if type(states) == int:
states_tot.extend(list(set([model.act(states, a, check_wall=True) for a in actions])))
else:
for i in range(len(states)):
total_states(int(states[i]), model, actions, states_tot)
return list(set(states_tot))
def compute_empowerment(model, n_steps):
states = range(model.height * model.width)
actions = ["N", "S", "E", "W", "_"]
empowerment = []
for s in states:
for i in range(n_steps):
s = total_states(s, model, actions,[])
empowerment.append(np.log2(len(s)))
return np.asarray(empowerment)