ActiveBOToytask/PolicyModel/GaussianModel.py

63 lines
1.9 KiB
Python
Raw Normal View History

2023-02-02 17:54:09 +00:00
import numpy as np
import matplotlib.pyplot as plt
class GaussianPolicy:
def __init__(self, nr_weights, nr_steps, seed=None, lowerb=-1.0, upperb=1.0):
self.nr_weights = nr_weights
self.nr_steps = nr_steps
self.weights = None
self.trajectory = None
self.mean = np.linspace(0, self.nr_steps, self.nr_weights)
if nr_weights > 1:
self.std = self.mean[1] / (2 * np.sqrt(2 * np.log(2))) # Full width at half maximum
else:
self.std = self.nr_steps / 2
self.rng = np.random.default_rng(seed=seed)
self.low = lowerb
self.upper = upperb
self.reset()
def reset(self):
self.weights = np.zeros((self.nr_weights, 1))
self.trajectory = np.zeros((self.nr_steps, 1))
def random_policy(self):
self.weights = self.rng.uniform(self.low, self.upper, self.nr_weights)
def policy_rollout(self):
self.trajectory = np.zeros((self.nr_steps, 1))
for i in range(self.nr_steps):
for j in range(self.nr_weights):
base_fun = np.exp(-0.5*(i - self.mean[j])**2 / self.std**2)
self.trajectory[i] += base_fun * self.weights[j]
return self.trajectory
2023-02-06 14:43:30 +00:00
def plot_policy(self, finished=np.NAN):
2023-02-02 17:54:09 +00:00
x = np.linspace(0, self.nr_steps, self.nr_steps)
plt.plot(x, self.trajectory)
2023-02-06 14:43:30 +00:00
if finished != np.NAN:
plt.vlines(finished, -1, 1, colors='red')
2023-02-03 10:34:24 +00:00
# for i in self.mean:
# gaussian = np.exp(-0.5 * (x - i)**2 / self.std**2)
# plt.plot(x, gaussian)
2023-02-02 17:54:09 +00:00
def main():
policy = GaussianPolicy(1, 50)
policy.random_policy()
policy.policy_rollout()
print(policy.weights)
fig, (ax1, ax2) = plt.subplots(2, 1)
x = np.linspace(0, policy.nr_steps, policy.nr_steps)
ax1.plot(x, policy.trajectory)
ax2.bar(policy.mean, policy.weights)
plt.show()
if __name__ == "__main__":
main()