diff --git a/src/active_bo_msgs/CMakeLists.txt b/src/active_bo_msgs/CMakeLists.txt index df1d155..bbf7066 100644 --- a/src/active_bo_msgs/CMakeLists.txt +++ b/src/active_bo_msgs/CMakeLists.txt @@ -25,8 +25,9 @@ rosidl_generate_interfaces(${PROJECT_NAME} "srv/BO.srv" "srv/ActiveBO.srv" "srv/ActiveRL.srv" + "srv/ActiveRLEval.srv" "msg/ImageFeedback.msg" - "msg/ActiveRLEval.msg" + ) if(BUILD_TESTING) diff --git a/src/active_bo_ros/active_bo_ros/active_bo_service.py b/src/active_bo_ros/active_bo_ros/active_bo_service.py index e8d20f7..87902b2 100644 --- a/src/active_bo_ros/active_bo_ros/active_bo_service.py +++ b/src/active_bo_ros/active_bo_ros/active_bo_service.py @@ -13,10 +13,12 @@ import numpy as np class ActiveBOService(Node): def __init__(self): super().__init__('active_bo_service') - self.srv = self.create_service(ActiveBO, 'active_bo_srv', self.active_bo_callback) + self.active_bo_srv = self.create_service(ActiveBO, 'active_bo_srv', self.active_bo_callback) self.active_rl_client = self.create_client(ActiveRL, 'active_rl_srv') + self.rl_trigger_ = False + self.env = Continuous_MountainCarEnv() self.distance_penalty = 0 @@ -54,9 +56,18 @@ class ActiveBOService(Node): arl_request.old_policy = old_policy.tolist() arl_request.old_weights = old_weights.tolist() - arl_response = self.active_rl_client.call(arl_request) + future_rl = self.active_rl_client.call_async(arl_request) - BO.add_new_observation(arl_response.reward, arl_response.new_weights) + while rclpy.ok(): + rclpy.spin_once(self) + if future_rl.done(): + try: + arl_response = future_rl.result() + self.get_logger().info('active RL Response: %s' % arl_response) + BO.add_new_observation(arl_response.reward, arl_response.new_weights) + except Exception as e: + self.get_logger().error('active RL Service failed %r' % (e,)) + break # BO part else: diff --git a/src/active_bo_ros/active_bo_ros/active_rl_service.py b/src/active_bo_ros/active_bo_ros/active_rl_service.py index 1c07a31..17fe153 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_service.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_service.py @@ -1,6 +1,6 @@ from active_bo_msgs.srv import ActiveRL from active_bo_msgs.msg import ImageFeedback -from active_bo_msgs.msg import ActiveRLEval +from active_bo_msgs.srv import ActiveRLEval import rclpy from rclpy.node import Node @@ -17,37 +17,41 @@ class ActiveRLService(Node): def __init__(self): super().__init__('active_rl_service') srv_callback_group = ReentrantCallbackGroup() - sub_callback_group = ReentrantCallbackGroup() + eval_callback_group = ReentrantCallbackGroup() - self.srv = self.create_service(ActiveRL, - 'active_rl_srv', - self.active_rl_callback, - callback_group=srv_callback_group) + self.rl_srv = self.create_service(ActiveRL, + 'active_rl_srv', + self.active_rl_callback, + callback_group=srv_callback_group) self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1) - self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1) - self.eval_sub = self.create_subscription(ActiveRLEval, - 'active_rl_eval_response', - self.active_rl_eval_callback, - 10, - callback_group=sub_callback_group) - self.eval_response_received = False - self.eval_response = None - self.eval_response_received_first = False + self.active_rl_eval_client = self.create_client(ActiveRLEval, + 'active_rl_eval_srv', + callback_group=eval_callback_group) + + # self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1) + # self.eval_sub = self.create_subscription(ActiveRLEval, + # 'active_rl_eval_response', + # self.active_rl_eval_callback, + # 10, + # callback_group=sub_callback_group) + # self.eval_response_received = False + # self.eval_response = None + # self.eval_response_received_first = False self.env = Continuous_MountainCarEnv(render_mode='rgb_array') self.distance_penalty = 0 - def active_rl_eval_callback(self, response): - # if not self.eval_response_received_first: - # self.eval_response_received_first = True - # self.get_logger().info('/active_rl_eval_response connected!') - # else: - # self.eval_response = response - # self.eval_response_received = True - self.eval_response = response - self.eval_response_received = True + # def active_rl_eval_callback(self, response): + # # if not self.eval_response_received_first: + # # self.eval_response_received_first = True + # # self.get_logger().info('/active_rl_eval_response connected!') + # # else: + # # self.eval_response = response + # # self.eval_response_received = True + # self.eval_response = response + # self.eval_response_received = True def active_rl_callback(self, request, response): @@ -58,9 +62,9 @@ class ActiveRLService(Node): old_policy = request.old_policy old_weights = request.old_weights - eval_request = ActiveRLEval() - eval_request.policy = old_policy - eval_request.weights = old_weights + eval_request = ActiveRLEval.Request() + eval_request.old_policy = old_policy + eval_request.old_weights = old_weights self.env.reset() @@ -91,16 +95,17 @@ class ActiveRLService(Node): break self.get_logger().info('Enter new solution!') - self.eval_pub.publish(eval_request) + # self.eval_pub.publish(eval_request) + # + # while not self.eval_response_received: + # rclpy.spin_once(self) - while not self.eval_response_received: - rclpy.spin_once(self) + eval_response = self.active_rl_eval_client.call(eval_request) + self.get_logger().info('Active RL Eval Srv started!') + + new_policy = eval_response.new_policy + new_weights = eval_response.new_weights - self.get_logger().info('Topic responded!') - new_policy = self.eval_response.policy - new_weights = self.eval_response.weights - self.eval_response_received = False - self.eval_response = None reward = 0 step_count = 0