added BO torch tryout
This commit is contained in:
parent
c0ae4097fe
commit
b09e44daa5
@ -4,7 +4,7 @@
|
|||||||
<content url="file://$MODULE_DIR$">
|
<content url="file://$MODULE_DIR$">
|
||||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||||
</content>
|
</content>
|
||||||
<orderEntry type="jdk" jdkName="Python 3.8 (venv)" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
<component name="PyDocumentationSettings">
|
<component name="PyDocumentationSettings">
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (RlToyTask)" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
26
BoTorchTest/botorchtest1.py
Normal file
26
BoTorchTest/botorchtest1.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from botorch.models import SingleTaskGP
|
||||||
|
from botorch.fit import fit_gpytorch_mll
|
||||||
|
from botorch.utils import standardize
|
||||||
|
from gpytorch.mlls import ExactMarginalLogLikelihood
|
||||||
|
|
||||||
|
from botorch.acquisition import UpperConfidenceBound
|
||||||
|
from botorch.optim import optimize_acqf
|
||||||
|
|
||||||
|
|
||||||
|
train_X = torch.rand(10, 2, dtype=torch.double)
|
||||||
|
Y = 1 - torch.norm(train_X - 0.5, dim=-1, keepdim=True)
|
||||||
|
Y = Y + 0.1 * torch.randn_like(Y) # add some noise
|
||||||
|
train_Y = standardize(Y)
|
||||||
|
|
||||||
|
gp = SingleTaskGP(train_X, train_Y)
|
||||||
|
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
|
||||||
|
fit_gpytorch_mll(mll)
|
||||||
|
|
||||||
|
UCB = UpperConfidenceBound(gp, beta=0.1)
|
||||||
|
|
||||||
|
bounds = torch.stack([torch.zeros(2), torch.ones(2)])
|
||||||
|
candidate, acq_value = optimize_acqf(
|
||||||
|
UCB, bounds=bounds, q=1, num_restarts=5, raw_samples=20,
|
||||||
|
)
|
||||||
|
print(candidate)
|
@ -5,10 +5,10 @@ import numpy as np
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
# BO parameters
|
# BO parameters
|
||||||
env = Continuous_MountainCarEnv()
|
env = Continuous_MountainCarEnv()
|
||||||
nr_steps = 100
|
nr_steps = 100
|
||||||
acquisition_fun = 'ei'
|
acquisition_fun = 'ei'
|
||||||
iteration_steps = 200
|
iteration_steps = 100
|
||||||
|
|
||||||
nr_runs = 100
|
nr_runs = 100
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user