tried out async methods -> not working

added proxy services to log the service communication
This commit is contained in:
Niko Feith 2023-04-03 15:22:14 +02:00
parent 37dcf957f4
commit e70459dc6e
16 changed files with 389 additions and 22 deletions

View File

@ -0,0 +1,84 @@
import rclpy
from rclpy.node import Node
from active_bo_msgs.srv import ActiveBO
import time
class ActiveBOProxy(Node):
def __init__(self):
super().__init__('active_rl_proxy')
self.client = self.create_client(ActiveBO, 'abo_proxy')
while not self.client.wait_for_service(timeout_sec=10.0):
self.get_logger().info('Waiting for the active_bo_srv')
self.server = self.create_service(ActiveBO, 'active_bo_srv', self.proxy_callback)
def proxy_callback(self, request, response):
self.get_logger().info(f'Received request:')
self.get_logger().info(f' nr_weights: {request.nr_weights}')
self.get_logger().info(f' max_steps: {request.max_steps}')
self.get_logger().info(f' nr_episodes: {request.nr_episodes}')
self.get_logger().info(f' nr_runs: {request.nr_runs}')
self.get_logger().info(f' acquisition_function: {request.acquisition_function}')
self.get_logger().info(f' epsilon: {request.epsilon}')
# Forward the request to the original service server
future = self.client.call_async(request)
timeout = 20
start_time = time.time()
while not future.done():
rclpy.spin_once(self, timeout_sec=0.1)
# self.get_logger().info(f'{future.result()}')
if time.time() - start_time > timeout:
self.get_logger().error('Service call timed out.')
break
# self.executor.spin_until_future_complete(future)
if future.result() is not None:
response.new_weights = future.result().new_weights
response.final_step = future.result().final_step
response.reward = future.result().reward
self.get_logger().info(f'Sending response:')
self.get_logger().info(f' best_policy: {response.best_policy}')
self.get_logger().info(f' best_weights: {response.best_weights}')
self.get_logger().info(f' reward_mean: {response.reward_mean}')
self.get_logger().info(f' reward_std: {response.reward_std}')
# result = self.client.call(request)
# if result is not None:
# response.new_weights = result.new_weights
# response.final_step = result.final_step
# response.reward = result.reward
#
# self.get_logger().info(f'Sending response:')
# self.get_logger().info(f' best_policy: {response.best_policy}')
# self.get_logger().info(f' best_weights: {response.best_weights}')
# self.get_logger().info(f' reward_mean: {response.reward_mean}')
# self.get_logger().info(f' reward_std: {response.reward_std}')
else:
self.get_logger().error('Failed to call the original service')
return response
def main(args=None):
rclpy.init(args=args)
service_proxy = ActiveBOProxy()
try:
rclpy.spin(service_proxy)
except KeyboardInterrupt:
pass
service_proxy.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,77 @@
import rclpy
from rclpy.node import Node
from active_bo_msgs.srv import ActiveRL
import time
class ActiveRLProxy(Node):
def __init__(self):
super().__init__('active_rl_proxy')
self.client = self.create_client(ActiveRL, 'arl_proxy')
while not self.client.wait_for_service(timeout_sec=10.0):
self.get_logger().info('Waiting for the active_rl_srv')
self.server = self.create_service(ActiveRL, 'active_rl_srv', self.proxy_callback)
def proxy_callback(self, request, response):
self.get_logger().info(f'Received request:')
self.get_logger().info(f' old_policy: {request.old_policy}')
self.get_logger().info(f' old_weights: {request.old_weights}')
# Forward the request to the original service server
future = self.client.call_async(request)
timeout = 10
start_time = time.time()
while not future.done():
rclpy.spin_once(self, timeout_sec=0.1)
self.get_logger().info(f'{future.result()}')
if time.time() - start_time > timeout:
self.get_logger().error('Service call timed out.')
break
# self.executor.spin_until_future_complete(future)
if future.result() is not None:
response.new_weights = future.result().new_weights
response.final_step = future.result().final_step
response.reward = future.result().reward
self.get_logger().info(f'Sending response:')
self.get_logger().info(f' new_weigths: {response.new_weights}')
self.get_logger().info(f' final_step: {response.final_step}')
self.get_logger().info(f' reward: {response.reward}')
# result = self.client.call(request)
# if result is not None:
# response.new_weights = result.new_weights
# response.final_step = result.final_step
# response.reward = result.reward
#
# self.get_logger().info(f'Sending response:')
# self.get_logger().info(f' new_weigths: {response.new_weights}')
# self.get_logger().info(f' final_step: {response.final_step}')
# self.get_logger().info(f' reward: {response.reward}')
else:
self.get_logger().error('Failed to call the original service')
return response
def main(args=None):
rclpy.init(args=args)
service_proxy = ActiveRLProxy()
try:
rclpy.spin(service_proxy)
except KeyboardInterrupt:
pass
service_proxy.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,57 @@
from launch import LaunchDescription
from launch.actions import IncludeLaunchDescription
from launch.launch_description_sources import PythonLaunchDescriptionSource
from launch_ros.actions import Node
from ament_index_python import get_package_share_directory
import os
def generate_launch_description():
websocket_launch = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
os.path.join(
get_package_share_directory('active_bo_ros'),
'rosbridge_server.launch.py'
)
)
)
rl_launch = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
os.path.join(
get_package_share_directory('active_bo_ros'),
'rl_service.launch.py'
)
)
)
bo_launch = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
os.path.join(
get_package_share_directory('active_bo_ros'),
'bo_service.launch.py'
)
)
)
arl_launch = Node(
package='active_bo_ros',
executable='active_rl_srv',
name='active_rl_service',
remappings=[
('active_rl_srv', 'arl_proxy')
],
)
abo_launch = Node(
package='active_bo_ros',
executable='active_bo_srv',
name='active_bo_service',
remappings=[
('active_bo_srv', 'abo_proxy')
],
)
return LaunchDescription([
websocket_launch,
rl_launch,
bo_launch,
arl_launch,
abo_launch
])

View File

@ -0,0 +1,17 @@
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
Node(
package='active_bo_debugging',
executable='active_rl_proxy',
name='active_rl_proxy'
),
Node(
package='active_bo_debugging',
executable='active_bo_proxy',
name='active_bo_proxy'
),
])

View File

@ -0,0 +1,22 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>active_bo_debugging</name>
<version>0.0.0</version>
<description>TODO: Package description</description>
<maintainer email="nikolaus.feith@unileoben.ac.at">niko</maintainer>
<license>TODO: License declaration</license>
<exec_depend>example_interfaces</exec_depend>
<exec_depend>active_bo_msgs</exec_depend>
<exec_depend>rclpy</exec_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/active_bo_debugging
[install]
install_scripts=$base/lib/active_bo_debugging

View File

@ -0,0 +1,31 @@
from setuptools import setup
import os
from glob import glob
package_name = 'active_bo_debugging'
setup(
name=package_name,
version='0.0.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name), glob('launch/*.launch.py')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='niko',
maintainer_email='nikolaus.feith@unileoben.ac.at',
description='TODO: Package description',
license='TODO: License declaration',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'active_rl_proxy = active_bo_debugging.active_rl_proxy:main',
'active_bo_proxy = active_bo_debugging.active_bo_proxy:main',
'active_rl_test = active_bo_ros.active_rl_test_node:main',
],
},
)

View File

@ -0,0 +1,25 @@
# Copyright 2015 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ament_copyright.main import main
import pytest
# Remove the `skip` decorator once the source file(s) have a copyright header
@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.')
@pytest.mark.copyright
@pytest.mark.linter
def test_copyright():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found errors'

View File

@ -0,0 +1,25 @@
# Copyright 2017 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ament_flake8.main import main_with_errors
import pytest
@pytest.mark.flake8
@pytest.mark.linter
def test_flake8():
rc, errors = main_with_errors(argv=[])
assert rc == 0, \
'Found %d code style errors / warnings:\n' % len(errors) + \
'\n'.join(errors)

View File

@ -0,0 +1,23 @@
# Copyright 2015 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ament_pep257.main import main
import pytest
@pytest.mark.linter
@pytest.mark.pep257
def test_pep257():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found code style errors / warnings'

View File

@ -10,6 +10,7 @@ from active_bo_ros.BayesianOptimization.BayesianOptimization import BayesianOpti
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
import numpy as np import numpy as np
import time
class ActiveBOService(Node): class ActiveBOService(Node):
@ -69,10 +70,17 @@ class ActiveBOService(Node):
future_rl = self.active_rl_client.call_async(arl_request) future_rl = self.active_rl_client.call_async(arl_request)
self.get_logger().info(str(future_rl)) self.get_logger().info(str(future_rl))
timeout = 10
start_time = time.time()
while not future_rl.done(): while not future_rl.done():
rclpy.spin_once(self) rclpy.spin_once(self, timeout_sec=0.1)
self.get_logger().info('waiting for response!') # self.get_logger().info(f'{future_rl.result()}')
if time.time() - start_time > timeout:
self.get_logger().error('Service call timed out.')
break
# self.executor.spin_until_future_complete(future_rl)
# arl_response = self.active_rl_client.call(arl_request)
self.get_logger().info('Received: Active RL') self.get_logger().info('Received: Active RL')
try: try:
@ -81,8 +89,6 @@ class ActiveBOService(Node):
except Exception as e: except Exception as e:
self.get_logger().error('active RL Service failed %r' % (e,)) self.get_logger().error('active RL Service failed %r' % (e,))
future_rl = None
# BO part # BO part
else: else:
x_next = BO.next_observation() x_next = BO.next_observation()

View File

@ -11,49 +11,42 @@ from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous
import numpy as np import numpy as np
import time import time
import copy
class ActiveRLService(Node): class ActiveRLService(Node):
def __init__(self): def __init__(self):
super().__init__('active_rl_service') super().__init__('active_rl_service')
srv_callback_group = ReentrantCallbackGroup() srv_callback_group = ReentrantCallbackGroup()
sub_callback_group = ReentrantCallbackGroup() topic_callback_group = ReentrantCallbackGroup()
self.srv = self.create_service(ActiveRL, self.srv = self.create_service(ActiveRL,
'active_rl_srv', 'active_rl_srv',
self.active_rl_callback, self.active_rl_callback,
callback_group=srv_callback_group) callback_group=srv_callback_group)
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1, callback_group=srv_callback_group) self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1, callback_group=topic_callback_group)
self.eval_pub = self.create_publisher(ActiveRLEval, self.eval_pub = self.create_publisher(ActiveRLEval,
'active_rl_eval_request', 'active_rl_eval_request',
1, 1,
callback_group=srv_callback_group) callback_group=topic_callback_group)
self.eval_sub = self.create_subscription(ActiveRLEval, self.eval_sub = self.create_subscription(ActiveRLEval,
'active_rl_eval_response', 'active_rl_eval_response',
self.active_rl_eval_callback, self.active_rl_eval_callback,
1, 1,
callback_group=srv_callback_group) callback_group=topic_callback_group)
self.eval_response_received = False self.eval_response_received = False
self.eval_response = None self.eval_response = None
self.eval_response_received_first = False
self.env = Continuous_MountainCarEnv(render_mode='rgb_array') self.env = Continuous_MountainCarEnv(render_mode='rgb_array')
self.distance_penalty = 0 self.distance_penalty = 0
def active_rl_eval_callback(self, response): def active_rl_eval_callback(self, response):
# if not self.eval_response_received_first:
# self.eval_response_received_first = True
# self.get_logger().info('/active_rl_eval_response connected!')
# else:
# self.eval_response = response
# self.eval_response_received = True
self.eval_response = response self.eval_response = response
self.eval_response_received = True self.eval_response_received = True
def active_rl_callback(self, request, response): def active_rl_callback(self, request, response):
self.get_logger().info('Active RL: Called') self.get_logger().info('Active RL: Called')
@ -100,12 +93,14 @@ class ActiveRLService(Node):
self.get_logger().info('Enter new solution!') self.get_logger().info('Enter new solution!')
self.eval_pub.publish(eval_request) self.eval_pub.publish(eval_request)
while not self.eval_response_received: while rclpy.ok():
rclpy.spin_once(self) rclpy.spin_once(self, timeout_sec=0.1)
if self.eval_response_received:
break
self.get_logger().info('Topic responded!') self.get_logger().info('Topic responded!')
new_policy = self.eval_response.policy new_policy = copy.deepcopy(self.eval_response.policy)
new_weights = self.eval_response.weights new_weights = copy.deepcopy(self.eval_response.weights)
self.eval_response_received = False self.eval_response_received = False
self.eval_response = None self.eval_response = None
@ -149,6 +144,8 @@ class ActiveRLService(Node):
response.reward = reward response.reward = reward
response.final_step = step_count response.final_step = step_count
self.get_logger().info(f'{response}')
return response return response

View File

@ -32,7 +32,6 @@ setup(
'bo_srv = active_bo_ros.bo_service:main', 'bo_srv = active_bo_ros.bo_service:main',
'active_bo_srv = active_bo_ros.active_bo_service:main', 'active_bo_srv = active_bo_ros.active_bo_service:main',
'active_rl_srv = active_bo_ros.active_rl_service:main', 'active_rl_srv = active_bo_ros.active_rl_service:main',
'active_rl_test = active_bo_ros.active_rl_test_node:main',
], ],
}, },
) )