Compare commits
3 Commits
3f9b12e7f2
...
432978a122
Author | SHA1 | Date | |
---|---|---|---|
432978a122 | |||
47f48662f4 | |||
058e8706a5 |
@ -43,20 +43,21 @@ class PreferenceExpectedImprovement:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def update_user_preference_model(self, preferred_input, preference_array):
|
def update_user_preference_model(self, preferred_input, preference_array):
|
||||||
# Update mean to reflect preferred input
|
# Update mean to reflect preferred input
|
||||||
self.user_model_mean = preferred_input
|
self.user_model_mean = preferred_input
|
||||||
|
|
||||||
initial_variance = np.ones((self.nr_dims, )) * self.initial_variance
|
initial_variance = np.ones((self.nr_dims, )) * self.initial_variance
|
||||||
reduced_variance = initial_variance / 10.0
|
reduced_variance = initial_variance / 10.0
|
||||||
variances = np.where(preference_array, reduced_variance, initial_variance)
|
variances = np.where(preference_array, reduced_variance, initial_variance)
|
||||||
self.user_model_covariance = np.diag(variances)
|
self.user_model_covariance = np.diag(variances)
|
||||||
|
|
||||||
def update_proposal_model(self, alpha=0.5):
|
def update_proposal_model(self, alpha=0.5):
|
||||||
# Update proposal model to be a weighted average of the current proposal model and the user model
|
# Update proposal model to be a weighted average of the current proposal model and the user model
|
||||||
self.proposal_model_mean = alpha * self.proposal_model_mean + (1 - alpha) * self.user_model_mean
|
self.proposal_model_mean = alpha * self.proposal_model_mean + (1 - alpha) * self.user_model_mean
|
||||||
self.proposal_model_covariance = alpha * self.proposal_model_covariance + (1 - alpha) * self.user_model_covariance
|
self.proposal_model_covariance = alpha * self.proposal_model_covariance + (1 - alpha) * self.user_model_covariance
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
acquisition = PreferenceExpectedImprovement(10, 2, -1.0, 1.0)
|
acquisition = PreferenceExpectedImprovement(10, 2, -1.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user