Compare commits

...

3 Commits

Author SHA1 Message Date
432978a122 Started with PreferenceExpectedImprovement.py 2023-07-13 10:48:46 +02:00
47f48662f4 Merge remote-tracking branch 'origin/master'
# Conflicts:
#	AcquistionFunctions/PreferenceExpectedImprovement.py
2023-07-13 10:48:18 +02:00
058e8706a5 Started with PreferenceExpectedImprovement.py 2023-07-13 10:48:10 +02:00

View File

@ -43,20 +43,21 @@ class PreferenceExpectedImprovement:
pass pass
def update_user_preference_model(self, preferred_input, preference_array): def update_user_preference_model(self, preferred_input, preference_array):
# Update mean to reflect preferred input # Update mean to reflect preferred input
self.user_model_mean = preferred_input self.user_model_mean = preferred_input
initial_variance = np.ones((self.nr_dims, )) * self.initial_variance initial_variance = np.ones((self.nr_dims, )) * self.initial_variance
reduced_variance = initial_variance / 10.0 reduced_variance = initial_variance / 10.0
variances = np.where(preference_array, reduced_variance, initial_variance) variances = np.where(preference_array, reduced_variance, initial_variance)
self.user_model_covariance = np.diag(variances) self.user_model_covariance = np.diag(variances)
def update_proposal_model(self, alpha=0.5): def update_proposal_model(self, alpha=0.5):
# Update proposal model to be a weighted average of the current proposal model and the user model # Update proposal model to be a weighted average of the current proposal model and the user model
self.proposal_model_mean = alpha * self.proposal_model_mean + (1 - alpha) * self.user_model_mean self.proposal_model_mean = alpha * self.proposal_model_mean + (1 - alpha) * self.user_model_mean
self.proposal_model_covariance = alpha * self.proposal_model_covariance + (1 - alpha) * self.user_model_covariance self.proposal_model_covariance = alpha * self.proposal_model_covariance + (1 - alpha) * self.user_model_covariance
if __name__ == '__main__': if __name__ == '__main__':
acquisition = PreferenceExpectedImprovement(10, 2, -1.0, 1.0) acquisition = PreferenceExpectedImprovement(10, 2, -1.0, 1.0)