ActiveBOToytask/runner/BOGymRunner.py

92 lines
2.5 KiB
Python
Raw Normal View History

2023-02-15 15:03:03 +00:00
from BayesianOptimization.BOwithGym import BayesianOptimization
from ToyTask.MountainCarGym import Continuous_MountainCarEnv
import numpy as np
import matplotlib.pyplot as plt
# BO parameters
env = Continuous_MountainCarEnv()
nr_steps = 100
acquisition_fun = 'ei'
iteration_steps = 500
nr_runs = 20
# storage arrays
finished_store = np.zeros((1, nr_runs))
best_policy = np.zeros((nr_steps, nr_runs))
reward_store = np.zeros((iteration_steps, nr_runs))
# post-processing
def post_processing(finished, policy, reward):
finish_mean = np.nanmean(finished)
finish_std = np.nanstd(finished)
policy_mean = np.mean(policy, axis=1)
policy_std = np.std(policy, axis=1)
reward_mean = np.mean(reward, axis=1)
reward_std = np.std(reward, axis=1)
return finish_mean, finish_std, policy_mean, policy_std, reward_mean, reward_std
# plot functions
def plot_policy(mean, std, fin_mean, fin_std):
x = np.linspace(0, mean.shape[0], mean.shape[0])
plt.plot(x, mean)
plt.fill_between(
x,
mean - 1.96 * std,
mean + 1.96 * std,
alpha=0.5
)
y = np.linspace(-2, 2, 50)
plt.vlines(fin_mean, -2, 2, colors='red')
plt.fill_betweenx(
y,
fin_mean - 1.96 * fin_std,
fin_mean + 1.96 * fin_std,
alpha=0.5,
)
plt.show()
def plot_reward(mean, std):
eps = np.linspace(0, mean.shape[0], mean.shape[0])
plt.plot(eps, mean)
plt.fill_between(
eps,
mean - 1.96 * std,
mean + 1.96 * std,
alpha=0.5
)
plt.show()
# main
def main():
global finished_store, best_policy, reward_store
bo = BayesianOptimization(env, nr_steps, acq=acquisition_fun)
for i in range(nr_runs):
print('Iteration:', str(i))
bo.initialize()
for j in range(iteration_steps):
x_next = bo.next_observation()
bo.eval_new_observation(x_next)
finished = bo.get_best_result(plotter=False)
finished_store[:, i] = finished
best_policy[:, i] = bo.policy_model.trajectory.T
reward_store[:, i] = bo.best_reward.T
finish_mean, finish_std, policy_mean, policy_std, reward_mean, reward_std = post_processing(finished_store,
best_policy,
reward_store)
plot_policy(policy_mean, policy_std, finish_mean, finish_std)
plot_reward(reward_mean, reward_std)
if __name__ == '__main__':
main()