Started with PreferenceExpectedImprovement.py

This commit is contained in:
Niko Feith 2023-07-12 15:00:13 +02:00
parent be2192ee90
commit e9f24d0086
2 changed files with 19 additions and 4 deletions

View File

@ -3,21 +3,31 @@ from scipy.stats import norm
class PreferenceExpectedImprovement: class PreferenceExpectedImprovement:
def __init__(self, nr_samples, upper_bound, lower_bound, nr_dims=2): def __init__(self, nr_samples, nr_dims, lower_bound, upper_bound, seed=None):
self.nr_samples = nr_samples self.nr_samples = nr_samples
self.nr_dims = nr_dims
# check if upper_bound and lower_bound are numpy arrays of shape (nr_dims, 1) or (nr_dims,) or if they are floats
self.upper_bound = upper_bound self.upper_bound = upper_bound
self.lower_bound = lower_bound self.lower_bound = lower_bound
self.user_model = None self.user_model = None
self.proposal_model_mean = np.array() self.proposal_model_mean = np.array((nr_dims, 1))
self.proposal_model_covariance = np.diag(np.ones((nr_dims, )) * 5)
self.rng = np.random.default_rng(seed=seed)
def initialize(self): def initialize(self):
pass pass
def rejection_sampling(self): def rejection_sampling(self):
samples = np.empty((self.nr_samples, self.nr_dims))
i = 0
while i < self.nr_samples:
pass pass
def expected_improvement(self): def expected_improvement(self):
pass pass
@ -26,3 +36,8 @@ class PreferenceExpectedImprovement:
def update_proposal_model(self): def update_proposal_model(self):
pass pass
if __name__ == '__main__':
acquisition = PreferenceExpectedImprovement(10, 2, -1.0, 1.0)

View File

@ -12,7 +12,7 @@ gaussian = gaussian[(gaussian > -1) & (gaussian < 1)]
uniform = np.random.uniform(-1, 1, len(gaussian)) # Same number of samples as the Gaussian uniform = np.random.uniform(-1, 1, len(gaussian)) # Same number of samples as the Gaussian
# Plot the distributions # Plot the distributions
plt.figure(figsize=(12,6)) plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1) plt.subplot(1, 2, 1)
plt.hist(gaussian, bins=30, density=True, alpha=0.6, color='g') plt.hist(gaussian, bins=30, density=True, alpha=0.6, color='g')