2023-02-02 17:54:09 +00:00
|
|
|
import numpy as np
|
|
|
|
from scipy.stats import norm
|
|
|
|
|
2023-04-21 10:27:29 +00:00
|
|
|
|
2023-02-02 17:54:09 +00:00
|
|
|
def ExpectedImprovement(gp, X, nr_test, nr_weights, kappa=2.576, seed=None, lower=-1.0, upper=1.0):
|
|
|
|
y_hat = gp.predict(X)
|
|
|
|
best_y = max(y_hat)
|
|
|
|
rng = np.random.default_rng(seed=seed)
|
|
|
|
X_test = rng.uniform(lower, upper, (nr_test, nr_weights))
|
|
|
|
mu, sigma = gp.predict(X_test, return_std=True)
|
|
|
|
z = (mu - best_y - kappa) / sigma
|
|
|
|
ei = (mu - best_y - kappa) * norm.cdf(z) + sigma * norm.pdf(z)
|
|
|
|
|
|
|
|
idx = np.argmax(ei)
|
|
|
|
X_next = X_test[idx, :]
|
|
|
|
return X_next
|