added BO torch tryout
This commit is contained in:
parent
c0ae4097fe
commit
b09e44daa5
@ -4,7 +4,7 @@
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.8 (venv)" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
|
@ -1,4 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (RlToyTask)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||
</project>
|
26
BoTorchTest/botorchtest1.py
Normal file
26
BoTorchTest/botorchtest1.py
Normal file
@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from botorch.models import SingleTaskGP
|
||||
from botorch.fit import fit_gpytorch_mll
|
||||
from botorch.utils import standardize
|
||||
from gpytorch.mlls import ExactMarginalLogLikelihood
|
||||
|
||||
from botorch.acquisition import UpperConfidenceBound
|
||||
from botorch.optim import optimize_acqf
|
||||
|
||||
|
||||
train_X = torch.rand(10, 2, dtype=torch.double)
|
||||
Y = 1 - torch.norm(train_X - 0.5, dim=-1, keepdim=True)
|
||||
Y = Y + 0.1 * torch.randn_like(Y) # add some noise
|
||||
train_Y = standardize(Y)
|
||||
|
||||
gp = SingleTaskGP(train_X, train_Y)
|
||||
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
|
||||
fit_gpytorch_mll(mll)
|
||||
|
||||
UCB = UpperConfidenceBound(gp, beta=0.1)
|
||||
|
||||
bounds = torch.stack([torch.zeros(2), torch.ones(2)])
|
||||
candidate, acq_value = optimize_acqf(
|
||||
UCB, bounds=bounds, q=1, num_restarts=5, raw_samples=20,
|
||||
)
|
||||
print(candidate)
|
@ -5,10 +5,10 @@ import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# BO parameters
|
||||
env = Continuous_MountainCarEnv()
|
||||
env = Continuous_MountainCarEnv()
|
||||
nr_steps = 100
|
||||
acquisition_fun = 'ei'
|
||||
iteration_steps = 200
|
||||
iteration_steps = 100
|
||||
|
||||
nr_runs = 100
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user