Compare commits
No commits in common. "432978a1220496575f81c313f7e7f1292bc1d650" and "3f9b12e7f2f547ed1b61fa7d9584b92d0d37c443" have entirely different histories.
432978a122
...
3f9b12e7f2
@ -28,7 +28,7 @@ class PreferenceExpectedImprovement:
|
||||
while samples.shape[0] < self.nr_samples:
|
||||
# sample from the multi variate gaussian distribution
|
||||
sample = self.rng.multivariate_normal(
|
||||
self.proposal_model_mean,
|
||||
self.proposal_model_mean,
|
||||
self.proposal_model_covariance
|
||||
)
|
||||
|
||||
@ -43,21 +43,20 @@ class PreferenceExpectedImprovement:
|
||||
pass
|
||||
|
||||
def update_user_preference_model(self, preferred_input, preference_array):
|
||||
# Update mean to reflect preferred input
|
||||
self.user_model_mean = preferred_input
|
||||
# Update mean to reflect preferred input
|
||||
self.user_model_mean = preferred_input
|
||||
|
||||
initial_variance = np.ones((self.nr_dims, )) * self.initial_variance
|
||||
reduced_variance = initial_variance / 10.0
|
||||
variances = np.where(preference_array, reduced_variance, initial_variance)
|
||||
self.user_model_covariance = np.diag(variances)
|
||||
initial_variance = np.ones((self.nr_dims, )) * self.initial_variance
|
||||
reduced_variance = initial_variance / 10.0
|
||||
variances = np.where(preference_array, reduced_variance, initial_variance)
|
||||
self.user_model_covariance = np.diag(variances)
|
||||
|
||||
def update_proposal_model(self, alpha=0.5):
|
||||
# Update proposal model to be a weighted average of the current proposal model and the user model
|
||||
self.proposal_model_mean = alpha * self.proposal_model_mean + (1 - alpha) * self.user_model_mean
|
||||
self.proposal_model_covariance = alpha * self.proposal_model_covariance + (1 - alpha) * self.user_model_covariance
|
||||
# Update proposal model to be a weighted average of the current proposal model and the user model
|
||||
self.proposal_model_mean = alpha * self.proposal_model_mean + (1 - alpha) * self.user_model_mean
|
||||
self.proposal_model_covariance = alpha * self.proposal_model_covariance + (1 - alpha) * self.user_model_covariance
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
acquisition = PreferenceExpectedImprovement(10, 2, -1.0, 1.0)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user