PreferenceExpectedImprovement added
This commit is contained in:
parent
c9daf65917
commit
c32e7cff84
@ -0,0 +1,77 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.stats import norm, multivariate_normal
|
||||||
|
|
||||||
|
|
||||||
|
class PreferenceExpectedImprovement:
|
||||||
|
def __init__(self, nr_dims, nr_samples, lower_bound, upper_bound, initial_variance, update_variance, seed=None):
|
||||||
|
|
||||||
|
self.nr_dims = nr_dims
|
||||||
|
self.nr_samples = int(nr_samples)
|
||||||
|
|
||||||
|
self.lower_bound = lower_bound
|
||||||
|
self.upper_bound = upper_bound
|
||||||
|
|
||||||
|
self.init_var = initial_variance
|
||||||
|
self.update_var = update_variance
|
||||||
|
|
||||||
|
self.rng = np.random.default_rng(seed=seed)
|
||||||
|
|
||||||
|
# initial proposal distribution
|
||||||
|
self.proposal_mean = np.zeros((nr_dims, 1))
|
||||||
|
self.proposal_cov = np.diag(np.ones((nr_dims,)) * self.init_var)
|
||||||
|
|
||||||
|
def rejection_sampling(self):
|
||||||
|
samples = np.empty((0, self.nr_dims))
|
||||||
|
while samples.shape[0] < self.nr_samples:
|
||||||
|
# sample from the multi variate gaussian distribution
|
||||||
|
sample = np.zeros((1, self.nr_dims))
|
||||||
|
for i in range(self.nr_dims):
|
||||||
|
check = False
|
||||||
|
while not check:
|
||||||
|
sample[0, i] = self.rng.normal(self.proposal_mean[i], self.proposal_cov[i, i])
|
||||||
|
if self.lower_bound <= sample[0, i] <= self.upper_bound:
|
||||||
|
check = True
|
||||||
|
|
||||||
|
samples = np.append(samples, sample, axis=0)
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def expected_improvement(self, gp , X, kappa=0.01):
|
||||||
|
X_sample = self.rejection_sampling()
|
||||||
|
|
||||||
|
mu_sample, sigma_sample = gp.predict(X_sample, return_std=True)
|
||||||
|
sigma_sample = sigma_sample.reshape(-1, 1)
|
||||||
|
|
||||||
|
mu = gp.predict(X)
|
||||||
|
mu_best = np.max(mu)
|
||||||
|
|
||||||
|
with np.errstate(divide='warn'):
|
||||||
|
imp = mu_sample - mu_best - kappa
|
||||||
|
imp = imp.reshape(-1, 1)
|
||||||
|
z = imp / sigma_sample
|
||||||
|
ei = imp * norm.cdf(z) + sigma_sample * norm.pdf(z)
|
||||||
|
ei[sigma_sample == 0.0] = 0.0
|
||||||
|
|
||||||
|
idx = np.argmax(ei)
|
||||||
|
x_next = X_sample[idx, :]
|
||||||
|
|
||||||
|
return x_next
|
||||||
|
|
||||||
|
def update_proposal_model(self, preference_mean, preference_bool):
|
||||||
|
cov_diag = np.ones((self.nr_dims,)) * self.init_var
|
||||||
|
cov_diag[preference_bool] = self.update_var
|
||||||
|
|
||||||
|
preference_cov = np.diag(cov_diag)
|
||||||
|
|
||||||
|
preference_mean = preference_mean.reshape(-1, 1)
|
||||||
|
|
||||||
|
posterior_mean = np.linalg.inv(np.linalg.inv(self.proposal_cov) + np.linalg.inv(preference_cov))\
|
||||||
|
.dot(np.linalg.inv(self.proposal_cov).dot(self.proposal_mean)
|
||||||
|
+ np.linalg.inv(preference_cov).dot(preference_mean))
|
||||||
|
|
||||||
|
posterior_cov = np.linalg.inv(np.linalg.inv(self.proposal_cov) + np.linalg.inv(preference_cov))
|
||||||
|
|
||||||
|
self.proposal_mean = posterior_mean
|
||||||
|
self.proposal_cov = posterior_cov
|
||||||
|
|
||||||
|
|
@ -3,9 +3,11 @@ from sklearn.gaussian_process import GaussianProcessRegressor
|
|||||||
from sklearn.gaussian_process.kernels import Matern
|
from sklearn.gaussian_process.kernels import Matern
|
||||||
|
|
||||||
from active_bo_ros.PolicyModel.GaussianRBFModel import GaussianRBF
|
from active_bo_ros.PolicyModel.GaussianRBFModel import GaussianRBF
|
||||||
|
|
||||||
from active_bo_ros.AcquisitionFunctions.ExpectedImprovement import ExpectedImprovement
|
from active_bo_ros.AcquisitionFunctions.ExpectedImprovement import ExpectedImprovement
|
||||||
from active_bo_ros.AcquisitionFunctions.ProbabilityOfImprovement import ProbabilityOfImprovement
|
from active_bo_ros.AcquisitionFunctions.ProbabilityOfImprovement import ProbabilityOfImprovement
|
||||||
from active_bo_ros.AcquisitionFunctions.ConfidenceBound import ConfidenceBound
|
from active_bo_ros.AcquisitionFunctions.ConfidenceBound import ConfidenceBound
|
||||||
|
from active_bo_ros.AcquisitionFunctions.PreferenceExpectedImprovement import PreferenceExpectedImprovement
|
||||||
|
|
||||||
from sklearn.exceptions import ConvergenceWarning
|
from sklearn.exceptions import ConvergenceWarning
|
||||||
import warnings
|
import warnings
|
||||||
@ -39,7 +41,16 @@ class BayesianOptimization:
|
|||||||
self.lower_bound,
|
self.lower_bound,
|
||||||
self.upper_bound)
|
self.upper_bound)
|
||||||
|
|
||||||
self.eval_X = 100
|
self.nr_samples = 100
|
||||||
|
|
||||||
|
if acq == "Preference Expected Improvement":
|
||||||
|
self.acq_fun = PreferenceExpectedImprovement(self.nr_policy_weights,
|
||||||
|
self.nr_samples,
|
||||||
|
self.lower_bound,
|
||||||
|
self.upper_bound,
|
||||||
|
initial_variance=10.0,
|
||||||
|
update_variance=0.05,
|
||||||
|
seed=policy_seed)
|
||||||
|
|
||||||
def reset_bo(self):
|
def reset_bo(self):
|
||||||
self.counter_array = np.empty((1, 1))
|
self.counter_array = np.empty((1, 1))
|
||||||
@ -94,7 +105,7 @@ class BayesianOptimization:
|
|||||||
if self.acq == "Expected Improvement":
|
if self.acq == "Expected Improvement":
|
||||||
x_next = ExpectedImprovement(self.GP,
|
x_next = ExpectedImprovement(self.GP,
|
||||||
self.X,
|
self.X,
|
||||||
self.eval_X,
|
self.nr_samples,
|
||||||
self.nr_policy_weights,
|
self.nr_policy_weights,
|
||||||
kappa=0,
|
kappa=0,
|
||||||
seed=self.policy_seed,
|
seed=self.policy_seed,
|
||||||
@ -104,7 +115,7 @@ class BayesianOptimization:
|
|||||||
elif self.acq == "Probability of Improvement":
|
elif self.acq == "Probability of Improvement":
|
||||||
x_next = ProbabilityOfImprovement(self.GP,
|
x_next = ProbabilityOfImprovement(self.GP,
|
||||||
self.X,
|
self.X,
|
||||||
self.eval_X,
|
self.nr_samples,
|
||||||
self.nr_policy_weights,
|
self.nr_policy_weights,
|
||||||
kappa=0,
|
kappa=0,
|
||||||
seed=self.policy_seed,
|
seed=self.policy_seed,
|
||||||
@ -113,13 +124,18 @@ class BayesianOptimization:
|
|||||||
|
|
||||||
elif self.acq == "Upper Confidence Bound":
|
elif self.acq == "Upper Confidence Bound":
|
||||||
x_next = ConfidenceBound(self.GP,
|
x_next = ConfidenceBound(self.GP,
|
||||||
self.eval_X,
|
self.nr_samples,
|
||||||
self.nr_policy_weights,
|
self.nr_policy_weights,
|
||||||
beta=2.576,
|
beta=2.576,
|
||||||
seed=self.policy_seed,
|
seed=self.policy_seed,
|
||||||
lower=self.lower_bound,
|
lower=self.lower_bound,
|
||||||
upper=self.upper_bound)
|
upper=self.upper_bound)
|
||||||
|
|
||||||
|
elif self.acq == "Preference Expected Improvement":
|
||||||
|
x_next = self.acq_fun.expected_improvement(self.GP,
|
||||||
|
self.X,
|
||||||
|
kappa=0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -9,10 +9,10 @@ class ImprovementQuery:
|
|||||||
self.rewards = rewards
|
self.rewards = rewards
|
||||||
|
|
||||||
def query(self):
|
def query(self):
|
||||||
if self.rewards.shape[0] < self.period:
|
if self.rewards.shape[0] < self.period + 1:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
elif self.rewards.shape[0] < self.last_query + self.period:
|
elif self.rewards.shape[0] < self.last_query + self.period + 1:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -194,6 +194,8 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
if self.user_asked:
|
if self.user_asked:
|
||||||
self.last_user_reward = self.rl_reward
|
self.last_user_reward = self.rl_reward
|
||||||
|
if self.bo_acq_fcn == "Preference Expected Improvement":
|
||||||
|
self.BO.acq_fun.update_proposal_model(self.rl_weights, self.overwrite_weight)
|
||||||
self.user_asked = False
|
self.user_asked = False
|
||||||
|
|
||||||
self.rl_pending = False
|
self.rl_pending = False
|
||||||
@ -271,6 +273,8 @@ class ActiveBOTopic(Node):
|
|||||||
acq = 'pi'
|
acq = 'pi'
|
||||||
elif self.bo_acq_fcn == "Upper Confidence Bound":
|
elif self.bo_acq_fcn == "Upper Confidence Bound":
|
||||||
acq = 'cb'
|
acq = 'cb'
|
||||||
|
elif self.bo_acq_fcn == "Preference Expected Improvement":
|
||||||
|
acq = 'pei'
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -311,7 +315,7 @@ class ActiveBOTopic(Node):
|
|||||||
if self.init_pending:
|
if self.init_pending:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
if self.current_episode < self.bo_episodes:
|
if self.current_episode < self.bo_episodes + self.nr_init - 1:
|
||||||
# metrics
|
# metrics
|
||||||
if self.bo_metric == "random":
|
if self.bo_metric == "random":
|
||||||
user_query = RandomQuery(self.bo_metric_parameter)
|
user_query = RandomQuery(self.bo_metric_parameter)
|
||||||
@ -362,6 +366,9 @@ class ActiveBOTopic(Node):
|
|||||||
self.rl_pending = True
|
self.rl_pending = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# if self.bo_acq_fcn == "Preference Expected Improvement":
|
||||||
|
# self.get_logger().info(f"{self.BO.acq_fun.proposal_mean}")
|
||||||
|
# self.get_logger().info(f"X: {self.BO.X}")
|
||||||
x_next = self.BO.next_observation()
|
x_next = self.BO.next_observation()
|
||||||
# self.get_logger().info(f'x_next: {x_next}')
|
# self.get_logger().info(f'x_next: {x_next}')
|
||||||
# self.get_logger().info(f'overwrite: {self.overwrite_weight}')
|
# self.get_logger().info(f'overwrite: {self.overwrite_weight}')
|
||||||
@ -392,10 +399,11 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
self.active_rl_pub.publish(rl_msg)
|
self.active_rl_pub.publish(rl_msg)
|
||||||
|
|
||||||
self.current_episode += 1
|
|
||||||
self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y)
|
self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y)
|
||||||
self.get_logger().info(f'Current Episode: {self.current_episode},'
|
self.get_logger().info(f'Current Episode: {self.current_episode},'
|
||||||
f' best reward: {self.reward[self.current_episode, self.current_run]}')
|
f' best reward: {self.reward[self.current_episode, self.current_run]}')
|
||||||
|
self.current_episode += 1
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.best_policy[:, self.current_run], \
|
self.best_policy[:, self.current_run], \
|
||||||
|
Loading…
Reference in New Issue
Block a user