From 10f938cc13fb4a543f7522467adb6d7c03267d90 Mon Sep 17 00:00:00 2001 From: Niko Date: Mon, 27 Mar 2023 16:57:11 +0200 Subject: [PATCH] Works for one use input --- .../active_bo_ros/active_bo_service.py | 39 ++++++--- .../active_bo_ros/active_rl_service.py | 79 +++++++++---------- src/active_bo_ros/active_bo_ros/rl_service.py | 2 + 3 files changed, 68 insertions(+), 52 deletions(-) diff --git a/src/active_bo_ros/active_bo_ros/active_bo_service.py b/src/active_bo_ros/active_bo_ros/active_bo_service.py index 87902b2..c1c1af5 100644 --- a/src/active_bo_ros/active_bo_ros/active_bo_service.py +++ b/src/active_bo_ros/active_bo_ros/active_bo_service.py @@ -4,6 +4,8 @@ from active_bo_msgs.srv import ActiveRL import rclpy from rclpy.node import Node +from rclpy.callback_groups import ReentrantCallbackGroup + from active_bo_ros.BayesianOptimization.BayesianOptimization import BayesianOptimization from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv @@ -13,11 +15,18 @@ import numpy as np class ActiveBOService(Node): def __init__(self): super().__init__('active_bo_service') - self.active_bo_srv = self.create_service(ActiveBO, 'active_bo_srv', self.active_bo_callback) - self.active_rl_client = self.create_client(ActiveRL, 'active_rl_srv') + bo_callback_group = ReentrantCallbackGroup() + rl_callback_group = ReentrantCallbackGroup() - self.rl_trigger_ = False + self.srv = self.create_service(ActiveBO, + 'active_bo_srv', + self.active_bo_callback, + callback_group=bo_callback_group) + + self.active_rl_client = self.create_client(ActiveRL, + 'active_rl_srv', + callback_group=rl_callback_group) self.env = Continuous_MountainCarEnv() self.distance_penalty = 0 @@ -56,18 +65,22 @@ class ActiveBOService(Node): arl_request.old_policy = old_policy.tolist() arl_request.old_weights = old_weights.tolist() + self.get_logger().info('Calling: Active RL') future_rl = self.active_rl_client.call_async(arl_request) - while rclpy.ok(): + while not future_rl.done(): rclpy.spin_once(self) - if future_rl.done(): - try: - arl_response = future_rl.result() - self.get_logger().info('active RL Response: %s' % arl_response) - BO.add_new_observation(arl_response.reward, arl_response.new_weights) - except Exception as e: - self.get_logger().error('active RL Service failed %r' % (e,)) - break + self.get_logger().info('waiting for response!') + + self.get_logger().info('Received: Active RL') + + try: + arl_response = future_rl.result() + BO.add_new_observation(arl_response.reward, arl_response.new_weights) + except Exception as e: + self.get_logger().error('active RL Service failed %r' % (e,)) + + future_rl = None # BO part else: @@ -96,6 +109,8 @@ def main(args=None): rclpy.spin(active_bo_service) + rclpy.shutdown() + if __name__ == '__main__': main() diff --git a/src/active_bo_ros/active_bo_ros/active_rl_service.py b/src/active_bo_ros/active_bo_ros/active_rl_service.py index 17fe153..c5a8ee8 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_service.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_service.py @@ -1,6 +1,6 @@ from active_bo_msgs.srv import ActiveRL from active_bo_msgs.msg import ImageFeedback -from active_bo_msgs.srv import ActiveRLEval +from active_bo_msgs.msg import ActiveRLEval import rclpy from rclpy.node import Node @@ -17,44 +17,42 @@ class ActiveRLService(Node): def __init__(self): super().__init__('active_rl_service') srv_callback_group = ReentrantCallbackGroup() - eval_callback_group = ReentrantCallbackGroup() + sub_callback_group = ReentrantCallbackGroup() - self.rl_srv = self.create_service(ActiveRL, - 'active_rl_srv', - self.active_rl_callback, - callback_group=srv_callback_group) + self.srv = self.create_service(ActiveRL, + 'active_rl_srv', + self.active_rl_callback, + callback_group=srv_callback_group) self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1) - self.active_rl_eval_client = self.create_client(ActiveRLEval, - 'active_rl_eval_srv', - callback_group=eval_callback_group) - - # self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1) - # self.eval_sub = self.create_subscription(ActiveRLEval, - # 'active_rl_eval_response', - # self.active_rl_eval_callback, - # 10, - # callback_group=sub_callback_group) - # self.eval_response_received = False - # self.eval_response = None - # self.eval_response_received_first = False + self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1) + self.eval_sub = self.create_subscription(ActiveRLEval, + 'active_rl_eval_response', + self.active_rl_eval_callback, + 1, + callback_group=sub_callback_group) + self.eval_response_received = False + self.eval_response = None + self.eval_response_received_first = False self.env = Continuous_MountainCarEnv(render_mode='rgb_array') self.distance_penalty = 0 - # def active_rl_eval_callback(self, response): - # # if not self.eval_response_received_first: - # # self.eval_response_received_first = True - # # self.get_logger().info('/active_rl_eval_response connected!') - # # else: - # # self.eval_response = response - # # self.eval_response_received = True - # self.eval_response = response - # self.eval_response_received = True + def active_rl_eval_callback(self, response): + # if not self.eval_response_received_first: + # self.eval_response_received_first = True + # self.get_logger().info('/active_rl_eval_response connected!') + # else: + # self.eval_response = response + # self.eval_response_received = True + self.eval_response = response + self.eval_response_received = True def active_rl_callback(self, request, response): + self.get_logger().info('Active RL: Called') + feedback_msg = ImageFeedback() reward = 0 @@ -62,9 +60,9 @@ class ActiveRLService(Node): old_policy = request.old_policy old_weights = request.old_weights - eval_request = ActiveRLEval.Request() - eval_request.old_policy = old_policy - eval_request.old_weights = old_weights + eval_request = ActiveRLEval() + eval_request.policy = old_policy + eval_request.weights = old_weights self.env.reset() @@ -95,17 +93,16 @@ class ActiveRLService(Node): break self.get_logger().info('Enter new solution!') - # self.eval_pub.publish(eval_request) - # - # while not self.eval_response_received: - # rclpy.spin_once(self) + self.eval_pub.publish(eval_request) - eval_response = self.active_rl_eval_client.call(eval_request) - self.get_logger().info('Active RL Eval Srv started!') - - new_policy = eval_response.new_policy - new_weights = eval_response.new_weights + while not self.eval_response_received: + rclpy.spin_once(self) + self.get_logger().info('Topic responded!') + new_policy = self.eval_response.policy + new_weights = self.eval_response.weights + self.eval_response_received = False + self.eval_response = None reward = 0 step_count = 0 @@ -157,6 +154,8 @@ def main(args=None): rclpy.spin(active_rl_service) + rclpy.shutdown() + if __name__ == '__main__': main() diff --git a/src/active_bo_ros/active_bo_ros/rl_service.py b/src/active_bo_ros/active_bo_ros/rl_service.py index cde3e10..67f26f6 100644 --- a/src/active_bo_ros/active_bo_ros/rl_service.py +++ b/src/active_bo_ros/active_bo_ros/rl_service.py @@ -71,5 +71,7 @@ def main(args=None): rclpy.spin(rl_service) + rclpy.shutdown() + if __name__ == '__main__': main()