From e9f24d0086bee6338c399980e09018981e57ee43 Mon Sep 17 00:00:00 2001 From: Niko Date: Wed, 12 Jul 2023 15:00:13 +0200 Subject: [PATCH] Started with PreferenceExpectedImprovement.py --- .../PreferenceExpectedImprovement.py | 21 ++++++++++++++++--- DistTesting/Testing.py | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/AcquistionFunctions/PreferenceExpectedImprovement.py b/AcquistionFunctions/PreferenceExpectedImprovement.py index c47f937..b69595c 100644 --- a/AcquistionFunctions/PreferenceExpectedImprovement.py +++ b/AcquistionFunctions/PreferenceExpectedImprovement.py @@ -3,20 +3,30 @@ from scipy.stats import norm class PreferenceExpectedImprovement: - def __init__(self, nr_samples, upper_bound, lower_bound, nr_dims=2): + def __init__(self, nr_samples, nr_dims, lower_bound, upper_bound, seed=None): self.nr_samples = nr_samples + self.nr_dims = nr_dims + # check if upper_bound and lower_bound are numpy arrays of shape (nr_dims, 1) or (nr_dims,) or if they are floats + self.upper_bound = upper_bound self.lower_bound = lower_bound self.user_model = None - self.proposal_model_mean = np.array() + self.proposal_model_mean = np.array((nr_dims, 1)) + self.proposal_model_covariance = np.diag(np.ones((nr_dims, )) * 5) + + self.rng = np.random.default_rng(seed=seed) def initialize(self): pass def rejection_sampling(self): - pass + samples = np.empty((self.nr_samples, self.nr_dims)) + i = 0 + while i < self.nr_samples: + pass + def expected_improvement(self): pass @@ -26,3 +36,8 @@ class PreferenceExpectedImprovement: def update_proposal_model(self): pass + + +if __name__ == '__main__': + acquisition = PreferenceExpectedImprovement(10, 2, -1.0, 1.0) + diff --git a/DistTesting/Testing.py b/DistTesting/Testing.py index 50ae73b..8fda221 100644 --- a/DistTesting/Testing.py +++ b/DistTesting/Testing.py @@ -12,7 +12,7 @@ gaussian = gaussian[(gaussian > -1) & (gaussian < 1)] uniform = np.random.uniform(-1, 1, len(gaussian)) # Same number of samples as the Gaussian # Plot the distributions -plt.figure(figsize=(12,6)) +plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.hist(gaussian, bins=30, density=True, alpha=0.6, color='g')