From 37dcf957f4b6b02f0befcb5fde333ff915f7d5c8 Mon Sep 17 00:00:00 2001 From: Niko Date: Wed, 29 Mar 2023 17:46:29 +0200 Subject: [PATCH] test for active rl --- src/active_bo_msgs/CMakeLists.txt | 1 + .../active_bo_ros/active_bo_service.py | 1 + .../active_bo_ros/active_rl_service.py | 12 +++-- .../active_bo_ros/active_rl_test_node.py | 45 +++++++++++++++++++ src/active_bo_ros/setup.py | 1 + 5 files changed, 56 insertions(+), 4 deletions(-) create mode 100644 src/active_bo_ros/active_bo_ros/active_rl_test_node.py diff --git a/src/active_bo_msgs/CMakeLists.txt b/src/active_bo_msgs/CMakeLists.txt index bbf7066..60fb753 100644 --- a/src/active_bo_msgs/CMakeLists.txt +++ b/src/active_bo_msgs/CMakeLists.txt @@ -26,6 +26,7 @@ rosidl_generate_interfaces(${PROJECT_NAME} "srv/ActiveBO.srv" "srv/ActiveRL.srv" "srv/ActiveRLEval.srv" + "msg/ActiveRLEval.msg" "msg/ImageFeedback.msg" ) diff --git a/src/active_bo_ros/active_bo_ros/active_bo_service.py b/src/active_bo_ros/active_bo_ros/active_bo_service.py index c1c1af5..20ec1e8 100644 --- a/src/active_bo_ros/active_bo_ros/active_bo_service.py +++ b/src/active_bo_ros/active_bo_ros/active_bo_service.py @@ -67,6 +67,7 @@ class ActiveBOService(Node): arl_request.old_weights = old_weights.tolist() self.get_logger().info('Calling: Active RL') future_rl = self.active_rl_client.call_async(arl_request) + self.get_logger().info(str(future_rl)) while not future_rl.done(): rclpy.spin_once(self) diff --git a/src/active_bo_ros/active_bo_ros/active_rl_service.py b/src/active_bo_ros/active_bo_ros/active_rl_service.py index b28a7ab..13d305c 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_service.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_service.py @@ -13,7 +13,6 @@ import numpy as np import time - class ActiveRLService(Node): def __init__(self): super().__init__('active_rl_service') @@ -25,14 +24,17 @@ class ActiveRLService(Node): self.active_rl_callback, callback_group=srv_callback_group) - self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1) + self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1, callback_group=srv_callback_group) - self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1) + self.eval_pub = self.create_publisher(ActiveRLEval, + 'active_rl_eval_request', + 1, + callback_group=srv_callback_group) self.eval_sub = self.create_subscription(ActiveRLEval, 'active_rl_eval_response', self.active_rl_eval_callback, 1, - callback_group=sub_callback_group) + callback_group=srv_callback_group) self.eval_response_received = False self.eval_response = None self.eval_response_received_first = False @@ -50,6 +52,8 @@ class ActiveRLService(Node): self.eval_response = response self.eval_response_received = True + + def active_rl_callback(self, request, response): self.get_logger().info('Active RL: Called') diff --git a/src/active_bo_ros/active_bo_ros/active_rl_test_node.py b/src/active_bo_ros/active_bo_ros/active_rl_test_node.py new file mode 100644 index 0000000..570e1f8 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/active_rl_test_node.py @@ -0,0 +1,45 @@ +from active_bo_msgs.srv import ActiveRL + +import rclpy +from rclpy.node import Node + +import numpy as np +import time + + +class ActiveRLTest(Node): + def __init__(self): + super().__init__('active_rl_test') + + self.client = self.create_client(ActiveRL, 'active_rl_srv') + + self.main_loop = self.create_timer(20.0, self.main_callback) + + def main_callback(self): + + random_policy = np.random.uniform(-1.0, 1.0, 100).tolist() + random_weights = np.random.uniform(-1.0, 1.0, 5).tolist() + + self.get_logger().info(str(random_policy)) + + rl_request = ActiveRL.Request() + rl_request.old_policy = random_policy + rl_request.old_weights = random_weights + + future = self.client.call_async(rl_request) + + self.get_logger().info(str(future)) + + +def main(args=None): + rclpy.init(args=args) + + active_rl_test = ActiveRLTest() + + rclpy.spin(active_rl_test) + + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/src/active_bo_ros/setup.py b/src/active_bo_ros/setup.py index 7b03d7f..14aa018 100644 --- a/src/active_bo_ros/setup.py +++ b/src/active_bo_ros/setup.py @@ -32,6 +32,7 @@ setup( 'bo_srv = active_bo_ros.bo_service:main', 'active_bo_srv = active_bo_ros.active_bo_service:main', 'active_rl_srv = active_bo_ros.active_rl_service:main', + 'active_rl_test = active_bo_ros.active_rl_test_node:main', ], }, )