From 52fb25844bb5136d11d181b919c1b7d6e6d31750 Mon Sep 17 00:00:00 2001 From: "nikolaus.feith" Date: Fri, 24 Feb 2023 16:32:55 +0100 Subject: [PATCH] Policy Service complete --- src/active_bo_msgs/CMakeLists.txt | 14 ++--- src/active_bo_msgs/package.xml | 6 +++ src/active_bo_msgs/srv/WeightToPolicy.srv | 4 ++ .../AcquisitionFunctions/ConfidenceBound.py | 13 +++++ .../ExpectedImprovement.py | 15 ++++++ .../ProbabilityOfImprovement.py | 15 ++++++ .../PolicyModel/GaussianRBFModel.py | 53 +++++++++++++++++++ .../active_bo_ros/PolicyModel/__init__.py | 0 .../active_bo_ros/policy_service.py | 37 +++++++++++++ .../launch/policy_service.launch.py | 11 ++++ src/active_bo_ros/package.xml | 4 ++ src/active_bo_ros/setup.py | 8 ++- 12 files changed, 169 insertions(+), 11 deletions(-) create mode 100644 src/active_bo_msgs/srv/WeightToPolicy.srv create mode 100644 src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ConfidenceBound.py create mode 100644 src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ExpectedImprovement.py create mode 100644 src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ProbabilityOfImprovement.py create mode 100644 src/active_bo_ros/active_bo_ros/PolicyModel/GaussianRBFModel.py create mode 100644 src/active_bo_ros/active_bo_ros/PolicyModel/__init__.py create mode 100644 src/active_bo_ros/active_bo_ros/policy_service.py create mode 100755 src/active_bo_ros/launch/policy_service.launch.py diff --git a/src/active_bo_msgs/CMakeLists.txt b/src/active_bo_msgs/CMakeLists.txt index 7e1f655..b100134 100644 --- a/src/active_bo_msgs/CMakeLists.txt +++ b/src/active_bo_msgs/CMakeLists.txt @@ -17,18 +17,14 @@ endif() # find dependencies find_package(ament_cmake REQUIRED) -# uncomment the following section in order to fill in -# further dependencies manually. -# find_package( REQUIRED) +find_package(rosidl_default_generators REQUIRED) + +rosidl_generate_interfaces(${PROJECT_NAME} + "srv/WeightToPolicy.srv" +) if(BUILD_TESTING) find_package(ament_lint_auto REQUIRED) - # the following line skips the linter which checks for copyrights - # uncomment the line when a copyright and license is not present in all source files - #set(ament_cmake_copyright_FOUND TRUE) - # the following line skips cpplint (only works in a git repo) - # uncomment the line when this package is not in a git repo - #set(ament_cmake_cpplint_FOUND TRUE) ament_lint_auto_find_test_dependencies() endif() diff --git a/src/active_bo_msgs/package.xml b/src/active_bo_msgs/package.xml index dc65d69..fd8e609 100644 --- a/src/active_bo_msgs/package.xml +++ b/src/active_bo_msgs/package.xml @@ -9,6 +9,12 @@ ament_cmake + rosidl_default_generators + + rosidl_default_runtime + + rosidl_interface_packages + ament_lint_auto ament_lint_common diff --git a/src/active_bo_msgs/srv/WeightToPolicy.srv b/src/active_bo_msgs/srv/WeightToPolicy.srv new file mode 100644 index 0000000..f7f3316 --- /dev/null +++ b/src/active_bo_msgs/srv/WeightToPolicy.srv @@ -0,0 +1,4 @@ +float32[] weights +uint16 nr_steps +--- +float32[] policy \ No newline at end of file diff --git a/src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ConfidenceBound.py b/src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ConfidenceBound.py new file mode 100644 index 0000000..48fd5b2 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ConfidenceBound.py @@ -0,0 +1,13 @@ +import numpy as np + +def ConfidenceBound(gp, X, nr_test, nr_weights, lam=1.2, seed=None, lower=-1.0, upper=1.0): + y_hat = gp.predict(X) + best_y = max(y_hat) + rng = np.random.default_rng(seed=seed) + X_test = rng.uniform(lower, upper, (nr_test, nr_weights)) + mu, sigma = gp.predict(X_test, return_std=True) + cb = mu + lam * sigma + + idx = np.argmax(cb) + X_next = X_test[idx, :] + return X_next diff --git a/src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ExpectedImprovement.py b/src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ExpectedImprovement.py new file mode 100644 index 0000000..6b17a06 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ExpectedImprovement.py @@ -0,0 +1,15 @@ +import numpy as np +from scipy.stats import norm + +def ExpectedImprovement(gp, X, nr_test, nr_weights, kappa=2.576, seed=None, lower=-1.0, upper=1.0): + y_hat = gp.predict(X) + best_y = max(y_hat) + rng = np.random.default_rng(seed=seed) + X_test = rng.uniform(lower, upper, (nr_test, nr_weights)) + mu, sigma = gp.predict(X_test, return_std=True) + z = (mu - best_y - kappa) / sigma + ei = (mu - best_y - kappa) * norm.cdf(z) + sigma * norm.pdf(z) + + idx = np.argmax(ei) + X_next = X_test[idx, :] + return X_next diff --git a/src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ProbabilityOfImprovement.py b/src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ProbabilityOfImprovement.py new file mode 100644 index 0000000..3253438 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/AcquisitionFunctions/ProbabilityOfImprovement.py @@ -0,0 +1,15 @@ +import numpy as np +from scipy.stats import norm + +def ProbabilityOfImprovement(gp, X, nr_test, nr_weights, kappa=2.576, seed=None, lower=-1.0, upper=1.0): + y_hat = gp.predict(X) + best_y = max(y_hat) + rng = np.random.default_rng(seed=seed) + X_test = rng.uniform(lower, upper, (nr_test, nr_weights)) + mu, sigma = gp.predict(X_test, return_std=True) + z = (mu - best_y - kappa) / sigma + pi = norm.cdf(z) + + idx = np.argmax(pi) + X_next = X_test[idx, :] + return X_next diff --git a/src/active_bo_ros/active_bo_ros/PolicyModel/GaussianRBFModel.py b/src/active_bo_ros/active_bo_ros/PolicyModel/GaussianRBFModel.py new file mode 100644 index 0000000..ed791ef --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/PolicyModel/GaussianRBFModel.py @@ -0,0 +1,53 @@ +import numpy as np +class GaussianRBF: + def __init__(self, nr_weights, nr_steps, seed=None, lowerb=-1.0, upperb=1.0): + self.nr_weights = nr_weights + self.nr_steps = nr_steps + self.weights = None + self.policy = None + self.mean = np.linspace(0, self.nr_steps, self.nr_weights) + if nr_weights > 1: + self.std = self.mean[1] / (2 * np.sqrt(2 * np.log(2))) # Full width at half maximum + else: + self.std = self.nr_steps / 2 + + self.rng = np.random.default_rng(seed=seed) + self.low = lowerb + self.upper = upperb + + self.reset() + + def reset(self): + self.weights = np.zeros((self.nr_weights, 1)) + self.policy = np.zeros((self.nr_steps, 1)) + + def random_policy(self): + self.weights = self.rng.uniform(self.low, self.upper, self.nr_weights) + + def policy_rollout(self): + self.policy = np.zeros((self.nr_steps, 1)) + for i in range(self.nr_steps): + for j in range(self.nr_weights): + base_fun = np.exp(-0.5*(i - self.mean[j])**2 / self.std**2) + self.policy[i] += base_fun * self.weights[j] + + return self.policy + + # def plot_policy(self, finished=np.NAN): + # x = np.linspace(0, self.nr_steps, self.nr_steps) + # plt.plot(x, self.trajectory) + # if finished != np.NAN: + # plt.vlines(finished, -1, 1, colors='red') + # # for i in self.mean: + # # gaussian = np.exp(-0.5 * (x - i)**2 / self.std**2) + # # plt.plot(x, gaussian) + + +def main(): + policy = GaussianRBFModel(1, 50) + policy.random_policy() + policy.policy_rollout() + print(policy.weights) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/active_bo_ros/active_bo_ros/PolicyModel/__init__.py b/src/active_bo_ros/active_bo_ros/PolicyModel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/active_bo_ros/active_bo_ros/policy_service.py b/src/active_bo_ros/active_bo_ros/policy_service.py new file mode 100644 index 0000000..73f0de6 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/policy_service.py @@ -0,0 +1,37 @@ +from active_bo_msgs.srv import WeightToPolicy + +import rclpy +from rclpy.node import Node + +from active_bo_ros.PolicyModel.GaussianRBFModel import GaussianRBF +import numpy as np + +class PolicyService(Node): + def __init__(self): + super().__init__('policy_service') + self.srv = self.create_service(WeightToPolicy, 'policy_srv', self.policy_callback) + + @staticmethod + def policy_callback(request, response): + weights = request.weights + weight_len = len(weights) + nr_steps = request.nr_steps + + policy = GaussianRBF(weight_len, nr_steps) + policy.weights = weights + policy.policy_rollout() + + response.policy = policy.policy.flatten().tolist() + return response + +def main(args=None): + rclpy.init(args=args) + + policy_service = PolicyService() + + rclpy.spin(policy_service) + + rclpy.shutdown() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/active_bo_ros/launch/policy_service.launch.py b/src/active_bo_ros/launch/policy_service.launch.py new file mode 100755 index 0000000..2004daf --- /dev/null +++ b/src/active_bo_ros/launch/policy_service.launch.py @@ -0,0 +1,11 @@ +from launch import LaunchDescription +from launch_ros.actions import Node + +def generate_launch_description(): + return LaunchDescription([ + Node( + package='active_bo_ros', + executable='policy_srv', + name='policy_srv' + ), + ]) \ No newline at end of file diff --git a/src/active_bo_ros/package.xml b/src/active_bo_ros/package.xml index 960d325..2cbe236 100644 --- a/src/active_bo_ros/package.xml +++ b/src/active_bo_ros/package.xml @@ -7,6 +7,10 @@ cpsfeith TODO: License declaration + example_interfaces + active_bo_msgs + rclpy + ament_copyright ament_flake8 ament_pep257 diff --git a/src/active_bo_ros/setup.py b/src/active_bo_ros/setup.py index da75615..8154794 100644 --- a/src/active_bo_ros/setup.py +++ b/src/active_bo_ros/setup.py @@ -1,17 +1,20 @@ from setuptools import setup +import os +from glob import glob package_name = 'active_bo_ros' setup( name=package_name, version='0.0.0', - packages=[package_name], + packages=[package_name, package_name+'/PolicyModel'], data_files=[ ('share/ament_index/resource_index/packages', ['resource/' + package_name]), ('share/' + package_name, ['package.xml']), + (os.path.join('share', package_name), glob('launch/*.launch.py')), ], - install_requires=['setuptools'], + install_requires=['setuptools', 'numpy'], zip_safe=True, maintainer='cpsfeith', maintainer_email='nikolaus.feith@unileoben.ac.at', @@ -20,6 +23,7 @@ setup( tests_require=['pytest'], entry_points={ 'console_scripts': [ + 'policy_srv = active_bo_ros.policy_service:main' ], }, )