63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
class GaussianPolicy:
|
||
|
def __init__(self, nr_weights, nr_steps, seed=None, lowerb=-1.0, upperb=1.0):
|
||
|
self.nr_weights = nr_weights
|
||
|
self.nr_steps = nr_steps
|
||
|
self.weights = None
|
||
|
self.trajectory = None
|
||
|
self.mean = np.linspace(0, self.nr_steps, self.nr_weights)
|
||
|
if nr_weights > 1:
|
||
|
self.std = self.mean[1] / (2 * np.sqrt(2 * np.log(2))) # Full width at half maximum
|
||
|
else:
|
||
|
self.std = self.nr_steps / 2
|
||
|
|
||
|
self.rng = np.random.default_rng(seed=seed)
|
||
|
self.low = lowerb
|
||
|
self.upper = upperb
|
||
|
|
||
|
self.reset()
|
||
|
|
||
|
def reset(self):
|
||
|
self.weights = np.zeros((self.nr_weights, 1))
|
||
|
self.trajectory = np.zeros((self.nr_steps, 1))
|
||
|
|
||
|
def random_policy(self):
|
||
|
self.weights = self.rng.uniform(self.low, self.upper, self.nr_weights)
|
||
|
|
||
|
def policy_rollout(self):
|
||
|
self.trajectory = np.zeros((self.nr_steps, 1))
|
||
|
for i in range(self.nr_steps):
|
||
|
for j in range(self.nr_weights):
|
||
|
base_fun = np.exp(-0.5*(i - self.mean[j])**2 / self.std**2)
|
||
|
self.trajectory[i] += base_fun * self.weights[j]
|
||
|
|
||
|
return self.trajectory
|
||
|
|
||
|
def plot_policy(self):
|
||
|
x = np.linspace(0, self.nr_steps, self.nr_steps)
|
||
|
plt.plot(x, self.trajectory)
|
||
|
for i in self.mean:
|
||
|
gaussian = np.exp(-0.5 * (x - i)**2 / self.std**2)
|
||
|
plt.plot(x, gaussian)
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
def main():
|
||
|
policy = GaussianPolicy(1, 50)
|
||
|
policy.random_policy()
|
||
|
policy.policy_rollout()
|
||
|
print(policy.weights)
|
||
|
|
||
|
fig, (ax1, ax2) = plt.subplots(2, 1)
|
||
|
|
||
|
x = np.linspace(0, policy.nr_steps, policy.nr_steps)
|
||
|
ax1.plot(x, policy.trajectory)
|
||
|
ax2.bar(policy.mean, policy.weights)
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|