diff --git a/src/InteractionMetrics/InteractionMetrics/Improvement.py b/src/InteractionMetrics/InteractionMetrics/Improvement.py new file mode 100644 index 0000000..416a832 --- /dev/null +++ b/src/InteractionMetrics/InteractionMetrics/Improvement.py @@ -0,0 +1,21 @@ +class ImprovementQuery: + def __init__(self, threshold, period, last_query, rewards): + self.threshold = threshold + self.period = period + self.last_query = last_query + self.rewards = rewards + + def query(self): + if self.rewards.shape[0] < self.period + 1: + return False + + elif self.rewards.shape[0] < self.last_query + self.period + 1: + return False + + else: + first = self.rewards[-self.period-1] + last = self.rewards[-1] + + slope = (last - first) / self.period + + return slope < self.threshold diff --git a/src/InteractionMetrics/InteractionMetrics/MaxAcquisition.py b/src/InteractionMetrics/InteractionMetrics/MaxAcquisition.py new file mode 100644 index 0000000..90dbd47 --- /dev/null +++ b/src/InteractionMetrics/InteractionMetrics/MaxAcquisition.py @@ -0,0 +1,55 @@ +import numpy as np +from scipy.stats import norm + + +class MaxAcqQuery: + def __init__(self, threshold, gp, + nr_test, nr_weights, + lower=-1.0, upper=1.0, + acq="Expected Improvement", + **kwargs): + self.threshold = threshold + self.gp = gp + self.nr_test = nr_test + self.nr_weights = nr_weights + self.lower = lower + self.upper = upper + self.acq = acq + + self.seed = kwargs.get('seed', None) + self.kappa = kwargs.get('kappa', 2.576) + self.beta = kwargs.get('beta', 1.2) + self.X = kwargs.get('X', None) + + self.rng = np.random.default_rng(self.seed) + + def query(self): + X_test = self.rng.uniform(self.lower, self.upper, (self.nr_test, self.nr_weights)) + max_acq = 0 + + if self.acq == "Expected Improvement": + if self.X is None: + raise ValueError + y_hat = self.gp.predict(self.X) + best_y = max(y_hat) + mu, sigma = self.gp.predict(X_test, return_std=True) + z = (mu - best_y - self.kappa) / sigma + ei = (mu - best_y - self.kappa) * norm.cdf(z) + sigma * norm.pdf(z) + max_acq = np.max(ei) + + if self.acq == "Probability of Improvement": + if self.X is None: + raise ValueError + y_hat = self.gp.predict(self.X) + best_y = max(y_hat) + mu, sigma = self.gp.predict(X_test, return_std=True) + z = (mu - best_y - self.kappa) / sigma + pi = norm.cdf(z) + max_acq = np.max(pi) + + if self.acq == "Upper Confidence Bound": + mu, sigma = self.gp.predict(X_test, return_std=True) + cb = mu + self.beta * sigma + max_acq = np.max(cb) + + return max_acq > self.threshold diff --git a/src/InteractionMetrics/InteractionMetrics/Random.py b/src/InteractionMetrics/InteractionMetrics/Random.py index e69de29..e53d125 100644 --- a/src/InteractionMetrics/InteractionMetrics/Random.py +++ b/src/InteractionMetrics/InteractionMetrics/Random.py @@ -0,0 +1,11 @@ +import numpy as np + + +class RandomQuery: + def __init__(self, threshold): + self.threshold = threshold + self.random = np.random.uniform(0.0, 1.0, 1) + + def query(self): + return self.random > self.threshold + \ No newline at end of file diff --git a/src/InteractionMetrics/InteractionMetrics/Regular.py b/src/InteractionMetrics/InteractionMetrics/Regular.py new file mode 100644 index 0000000..7349859 --- /dev/null +++ b/src/InteractionMetrics/InteractionMetrics/Regular.py @@ -0,0 +1,12 @@ +class RegularQuery: + def __init__(self, regular, episode): + self.regular = int(regular) + self.counter = episode + + def query(self): + + if self.counter % self.regular == 0 and self.counter != 0: + return True + + else: + return False diff --git a/src/Tasks/Tasks/__init__.py b/src/Tasks/Tasks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/Tasks/package.xml b/src/Tasks/package.xml deleted file mode 100644 index 2c91b4f..0000000 --- a/src/Tasks/package.xml +++ /dev/null @@ -1,18 +0,0 @@ - - - - Tasks - 0.0.0 - TODO: Package description - niko - TODO: License declaration - - ament_copyright - ament_flake8 - ament_pep257 - python3-pytest - - - ament_python - - diff --git a/src/Tasks/resource/Tasks b/src/Tasks/resource/Tasks deleted file mode 100644 index e69de29..0000000 diff --git a/src/Tasks/setup.cfg b/src/Tasks/setup.cfg deleted file mode 100644 index 8d6e60f..0000000 --- a/src/Tasks/setup.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[develop] -script_dir=$base/lib/Tasks -[install] -install_scripts=$base/lib/Tasks diff --git a/src/Tasks/setup.py b/src/Tasks/setup.py deleted file mode 100644 index d5fc4a1..0000000 --- a/src/Tasks/setup.py +++ /dev/null @@ -1,25 +0,0 @@ -from setuptools import find_packages, setup - -package_name = 'Tasks' - -setup( - name=package_name, - version='0.0.0', - packages=find_packages(exclude=['test']), - data_files=[ - ('share/ament_index/resource_index/packages', - ['resource/' + package_name]), - ('share/' + package_name, ['package.xml']), - ], - install_requires=['setuptools'], - zip_safe=True, - maintainer='niko', - maintainer_email='nikolaus.feith@unileoben.ac.at', - description='TODO: Package description', - license='TODO: License declaration', - tests_require=['pytest'], - entry_points={ - 'console_scripts': [ - ], - }, -) diff --git a/src/Tasks/test/test_copyright.py b/src/Tasks/test/test_copyright.py deleted file mode 100644 index 97a3919..0000000 --- a/src/Tasks/test/test_copyright.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2015 Open Source Robotics Foundation, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ament_copyright.main import main -import pytest - - -# Remove the `skip` decorator once the source file(s) have a copyright header -@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') -@pytest.mark.copyright -@pytest.mark.linter -def test_copyright(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found errors' diff --git a/src/Tasks/test/test_flake8.py b/src/Tasks/test/test_flake8.py deleted file mode 100644 index 27ee107..0000000 --- a/src/Tasks/test/test_flake8.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2017 Open Source Robotics Foundation, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ament_flake8.main import main_with_errors -import pytest - - -@pytest.mark.flake8 -@pytest.mark.linter -def test_flake8(): - rc, errors = main_with_errors(argv=[]) - assert rc == 0, \ - 'Found %d code style errors / warnings:\n' % len(errors) + \ - '\n'.join(errors) diff --git a/src/Tasks/test/test_pep257.py b/src/Tasks/test/test_pep257.py deleted file mode 100644 index b234a38..0000000 --- a/src/Tasks/test/test_pep257.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2015 Open Source Robotics Foundation, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ament_pep257.main import main -import pytest - - -@pytest.mark.linter -@pytest.mark.pep257 -def test_pep257(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found code style errors / warnings'