diff --git a/src/active_bo_debugging/active_bo_debugging/__init__.py b/src/active_bo_debugging/active_bo_debugging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/active_bo_debugging/active_bo_debugging/active_bo_proxy.py b/src/active_bo_debugging/active_bo_debugging/active_bo_proxy.py new file mode 100644 index 0000000..d7d2eb3 --- /dev/null +++ b/src/active_bo_debugging/active_bo_debugging/active_bo_proxy.py @@ -0,0 +1,84 @@ +import rclpy +from rclpy.node import Node +from active_bo_msgs.srv import ActiveBO + +import time + + +class ActiveBOProxy(Node): + def __init__(self): + super().__init__('active_rl_proxy') + + self.client = self.create_client(ActiveBO, 'abo_proxy') + + while not self.client.wait_for_service(timeout_sec=10.0): + self.get_logger().info('Waiting for the active_bo_srv') + + self.server = self.create_service(ActiveBO, 'active_bo_srv', self.proxy_callback) + + def proxy_callback(self, request, response): + + self.get_logger().info(f'Received request:') + self.get_logger().info(f' nr_weights: {request.nr_weights}') + self.get_logger().info(f' max_steps: {request.max_steps}') + self.get_logger().info(f' nr_episodes: {request.nr_episodes}') + self.get_logger().info(f' nr_runs: {request.nr_runs}') + self.get_logger().info(f' acquisition_function: {request.acquisition_function}') + self.get_logger().info(f' epsilon: {request.epsilon}') + + # Forward the request to the original service server + future = self.client.call_async(request) + + timeout = 20 + start_time = time.time() + while not future.done(): + rclpy.spin_once(self, timeout_sec=0.1) + # self.get_logger().info(f'{future.result()}') + if time.time() - start_time > timeout: + self.get_logger().error('Service call timed out.') + break + + # self.executor.spin_until_future_complete(future) + + if future.result() is not None: + response.new_weights = future.result().new_weights + response.final_step = future.result().final_step + response.reward = future.result().reward + + self.get_logger().info(f'Sending response:') + self.get_logger().info(f' best_policy: {response.best_policy}') + self.get_logger().info(f' best_weights: {response.best_weights}') + self.get_logger().info(f' reward_mean: {response.reward_mean}') + self.get_logger().info(f' reward_std: {response.reward_std}') + # result = self.client.call(request) + # if result is not None: + # response.new_weights = result.new_weights + # response.final_step = result.final_step + # response.reward = result.reward + # + # self.get_logger().info(f'Sending response:') + # self.get_logger().info(f' best_policy: {response.best_policy}') + # self.get_logger().info(f' best_weights: {response.best_weights}') + # self.get_logger().info(f' reward_mean: {response.reward_mean}') + # self.get_logger().info(f' reward_std: {response.reward_std}') + else: + self.get_logger().error('Failed to call the original service') + + return response + + +def main(args=None): + rclpy.init(args=args) + service_proxy = ActiveBOProxy() + + try: + rclpy.spin(service_proxy) + except KeyboardInterrupt: + pass + + service_proxy.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/src/active_bo_debugging/active_bo_debugging/active_rl_proxy.py b/src/active_bo_debugging/active_bo_debugging/active_rl_proxy.py new file mode 100644 index 0000000..0a3dbba --- /dev/null +++ b/src/active_bo_debugging/active_bo_debugging/active_rl_proxy.py @@ -0,0 +1,77 @@ +import rclpy +from rclpy.node import Node +from active_bo_msgs.srv import ActiveRL + +import time + + +class ActiveRLProxy(Node): + def __init__(self): + super().__init__('active_rl_proxy') + + self.client = self.create_client(ActiveRL, 'arl_proxy') + + while not self.client.wait_for_service(timeout_sec=10.0): + self.get_logger().info('Waiting for the active_rl_srv') + + self.server = self.create_service(ActiveRL, 'active_rl_srv', self.proxy_callback) + + def proxy_callback(self, request, response): + + self.get_logger().info(f'Received request:') + self.get_logger().info(f' old_policy: {request.old_policy}') + self.get_logger().info(f' old_weights: {request.old_weights}') + + # Forward the request to the original service server + future = self.client.call_async(request) + + timeout = 10 + start_time = time.time() + while not future.done(): + rclpy.spin_once(self, timeout_sec=0.1) + self.get_logger().info(f'{future.result()}') + if time.time() - start_time > timeout: + self.get_logger().error('Service call timed out.') + break + # self.executor.spin_until_future_complete(future) + + if future.result() is not None: + response.new_weights = future.result().new_weights + response.final_step = future.result().final_step + response.reward = future.result().reward + + self.get_logger().info(f'Sending response:') + self.get_logger().info(f' new_weigths: {response.new_weights}') + self.get_logger().info(f' final_step: {response.final_step}') + self.get_logger().info(f' reward: {response.reward}') + # result = self.client.call(request) + # if result is not None: + # response.new_weights = result.new_weights + # response.final_step = result.final_step + # response.reward = result.reward + # + # self.get_logger().info(f'Sending response:') + # self.get_logger().info(f' new_weigths: {response.new_weights}') + # self.get_logger().info(f' final_step: {response.final_step}') + # self.get_logger().info(f' reward: {response.reward}') + else: + self.get_logger().error('Failed to call the original service') + + return response + + +def main(args=None): + rclpy.init(args=args) + service_proxy = ActiveRLProxy() + + try: + rclpy.spin(service_proxy) + except KeyboardInterrupt: + pass + + service_proxy.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/src/active_bo_ros/active_bo_ros/active_rl_test_node.py b/src/active_bo_debugging/active_bo_debugging/active_rl_test_node.py similarity index 100% rename from src/active_bo_ros/active_bo_ros/active_rl_test_node.py rename to src/active_bo_debugging/active_bo_debugging/active_rl_test_node.py diff --git a/src/active_bo_debugging/launch/debug_bo.launch.py b/src/active_bo_debugging/launch/debug_bo.launch.py new file mode 100755 index 0000000..729dc99 --- /dev/null +++ b/src/active_bo_debugging/launch/debug_bo.launch.py @@ -0,0 +1,57 @@ +from launch import LaunchDescription +from launch.actions import IncludeLaunchDescription +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch_ros.actions import Node + +from ament_index_python import get_package_share_directory +import os + + +def generate_launch_description(): + websocket_launch = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + os.path.join( + get_package_share_directory('active_bo_ros'), + 'rosbridge_server.launch.py' + ) + ) + ) + rl_launch = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + os.path.join( + get_package_share_directory('active_bo_ros'), + 'rl_service.launch.py' + ) + ) + ) + bo_launch = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + os.path.join( + get_package_share_directory('active_bo_ros'), + 'bo_service.launch.py' + ) + ) + ) + arl_launch = Node( + package='active_bo_ros', + executable='active_rl_srv', + name='active_rl_service', + remappings=[ + ('active_rl_srv', 'arl_proxy') + ], + ) + abo_launch = Node( + package='active_bo_ros', + executable='active_bo_srv', + name='active_bo_service', + remappings=[ + ('active_bo_srv', 'abo_proxy') + ], + ) + return LaunchDescription([ + websocket_launch, + rl_launch, + bo_launch, + arl_launch, + abo_launch + ]) diff --git a/src/active_bo_debugging/launch/debug_proxy.launch.py b/src/active_bo_debugging/launch/debug_proxy.launch.py new file mode 100755 index 0000000..2d1ac22 --- /dev/null +++ b/src/active_bo_debugging/launch/debug_proxy.launch.py @@ -0,0 +1,17 @@ +from launch import LaunchDescription +from launch_ros.actions import Node + + +def generate_launch_description(): + return LaunchDescription([ + Node( + package='active_bo_debugging', + executable='active_rl_proxy', + name='active_rl_proxy' + ), + Node( + package='active_bo_debugging', + executable='active_bo_proxy', + name='active_bo_proxy' + ), + ]) diff --git a/src/active_bo_debugging/package.xml b/src/active_bo_debugging/package.xml new file mode 100644 index 0000000..24f4b80 --- /dev/null +++ b/src/active_bo_debugging/package.xml @@ -0,0 +1,22 @@ + + + + active_bo_debugging + 0.0.0 + TODO: Package description + niko + TODO: License declaration + + example_interfaces + active_bo_msgs + rclpy + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/src/active_bo_debugging/resource/active_bo_debugging b/src/active_bo_debugging/resource/active_bo_debugging new file mode 100644 index 0000000..e69de29 diff --git a/src/active_bo_debugging/setup.cfg b/src/active_bo_debugging/setup.cfg new file mode 100644 index 0000000..2d556ff --- /dev/null +++ b/src/active_bo_debugging/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/active_bo_debugging +[install] +install_scripts=$base/lib/active_bo_debugging diff --git a/src/active_bo_debugging/setup.py b/src/active_bo_debugging/setup.py new file mode 100644 index 0000000..07f63db --- /dev/null +++ b/src/active_bo_debugging/setup.py @@ -0,0 +1,31 @@ +from setuptools import setup +import os +from glob import glob + +package_name = 'active_bo_debugging' + +setup( + name=package_name, + version='0.0.0', + packages=[package_name], + data_files=[ + ('share/ament_index/resource_index/packages', + ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + (os.path.join('share', package_name), glob('launch/*.launch.py')), + ], + install_requires=['setuptools'], + zip_safe=True, + maintainer='niko', + maintainer_email='nikolaus.feith@unileoben.ac.at', + description='TODO: Package description', + license='TODO: License declaration', + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + 'active_rl_proxy = active_bo_debugging.active_rl_proxy:main', + 'active_bo_proxy = active_bo_debugging.active_bo_proxy:main', + 'active_rl_test = active_bo_ros.active_rl_test_node:main', + ], + }, +) diff --git a/src/active_bo_debugging/test/test_copyright.py b/src/active_bo_debugging/test/test_copyright.py new file mode 100644 index 0000000..97a3919 --- /dev/null +++ b/src/active_bo_debugging/test/test_copyright.py @@ -0,0 +1,25 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_copyright.main import main +import pytest + + +# Remove the `skip` decorator once the source file(s) have a copyright header +@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') +@pytest.mark.copyright +@pytest.mark.linter +def test_copyright(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found errors' diff --git a/src/active_bo_debugging/test/test_flake8.py b/src/active_bo_debugging/test/test_flake8.py new file mode 100644 index 0000000..27ee107 --- /dev/null +++ b/src/active_bo_debugging/test/test_flake8.py @@ -0,0 +1,25 @@ +# Copyright 2017 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_flake8.main import main_with_errors +import pytest + + +@pytest.mark.flake8 +@pytest.mark.linter +def test_flake8(): + rc, errors = main_with_errors(argv=[]) + assert rc == 0, \ + 'Found %d code style errors / warnings:\n' % len(errors) + \ + '\n'.join(errors) diff --git a/src/active_bo_debugging/test/test_pep257.py b/src/active_bo_debugging/test/test_pep257.py new file mode 100644 index 0000000..b234a38 --- /dev/null +++ b/src/active_bo_debugging/test/test_pep257.py @@ -0,0 +1,23 @@ +# Copyright 2015 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ament_pep257.main import main +import pytest + + +@pytest.mark.linter +@pytest.mark.pep257 +def test_pep257(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found code style errors / warnings' diff --git a/src/active_bo_ros/active_bo_ros/active_bo_service.py b/src/active_bo_ros/active_bo_ros/active_bo_service.py index 20ec1e8..40935bf 100644 --- a/src/active_bo_ros/active_bo_ros/active_bo_service.py +++ b/src/active_bo_ros/active_bo_ros/active_bo_service.py @@ -10,6 +10,7 @@ from active_bo_ros.BayesianOptimization.BayesianOptimization import BayesianOpti from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv import numpy as np +import time class ActiveBOService(Node): @@ -69,10 +70,17 @@ class ActiveBOService(Node): future_rl = self.active_rl_client.call_async(arl_request) self.get_logger().info(str(future_rl)) + timeout = 10 + start_time = time.time() while not future_rl.done(): - rclpy.spin_once(self) - self.get_logger().info('waiting for response!') + rclpy.spin_once(self, timeout_sec=0.1) + # self.get_logger().info(f'{future_rl.result()}') + if time.time() - start_time > timeout: + self.get_logger().error('Service call timed out.') + break + # self.executor.spin_until_future_complete(future_rl) + # arl_response = self.active_rl_client.call(arl_request) self.get_logger().info('Received: Active RL') try: @@ -81,8 +89,6 @@ class ActiveBOService(Node): except Exception as e: self.get_logger().error('active RL Service failed %r' % (e,)) - future_rl = None - # BO part else: x_next = BO.next_observation() diff --git a/src/active_bo_ros/active_bo_ros/active_rl_service.py b/src/active_bo_ros/active_bo_ros/active_rl_service.py index 13d305c..6c3e8ca 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_service.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_service.py @@ -11,49 +11,42 @@ from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous import numpy as np import time +import copy class ActiveRLService(Node): def __init__(self): super().__init__('active_rl_service') srv_callback_group = ReentrantCallbackGroup() - sub_callback_group = ReentrantCallbackGroup() + topic_callback_group = ReentrantCallbackGroup() self.srv = self.create_service(ActiveRL, 'active_rl_srv', self.active_rl_callback, callback_group=srv_callback_group) - self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1, callback_group=srv_callback_group) + self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1, callback_group=topic_callback_group) self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1, - callback_group=srv_callback_group) + callback_group=topic_callback_group) self.eval_sub = self.create_subscription(ActiveRLEval, 'active_rl_eval_response', self.active_rl_eval_callback, 1, - callback_group=srv_callback_group) + callback_group=topic_callback_group) + self.eval_response_received = False self.eval_response = None - self.eval_response_received_first = False self.env = Continuous_MountainCarEnv(render_mode='rgb_array') self.distance_penalty = 0 def active_rl_eval_callback(self, response): - # if not self.eval_response_received_first: - # self.eval_response_received_first = True - # self.get_logger().info('/active_rl_eval_response connected!') - # else: - # self.eval_response = response - # self.eval_response_received = True self.eval_response = response self.eval_response_received = True - - def active_rl_callback(self, request, response): self.get_logger().info('Active RL: Called') @@ -100,12 +93,14 @@ class ActiveRLService(Node): self.get_logger().info('Enter new solution!') self.eval_pub.publish(eval_request) - while not self.eval_response_received: - rclpy.spin_once(self) + while rclpy.ok(): + rclpy.spin_once(self, timeout_sec=0.1) + if self.eval_response_received: + break self.get_logger().info('Topic responded!') - new_policy = self.eval_response.policy - new_weights = self.eval_response.weights + new_policy = copy.deepcopy(self.eval_response.policy) + new_weights = copy.deepcopy(self.eval_response.weights) self.eval_response_received = False self.eval_response = None @@ -149,6 +144,8 @@ class ActiveRLService(Node): response.reward = reward response.final_step = step_count + self.get_logger().info(f'{response}') + return response diff --git a/src/active_bo_ros/setup.py b/src/active_bo_ros/setup.py index 14aa018..7b03d7f 100644 --- a/src/active_bo_ros/setup.py +++ b/src/active_bo_ros/setup.py @@ -32,7 +32,6 @@ setup( 'bo_srv = active_bo_ros.bo_service:main', 'active_bo_srv = active_bo_ros.active_bo_service:main', 'active_rl_srv = active_bo_ros.active_rl_service:main', - 'active_rl_test = active_bo_ros.active_rl_test_node:main', ], }, )