initial
This commit is contained in:
parent
d4ed1065cc
commit
f4380fd8f2
21
src/InteractionMetrics/InteractionMetrics/Improvement.py
Normal file
21
src/InteractionMetrics/InteractionMetrics/Improvement.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
class ImprovementQuery:
|
||||||
|
def __init__(self, threshold, period, last_query, rewards):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.period = period
|
||||||
|
self.last_query = last_query
|
||||||
|
self.rewards = rewards
|
||||||
|
|
||||||
|
def query(self):
|
||||||
|
if self.rewards.shape[0] < self.period + 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
elif self.rewards.shape[0] < self.last_query + self.period + 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
else:
|
||||||
|
first = self.rewards[-self.period-1]
|
||||||
|
last = self.rewards[-1]
|
||||||
|
|
||||||
|
slope = (last - first) / self.period
|
||||||
|
|
||||||
|
return slope < self.threshold
|
55
src/InteractionMetrics/InteractionMetrics/MaxAcquisition.py
Normal file
55
src/InteractionMetrics/InteractionMetrics/MaxAcquisition.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import numpy as np
|
||||||
|
from scipy.stats import norm
|
||||||
|
|
||||||
|
|
||||||
|
class MaxAcqQuery:
|
||||||
|
def __init__(self, threshold, gp,
|
||||||
|
nr_test, nr_weights,
|
||||||
|
lower=-1.0, upper=1.0,
|
||||||
|
acq="Expected Improvement",
|
||||||
|
**kwargs):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.gp = gp
|
||||||
|
self.nr_test = nr_test
|
||||||
|
self.nr_weights = nr_weights
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.acq = acq
|
||||||
|
|
||||||
|
self.seed = kwargs.get('seed', None)
|
||||||
|
self.kappa = kwargs.get('kappa', 2.576)
|
||||||
|
self.beta = kwargs.get('beta', 1.2)
|
||||||
|
self.X = kwargs.get('X', None)
|
||||||
|
|
||||||
|
self.rng = np.random.default_rng(self.seed)
|
||||||
|
|
||||||
|
def query(self):
|
||||||
|
X_test = self.rng.uniform(self.lower, self.upper, (self.nr_test, self.nr_weights))
|
||||||
|
max_acq = 0
|
||||||
|
|
||||||
|
if self.acq == "Expected Improvement":
|
||||||
|
if self.X is None:
|
||||||
|
raise ValueError
|
||||||
|
y_hat = self.gp.predict(self.X)
|
||||||
|
best_y = max(y_hat)
|
||||||
|
mu, sigma = self.gp.predict(X_test, return_std=True)
|
||||||
|
z = (mu - best_y - self.kappa) / sigma
|
||||||
|
ei = (mu - best_y - self.kappa) * norm.cdf(z) + sigma * norm.pdf(z)
|
||||||
|
max_acq = np.max(ei)
|
||||||
|
|
||||||
|
if self.acq == "Probability of Improvement":
|
||||||
|
if self.X is None:
|
||||||
|
raise ValueError
|
||||||
|
y_hat = self.gp.predict(self.X)
|
||||||
|
best_y = max(y_hat)
|
||||||
|
mu, sigma = self.gp.predict(X_test, return_std=True)
|
||||||
|
z = (mu - best_y - self.kappa) / sigma
|
||||||
|
pi = norm.cdf(z)
|
||||||
|
max_acq = np.max(pi)
|
||||||
|
|
||||||
|
if self.acq == "Upper Confidence Bound":
|
||||||
|
mu, sigma = self.gp.predict(X_test, return_std=True)
|
||||||
|
cb = mu + self.beta * sigma
|
||||||
|
max_acq = np.max(cb)
|
||||||
|
|
||||||
|
return max_acq > self.threshold
|
@ -0,0 +1,11 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class RandomQuery:
|
||||||
|
def __init__(self, threshold):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.random = np.random.uniform(0.0, 1.0, 1)
|
||||||
|
|
||||||
|
def query(self):
|
||||||
|
return self.random > self.threshold
|
||||||
|
|
12
src/InteractionMetrics/InteractionMetrics/Regular.py
Normal file
12
src/InteractionMetrics/InteractionMetrics/Regular.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
class RegularQuery:
|
||||||
|
def __init__(self, regular, episode):
|
||||||
|
self.regular = int(regular)
|
||||||
|
self.counter = episode
|
||||||
|
|
||||||
|
def query(self):
|
||||||
|
|
||||||
|
if self.counter % self.regular == 0 and self.counter != 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
else:
|
||||||
|
return False
|
@ -1,18 +0,0 @@
|
|||||||
<?xml version="1.0"?>
|
|
||||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
|
||||||
<package format="3">
|
|
||||||
<name>Tasks</name>
|
|
||||||
<version>0.0.0</version>
|
|
||||||
<description>TODO: Package description</description>
|
|
||||||
<maintainer email="nikolaus.feith@unileoben.ac.at">niko</maintainer>
|
|
||||||
<license>TODO: License declaration</license>
|
|
||||||
|
|
||||||
<test_depend>ament_copyright</test_depend>
|
|
||||||
<test_depend>ament_flake8</test_depend>
|
|
||||||
<test_depend>ament_pep257</test_depend>
|
|
||||||
<test_depend>python3-pytest</test_depend>
|
|
||||||
|
|
||||||
<export>
|
|
||||||
<build_type>ament_python</build_type>
|
|
||||||
</export>
|
|
||||||
</package>
|
|
@ -1,4 +0,0 @@
|
|||||||
[develop]
|
|
||||||
script_dir=$base/lib/Tasks
|
|
||||||
[install]
|
|
||||||
install_scripts=$base/lib/Tasks
|
|
@ -1,25 +0,0 @@
|
|||||||
from setuptools import find_packages, setup
|
|
||||||
|
|
||||||
package_name = 'Tasks'
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name=package_name,
|
|
||||||
version='0.0.0',
|
|
||||||
packages=find_packages(exclude=['test']),
|
|
||||||
data_files=[
|
|
||||||
('share/ament_index/resource_index/packages',
|
|
||||||
['resource/' + package_name]),
|
|
||||||
('share/' + package_name, ['package.xml']),
|
|
||||||
],
|
|
||||||
install_requires=['setuptools'],
|
|
||||||
zip_safe=True,
|
|
||||||
maintainer='niko',
|
|
||||||
maintainer_email='nikolaus.feith@unileoben.ac.at',
|
|
||||||
description='TODO: Package description',
|
|
||||||
license='TODO: License declaration',
|
|
||||||
tests_require=['pytest'],
|
|
||||||
entry_points={
|
|
||||||
'console_scripts': [
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
|
@ -1,25 +0,0 @@
|
|||||||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from ament_copyright.main import main
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
# Remove the `skip` decorator once the source file(s) have a copyright header
|
|
||||||
@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
|
|
||||||
@pytest.mark.copyright
|
|
||||||
@pytest.mark.linter
|
|
||||||
def test_copyright():
|
|
||||||
rc = main(argv=['.', 'test'])
|
|
||||||
assert rc == 0, 'Found errors'
|
|
@ -1,25 +0,0 @@
|
|||||||
# Copyright 2017 Open Source Robotics Foundation, Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from ament_flake8.main import main_with_errors
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.flake8
|
|
||||||
@pytest.mark.linter
|
|
||||||
def test_flake8():
|
|
||||||
rc, errors = main_with_errors(argv=[])
|
|
||||||
assert rc == 0, \
|
|
||||||
'Found %d code style errors / warnings:\n' % len(errors) + \
|
|
||||||
'\n'.join(errors)
|
|
@ -1,23 +0,0 @@
|
|||||||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from ament_pep257.main import main
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linter
|
|
||||||
@pytest.mark.pep257
|
|
||||||
def test_pep257():
|
|
||||||
rc = main(argv=['.', 'test'])
|
|
||||||
assert rc == 0, 'Found code style errors / warnings'
|
|
Loading…
Reference in New Issue
Block a user