interaction_query finished
This commit is contained in:
parent
bb7b3e992d
commit
4b7dd8621a
@ -0,0 +1,8 @@
|
||||
pytest~=6.2.5
|
||||
setuptools==58.2.0
|
||||
numpy~=1.26.4
|
||||
pydot~=1.4.2
|
||||
empy~=3.3.4
|
||||
lark~=1.1.1
|
||||
scipy~=1.12.0
|
||||
scikit-learn~=1.4.0
|
@ -1,4 +0,0 @@
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from interaction_msgs.srv import Query
|
@ -1,12 +0,0 @@
|
||||
class RegularQuery:
|
||||
def __init__(self, regular, episode):
|
||||
self.regular = int(regular)
|
||||
self.counter = episode
|
||||
|
||||
def query(self):
|
||||
|
||||
if self.counter % self.regular == 0 and self.counter != 0:
|
||||
return True
|
||||
|
||||
else:
|
||||
return False
|
@ -1,4 +0,0 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/InteractionQuery
|
||||
[install]
|
||||
install_scripts=$base/lib/InteractionQuery
|
@ -1,18 +0,0 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>ObjectiveFunctions</name>
|
||||
<version>0.0.0</version>
|
||||
<description>TODO: Package description</description>
|
||||
<maintainer email="nikolaus.feith@unileoben.ac.at">niko</maintainer>
|
||||
<license>TODO: License declaration</license>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
@ -1,4 +0,0 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/ObjectiveFunctions
|
||||
[install]
|
||||
install_scripts=$base/lib/ObjectiveFunctions
|
@ -1,4 +0,0 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/Optimizers
|
||||
[install]
|
||||
install_scripts=$base/lib/Optimizers
|
@ -1,4 +0,0 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/RepresentationModels
|
||||
[install]
|
||||
install_scripts=$base/lib/RepresentationModels
|
@ -1,25 +0,0 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = 'RepresentationModels'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.0.0',
|
||||
packages=find_packages(exclude=['test']),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='niko',
|
||||
maintainer_email='nikolaus.feith@unileoben.ac.at',
|
||||
description='TODO: Package description',
|
||||
license='TODO: License declaration',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
],
|
||||
},
|
||||
)
|
@ -1,25 +0,0 @@
|
||||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_copyright.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
# Remove the `skip` decorator once the source file(s) have a copyright header
|
||||
@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
|
||||
@pytest.mark.copyright
|
||||
@pytest.mark.linter
|
||||
def test_copyright():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found errors'
|
@ -1,25 +0,0 @@
|
||||
# Copyright 2017 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_flake8.main import main_with_errors
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.flake8
|
||||
@pytest.mark.linter
|
||||
def test_flake8():
|
||||
rc, errors = main_with_errors(argv=[])
|
||||
assert rc == 0, \
|
||||
'Found %d code style errors / warnings:\n' % len(errors) + \
|
||||
'\n'.join(errors)
|
@ -1,23 +0,0 @@
|
||||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_pep257.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.linter
|
||||
@pytest.mark.pep257
|
||||
def test_pep257():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found code style errors / warnings'
|
@ -1,3 +1,6 @@
|
||||
# MODES: random:=0, regular:=1, improvement:=2
|
||||
uint16 modes
|
||||
|
||||
# random query
|
||||
float32 threshold
|
||||
|
||||
@ -7,9 +10,9 @@ uint16 current_episode
|
||||
|
||||
# improvement query
|
||||
# float32 threshold
|
||||
uint16 period
|
||||
# uint16 frequency
|
||||
uint16 last_queried_episode
|
||||
float32[] rewards
|
||||
float32[] last_rewards
|
||||
|
||||
---
|
||||
bool interaction
|
@ -1,7 +1,7 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>RepresentationModels</name>
|
||||
<name>interaction_objective_function</name>
|
||||
<version>0.0.0</version>
|
||||
<description>TODO: Package description</description>
|
||||
<maintainer email="nikolaus.feith@unileoben.ac.at">niko</maintainer>
|
4
src/interaction_objective_function/setup.cfg
Normal file
4
src/interaction_objective_function/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/interaction_objective_function
|
||||
[install]
|
||||
install_scripts=$base/lib/interaction_objective_function
|
@ -1,6 +1,6 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = 'Optimizers'
|
||||
package_name = 'interaction_objective_function'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
@ -0,0 +1,4 @@
|
||||
from .confidence_bounds import ConfidenceBounds
|
||||
from .probability_of_improvement import ProbabilityOfImprovement
|
||||
from .expected_improvement import ExpectedImprovement
|
||||
from .preference_expected_improvement import PreferenceExpectedImprovement
|
@ -0,0 +1,31 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ConfidenceBounds:
|
||||
def __init__(self, nr_weights, nr_samples=100, beta=1.2, seed=None, lower_bound=-1.0, upper_bound=1.0):
|
||||
self.nr_weights = nr_weights
|
||||
self.nr_samples = nr_samples
|
||||
self.beta = beta # if beta negative => lower confidence bounds
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, gauss_process, _, seed=None):
|
||||
# if seed is set for whole experiment
|
||||
if self.seed is not None:
|
||||
seed = self.seed
|
||||
|
||||
# random generator
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
# sample from the surrogate
|
||||
x_test = rng.uniform(self.lower_bound, self.upper_bound, size=(self.nr_samples, self.nr_weights))
|
||||
mu, sigma = gauss_process.predict(x_test, return_std=True)
|
||||
|
||||
# upper/lower confidence bounds
|
||||
cb = mu + self.beta * sigma
|
||||
|
||||
# get the best result and return it
|
||||
idx = np.argmax(cb)
|
||||
return x_test[idx, :]
|
@ -0,0 +1,37 @@
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
|
||||
|
||||
class ExpectedImprovement:
|
||||
def __init__(self, nr_weights, nr_samples=100, kappa=0.0, seed=None, lower_bound=-1.0, upper_bound=1.0):
|
||||
self.nr_weights = nr_weights
|
||||
self.nr_samples = nr_samples
|
||||
self.kappa = kappa
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, gauss_process, x_observed, seed=None):
|
||||
# if seed is set for whole experiment
|
||||
if self.seed is not None:
|
||||
seed = self.seed
|
||||
|
||||
# random generator
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
# get the best so far observed y
|
||||
mu = gauss_process.predict(x_observed)
|
||||
y_best = max(mu)
|
||||
|
||||
# sample from surrogate
|
||||
x_test = rng.uniform(self.lower_bound, self.upper_bound, size=(self.nr_samples, self.nr_weights))
|
||||
mu, sigma = gauss_process.predict(x_test, return_std=True)
|
||||
|
||||
# expected improvement
|
||||
z = (mu - y_best - self.kappa) / sigma
|
||||
ei = (mu - y_best - self.kappa) * norm.cdf(z) + sigma * norm.pdf(z)
|
||||
|
||||
# get the best result and return it
|
||||
idx = np.argmax(ei)
|
||||
return x_test[idx, :]
|
@ -0,0 +1,93 @@
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
|
||||
|
||||
class PreferenceExpectedImprovement:
|
||||
def __init__(self, nr_dims, initial_variance, update_variance, nr_samples=100,
|
||||
kappa=0.0, lower_bound=None, upper_bound=None, seed=None, fixed_dims=None):
|
||||
self.nr_dims = nr_dims
|
||||
|
||||
self.initial_variance = initial_variance
|
||||
self.update_variance = update_variance
|
||||
|
||||
self.nr_samples = nr_samples
|
||||
self.kappa = kappa
|
||||
|
||||
if lower_bound is None:
|
||||
self.lower_bound = [-1.] * self.nr_dims
|
||||
else:
|
||||
self.lower_bound = lower_bound
|
||||
|
||||
if upper_bound is None:
|
||||
self.upper_bound = [1.] * self.nr_dims
|
||||
else:
|
||||
self.upper_bound = upper_bound
|
||||
|
||||
self.seed = seed
|
||||
|
||||
# initial proposal distribution
|
||||
self.proposal_mean = np.zeros((nr_dims, 1))
|
||||
self.proposal_cov = np.diag(np.ones((nr_dims,)) * self.initial_variance)
|
||||
|
||||
# fixed dimension for robot experiment
|
||||
self.fixed_dims = fixed_dims
|
||||
|
||||
def rejection_sampling(self, seed=None):
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
samples = np.empty((0, self.nr_dims))
|
||||
while samples.shape[0] < self.nr_samples:
|
||||
# sample from the multi variate gaussian distribution
|
||||
sample = np.zeros((1, self.nr_dims))
|
||||
for i in range(self.nr_dims):
|
||||
if i in self.fixed_dims:
|
||||
sample[0, i] = self.fixed_dims[i]
|
||||
else:
|
||||
check = False
|
||||
while not check:
|
||||
sample[0, i] = rng.normal(self.proposal_mean[i], self.proposal_cov[i, i])
|
||||
if self.lower_bound[i] <= sample[0, i] <= self.upper_bound[i]:
|
||||
check = True
|
||||
|
||||
samples = np.append(samples, sample, axis=0)
|
||||
|
||||
return samples
|
||||
|
||||
def __call__(self, gauss_process, x_observed, seed=None):
|
||||
# if seed is set for whole experiment
|
||||
if self.seed is not None:
|
||||
seed = self.seed
|
||||
|
||||
# get the best so far observed y
|
||||
mu = gauss_process.predict(x_observed)
|
||||
y_best = max(mu)
|
||||
|
||||
# sample from surrogate
|
||||
x_test = self.rejection_sampling(seed)
|
||||
mu, sigma = gauss_process.predict(x_test, return_std=True)
|
||||
|
||||
# expected improvement
|
||||
z = (mu - y_best - self.kappa) / sigma
|
||||
ei = (mu - y_best - self.kappa) * norm.cdf(z) + sigma * norm.pdf(z)
|
||||
|
||||
# get the best result and return it
|
||||
idx = np.argmax(ei)
|
||||
return x_test[idx, :]
|
||||
|
||||
def update_proposal_model(self, preference_mean, preference_bool):
|
||||
cov_diag = np.ones((self.nr_dims,)) * self.initial_variance
|
||||
cov_diag[preference_bool] = self.update_variance
|
||||
|
||||
preference_cov = np.diag(cov_diag)
|
||||
|
||||
preference_mean = preference_mean.reshape(-1, 1)
|
||||
|
||||
posterior_mean = np.linalg.inv(np.linalg.inv(self.proposal_cov) + np.linalg.inv(preference_cov))\
|
||||
.dot(np.linalg.inv(self.proposal_cov).dot(self.proposal_mean)
|
||||
+ np.linalg.inv(preference_cov).dot(preference_mean))
|
||||
|
||||
posterior_cov = np.linalg.inv(np.linalg.inv(self.proposal_cov) + np.linalg.inv(preference_cov))
|
||||
|
||||
self.proposal_mean = posterior_mean
|
||||
self.proposal_cov = posterior_cov
|
@ -0,0 +1,37 @@
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
|
||||
|
||||
class ProbabilityOfImprovement:
|
||||
def __init__(self, nr_weights, nr_samples=100, kappa=0.0, seed=None, lower_bound=-1.0, upper_bound=1.0):
|
||||
self.nr_weights = nr_weights
|
||||
self.nr_samples = nr_samples
|
||||
self.kappa = kappa
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, gauss_process, x_observed, seed=None):
|
||||
# if seed is set for whole experiment
|
||||
if self.seed is not None:
|
||||
seed = self.seed
|
||||
|
||||
# random generator
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
# get the best so far observed y
|
||||
mu = gauss_process.predict(x_observed)
|
||||
y_best = max(mu)
|
||||
|
||||
# sample from surrogate
|
||||
x_test = rng.uniform(self.lower_bound, self.upper_bound, size=(self.nr_samples, self.nr_weights))
|
||||
mu, sigma = gauss_process.predict(x_test, return_std=True)
|
||||
|
||||
# probability of improvement
|
||||
z = (mu - y_best - self.kappa) / sigma
|
||||
pi = norm.cdf(z)
|
||||
|
||||
# get the best result and return it
|
||||
idx = np.argmax(pi)
|
||||
return x_test[idx, :]
|
@ -0,0 +1 @@
|
||||
|
@ -0,0 +1,137 @@
|
||||
|
||||
import numpy as np
|
||||
from sklearn.gaussian_process import GaussianProcessRegressor
|
||||
from sklearn.gaussian_process.kernels import Matern, RBF, ExpSineSquared
|
||||
|
||||
from ..acquisition_function import ConfidenceBounds
|
||||
from ..acquisition_function import ProbabilityOfImprovement
|
||||
from ..acquisition_function import ExpectedImprovement
|
||||
from ..acquisition_function import PreferenceExpectedImprovement
|
||||
|
||||
from sklearn.exceptions import ConvergenceWarning
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings('ignore', category=ConvergenceWarning)
|
||||
|
||||
|
||||
class BayesianOptimization:
|
||||
def __init__(self, nr_steps, nr_dimensions, nr_policy_parameters, seed=None,
|
||||
fixed_dimensions=None, lower_bound=None, upper_bound=None,
|
||||
acquisition_function_name="EI", kernel_name="Matern",
|
||||
**kwargs):
|
||||
|
||||
self.nr_steps = nr_steps
|
||||
self.nr_dimensions = nr_dimensions
|
||||
self.nr_policy_parameters = nr_policy_parameters
|
||||
self.nr_weights = nr_policy_parameters * nr_dimensions
|
||||
|
||||
if lower_bound is None:
|
||||
self.lower_bound = [-1.] * self.nr_weights
|
||||
else:
|
||||
self.lower_bound = lower_bound
|
||||
|
||||
if upper_bound is None:
|
||||
self.upper_bound = [-1.] * self.nr_weights
|
||||
else:
|
||||
self.upper_bound = upper_bound
|
||||
|
||||
self.seed = seed
|
||||
self.fixed_dimensions = fixed_dimensions
|
||||
|
||||
self.x_observed = None
|
||||
self.y_observed = None
|
||||
self.best_reward = None
|
||||
self.episode = 0
|
||||
|
||||
self.gauss_process = None
|
||||
self.n_restarts_optimizer = kwargs.get('n_restarts_optimizer', 5)
|
||||
|
||||
|
||||
|
||||
# region Kernel
|
||||
length_scale = kwargs.get('length_scale', 1.0)
|
||||
|
||||
if kernel_name == "Matern":
|
||||
nu = kwargs.get('nu', 1.5)
|
||||
self.kernel = Matern(nu=nu, length_scale=length_scale)
|
||||
|
||||
elif kernel_name == "RBF":
|
||||
self.kernel = RBF(length_scale=length_scale)
|
||||
|
||||
elif kernel_name == "ExpSineSquared":
|
||||
periodicity = kwargs.get('periodicity', 1.0)
|
||||
self.kernel = ExpSineSquared(length_scale=length_scale, periodicity=periodicity)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("This kernel is not implemented!")
|
||||
# endregion
|
||||
|
||||
# region Acquisitionfunctions
|
||||
if 'nr_samples' in kwargs:
|
||||
nr_samples = kwargs['nr_samples']
|
||||
else:
|
||||
nr_samples = 100
|
||||
|
||||
if acquisition_function_name == "CB":
|
||||
beta = kwargs.get('beta', 1.2)
|
||||
self.acquisition_function = ConfidenceBounds(self.nr_weights, nr_samples=nr_samples, beta=beta, seed=seed,
|
||||
lower_bound=lower_bound, upper_bound=upper_bound)
|
||||
|
||||
elif acquisition_function_name == "PI":
|
||||
kappa = kwargs.get('kappa', 0.0)
|
||||
self.acquisition_function = ProbabilityOfImprovement(self.nr_weights, nr_samples=nr_samples, kappa=kappa,
|
||||
seed=seed, lower_bound=lower_bound,
|
||||
upper_bound=upper_bound)
|
||||
elif acquisition_function_name == "EI":
|
||||
kappa = kwargs.get('kappa', 0.0)
|
||||
self.acquisition_function = ExpectedImprovement(self.nr_weights, nr_samples=nr_samples, kappa=kappa,
|
||||
seed=seed, lower_bound=lower_bound, upper_bound=upper_bound)
|
||||
elif acquisition_function_name == "PEI":
|
||||
kappa = kwargs.get('kappa', 0.0)
|
||||
|
||||
initial_variance = kwargs.get('initial_variance', None)
|
||||
update_variance = kwargs.get('update_variance', None)
|
||||
|
||||
if initial_variance is None or update_variance is None:
|
||||
raise ValueError("Initial_variance and update_variance has to be provided in PEI!")
|
||||
|
||||
self.acquisition_function = PreferenceExpectedImprovement(self.nr_weights, initial_variance,
|
||||
update_variance, nr_samples=nr_samples,
|
||||
kappa=kappa, lower_bound=lower_bound,
|
||||
upper_bound=upper_bound, seed=seed,
|
||||
fixed_dims=fixed_dimensions)
|
||||
else:
|
||||
raise NotImplementedError("This acquisition function is not implemented!")
|
||||
# endregion
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.gauss_process = GaussianProcessRegressor(self.kernel, n_restarts_optimizer=self.n_restarts_optimizer)
|
||||
self.best_reward = np.empty((1, 1))
|
||||
self.x_observed = np.zeros((1, self.nr_weights), dtype=np.float64)
|
||||
self.y_observed = np.zeros((1, 1), dtype=np.float64)
|
||||
self.episode = 0
|
||||
|
||||
def next_observation(self):
|
||||
x_next = self.acquisition_function(self.gauss_process, self.x_observed, seed=self.seed)
|
||||
return x_next
|
||||
|
||||
def add_observation(self, y_new, x_new):
|
||||
if self.episode == 0:
|
||||
self.x_observed[0, :] = x_new
|
||||
self.y_observed[0] = y_new
|
||||
self.best_reward[0] = np.max(self.y_observed)
|
||||
else:
|
||||
self.x_observed = np.vstack((self.x_observed, np.around(x_new, decimals=8)))
|
||||
self.y_observed = np.vstack((self.y_observed, y_new))
|
||||
self.best_reward = np.vstack((self.best_reward, np.max(self.y_observed)))
|
||||
|
||||
self.gauss_process.fit(self.x_observed, self.y_observed)
|
||||
self.episode += 1
|
||||
|
||||
def get_best_result(self):
|
||||
y_max = np.max(self.y_observed)
|
||||
idx = np.argmax(self.y_observed)
|
||||
x_max = self.x_observed[idx, :]
|
||||
return y_max, x_max, idx
|
@ -1,7 +1,7 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>Optimizers</name>
|
||||
<name>interaction_optimizers</name>
|
||||
<version>0.0.0</version>
|
||||
<description>TODO: Package description</description>
|
||||
<maintainer email="nikolaus.feith@unileoben.ac.at">niko</maintainer>
|
4
src/interaction_optimizers/setup.cfg
Normal file
4
src/interaction_optimizers/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/interaction_optimizers
|
||||
[install]
|
||||
install_scripts=$base/lib/interaction_optimizers
|
@ -1,6 +1,6 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = 'ObjectiveFunctions'
|
||||
package_name = 'interaction_optimizers'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
@ -1,3 +1,4 @@
|
||||
|
||||
class ImprovementQuery:
|
||||
def __init__(self, threshold, period, last_query, rewards):
|
||||
self.threshold = threshold
|
71
src/interaction_query/interaction_query/query_node.py
Normal file
71
src/interaction_query/interaction_query/query_node.py
Normal file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from .random_query import RandomQuery
|
||||
from .regular_query import RegularQuery
|
||||
from .improvement_query import ImprovementQuery
|
||||
|
||||
from interaction_msgs.srv import Query
|
||||
|
||||
|
||||
class QueryNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('query_node')
|
||||
self.query_service = self.create_service(Query, 'user_query', self.query_callback)
|
||||
|
||||
self.get_logger().info('Query node started!')
|
||||
|
||||
def check_random_request(self, req):
|
||||
t = req.threshold
|
||||
if 0 < t <= 1:
|
||||
return True
|
||||
else:
|
||||
self.get_logger().error('Invalid random request in user query!')
|
||||
|
||||
def check_regular_request(self, req):
|
||||
f = req.frequency
|
||||
if f > 0:
|
||||
return True
|
||||
else:
|
||||
self.get_logger().error('Invalid regular request in user query!')
|
||||
|
||||
def check_improvement_request(self, req):
|
||||
t = req.threshold
|
||||
f = req.frequency
|
||||
last_rewards = req.last_rewards
|
||||
if 0 < t <= 1 and f > 0 and isinstance(last_rewards, list):
|
||||
return True
|
||||
else:
|
||||
self.get_logger().error('Invalid improvement request in user query!')
|
||||
|
||||
def query_callback(self, request, response):
|
||||
mode = response.mode
|
||||
query_obj = None
|
||||
if mode == 0:
|
||||
if self.check_random_request(request):
|
||||
query_obj = RandomQuery(request.threshold)
|
||||
elif mode == 1:
|
||||
if self.check_regular_request(request):
|
||||
query_obj = RegularQuery(request.frequency, request.current_episode)
|
||||
elif mode == 2:
|
||||
if self.check_improvement_request(request):
|
||||
query_obj = ImprovementQuery(request.threshold, request.frequency,
|
||||
request.last_queried_episode, request.last_rewards)
|
||||
else:
|
||||
self.get_logger().error('Invalid query mode!')
|
||||
|
||||
if query_obj is not None:
|
||||
response.interaction = query_obj.query()
|
||||
return response
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = QueryNode()
|
||||
rclpy.spin(node)
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
12
src/interaction_query/interaction_query/regular_query.py
Normal file
12
src/interaction_query/interaction_query/regular_query.py
Normal file
@ -0,0 +1,12 @@
|
||||
class RegularQuery:
|
||||
def __init__(self, frequency, episode):
|
||||
self.frequency = int(frequency)
|
||||
self.counter = episode
|
||||
|
||||
def query(self):
|
||||
|
||||
if self.counter % self.frequency == 0 and self.counter != 0:
|
||||
return True
|
||||
|
||||
else:
|
||||
return False
|
@ -1,12 +1,15 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>InteractionQuery</name>
|
||||
<name>interaction_query</name>
|
||||
<version>0.0.0</version>
|
||||
<description>TODO: Package description</description>
|
||||
<maintainer email="root@todo.todo">root</maintainer>
|
||||
<license>TODO: License declaration</license>
|
||||
|
||||
<exec_depend>interaction_msgs</exec_depend>
|
||||
<exec_depend>rclpy</exec_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
4
src/interaction_query/setup.cfg
Normal file
4
src/interaction_query/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/interaction_query
|
||||
[install]
|
||||
install_scripts=$base/lib/interaction_query
|
@ -1,6 +1,6 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = 'InteractionQuery'
|
||||
package_name = 'interaction_query'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
@ -20,6 +20,7 @@ setup(
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'query_n = interaction_query.query_node:main',
|
||||
],
|
||||
},
|
||||
)
|
Loading…
Reference in New Issue
Block a user