initial
This commit is contained in:
parent
d4ed1065cc
commit
f4380fd8f2
21
src/InteractionMetrics/InteractionMetrics/Improvement.py
Normal file
21
src/InteractionMetrics/InteractionMetrics/Improvement.py
Normal file
@ -0,0 +1,21 @@
|
||||
class ImprovementQuery:
|
||||
def __init__(self, threshold, period, last_query, rewards):
|
||||
self.threshold = threshold
|
||||
self.period = period
|
||||
self.last_query = last_query
|
||||
self.rewards = rewards
|
||||
|
||||
def query(self):
|
||||
if self.rewards.shape[0] < self.period + 1:
|
||||
return False
|
||||
|
||||
elif self.rewards.shape[0] < self.last_query + self.period + 1:
|
||||
return False
|
||||
|
||||
else:
|
||||
first = self.rewards[-self.period-1]
|
||||
last = self.rewards[-1]
|
||||
|
||||
slope = (last - first) / self.period
|
||||
|
||||
return slope < self.threshold
|
55
src/InteractionMetrics/InteractionMetrics/MaxAcquisition.py
Normal file
55
src/InteractionMetrics/InteractionMetrics/MaxAcquisition.py
Normal file
@ -0,0 +1,55 @@
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
|
||||
|
||||
class MaxAcqQuery:
|
||||
def __init__(self, threshold, gp,
|
||||
nr_test, nr_weights,
|
||||
lower=-1.0, upper=1.0,
|
||||
acq="Expected Improvement",
|
||||
**kwargs):
|
||||
self.threshold = threshold
|
||||
self.gp = gp
|
||||
self.nr_test = nr_test
|
||||
self.nr_weights = nr_weights
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.acq = acq
|
||||
|
||||
self.seed = kwargs.get('seed', None)
|
||||
self.kappa = kwargs.get('kappa', 2.576)
|
||||
self.beta = kwargs.get('beta', 1.2)
|
||||
self.X = kwargs.get('X', None)
|
||||
|
||||
self.rng = np.random.default_rng(self.seed)
|
||||
|
||||
def query(self):
|
||||
X_test = self.rng.uniform(self.lower, self.upper, (self.nr_test, self.nr_weights))
|
||||
max_acq = 0
|
||||
|
||||
if self.acq == "Expected Improvement":
|
||||
if self.X is None:
|
||||
raise ValueError
|
||||
y_hat = self.gp.predict(self.X)
|
||||
best_y = max(y_hat)
|
||||
mu, sigma = self.gp.predict(X_test, return_std=True)
|
||||
z = (mu - best_y - self.kappa) / sigma
|
||||
ei = (mu - best_y - self.kappa) * norm.cdf(z) + sigma * norm.pdf(z)
|
||||
max_acq = np.max(ei)
|
||||
|
||||
if self.acq == "Probability of Improvement":
|
||||
if self.X is None:
|
||||
raise ValueError
|
||||
y_hat = self.gp.predict(self.X)
|
||||
best_y = max(y_hat)
|
||||
mu, sigma = self.gp.predict(X_test, return_std=True)
|
||||
z = (mu - best_y - self.kappa) / sigma
|
||||
pi = norm.cdf(z)
|
||||
max_acq = np.max(pi)
|
||||
|
||||
if self.acq == "Upper Confidence Bound":
|
||||
mu, sigma = self.gp.predict(X_test, return_std=True)
|
||||
cb = mu + self.beta * sigma
|
||||
max_acq = np.max(cb)
|
||||
|
||||
return max_acq > self.threshold
|
@ -0,0 +1,11 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RandomQuery:
|
||||
def __init__(self, threshold):
|
||||
self.threshold = threshold
|
||||
self.random = np.random.uniform(0.0, 1.0, 1)
|
||||
|
||||
def query(self):
|
||||
return self.random > self.threshold
|
||||
|
12
src/InteractionMetrics/InteractionMetrics/Regular.py
Normal file
12
src/InteractionMetrics/InteractionMetrics/Regular.py
Normal file
@ -0,0 +1,12 @@
|
||||
class RegularQuery:
|
||||
def __init__(self, regular, episode):
|
||||
self.regular = int(regular)
|
||||
self.counter = episode
|
||||
|
||||
def query(self):
|
||||
|
||||
if self.counter % self.regular == 0 and self.counter != 0:
|
||||
return True
|
||||
|
||||
else:
|
||||
return False
|
@ -1,18 +0,0 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>Tasks</name>
|
||||
<version>0.0.0</version>
|
||||
<description>TODO: Package description</description>
|
||||
<maintainer email="nikolaus.feith@unileoben.ac.at">niko</maintainer>
|
||||
<license>TODO: License declaration</license>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
@ -1,4 +0,0 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/Tasks
|
||||
[install]
|
||||
install_scripts=$base/lib/Tasks
|
@ -1,25 +0,0 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = 'Tasks'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.0.0',
|
||||
packages=find_packages(exclude=['test']),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='niko',
|
||||
maintainer_email='nikolaus.feith@unileoben.ac.at',
|
||||
description='TODO: Package description',
|
||||
license='TODO: License declaration',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
],
|
||||
},
|
||||
)
|
@ -1,25 +0,0 @@
|
||||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_copyright.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
# Remove the `skip` decorator once the source file(s) have a copyright header
|
||||
@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
|
||||
@pytest.mark.copyright
|
||||
@pytest.mark.linter
|
||||
def test_copyright():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found errors'
|
@ -1,25 +0,0 @@
|
||||
# Copyright 2017 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_flake8.main import main_with_errors
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.flake8
|
||||
@pytest.mark.linter
|
||||
def test_flake8():
|
||||
rc, errors = main_with_errors(argv=[])
|
||||
assert rc == 0, \
|
||||
'Found %d code style errors / warnings:\n' % len(errors) + \
|
||||
'\n'.join(errors)
|
@ -1,23 +0,0 @@
|
||||
# Copyright 2015 Open Source Robotics Foundation, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ament_pep257.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.linter
|
||||
@pytest.mark.pep257
|
||||
def test_pep257():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found code style errors / warnings'
|
Loading…
Reference in New Issue
Block a user