From dcdd257ead748b64fa83cc67340bfda96d042a89 Mon Sep 17 00:00:00 2001 From: Niko Date: Tue, 20 Jun 2023 14:43:17 +0200 Subject: [PATCH] fixing the parameters is working --- src/active_bo_msgs/CMakeLists.txt | 4 +- .../msg/ActiveRLEvalRequest.msg | 2 + .../msg/ActiveRLEvalResponse.msg | 3 + .../msg/{ActiveRL.msg => ActiveRLRequest.msg} | 2 - src/active_bo_msgs/msg/ActiveRLResponse.msg | 1 + .../OverwriteParadigm/__init__.py | 0 .../OverwriteParadigm/fixed_params.py | 0 .../active_bo_ros/active_rl_topic.py | 58 +++++++------------ .../active_bo_ros/interactive_bo.py | 42 +++++++++----- 9 files changed, 57 insertions(+), 55 deletions(-) create mode 100644 src/active_bo_msgs/msg/ActiveRLEvalRequest.msg create mode 100644 src/active_bo_msgs/msg/ActiveRLEvalResponse.msg rename src/active_bo_msgs/msg/{ActiveRL.msg => ActiveRLRequest.msg} (56%) create mode 100644 src/active_bo_ros/active_bo_ros/OverwriteParadigm/__init__.py create mode 100644 src/active_bo_ros/active_bo_ros/OverwriteParadigm/fixed_params.py diff --git a/src/active_bo_msgs/CMakeLists.txt b/src/active_bo_msgs/CMakeLists.txt index f9f39d0..3b65401 100644 --- a/src/active_bo_msgs/CMakeLists.txt +++ b/src/active_bo_msgs/CMakeLists.txt @@ -28,7 +28,9 @@ rosidl_generate_interfaces(${PROJECT_NAME} "msg/ActiveBORequest.msg" "msg/ActiveBOResponse.msg" "msg/ActiveRLResponse.msg" - "msg/ActiveRL.msg" + "msg/ActiveRLRequest.msg" + "msg/ActiveRLEvalRequest.msg" + "msg/ActiveRLEvalResponse.msg" "msg/ImageFeedback.msg" "msg/ActiveBOState.msg" diff --git a/src/active_bo_msgs/msg/ActiveRLEvalRequest.msg b/src/active_bo_msgs/msg/ActiveRLEvalRequest.msg new file mode 100644 index 0000000..c7c6249 --- /dev/null +++ b/src/active_bo_msgs/msg/ActiveRLEvalRequest.msg @@ -0,0 +1,2 @@ +float64[] policy +float64[] weights \ No newline at end of file diff --git a/src/active_bo_msgs/msg/ActiveRLEvalResponse.msg b/src/active_bo_msgs/msg/ActiveRLEvalResponse.msg new file mode 100644 index 0000000..9a562fb --- /dev/null +++ b/src/active_bo_msgs/msg/ActiveRLEvalResponse.msg @@ -0,0 +1,3 @@ +bool[] overwrite_weight +float64[] policy +float64[] weights \ No newline at end of file diff --git a/src/active_bo_msgs/msg/ActiveRL.msg b/src/active_bo_msgs/msg/ActiveRLRequest.msg similarity index 56% rename from src/active_bo_msgs/msg/ActiveRL.msg rename to src/active_bo_msgs/msg/ActiveRLRequest.msg index d607cff..83b4301 100644 --- a/src/active_bo_msgs/msg/ActiveRL.msg +++ b/src/active_bo_msgs/msg/ActiveRLRequest.msg @@ -2,7 +2,5 @@ string env uint32 seed bool display_run uint8 interactive_run -bool overwrite_param -bool[] overwrite_weight float64[] policy float64[] weights \ No newline at end of file diff --git a/src/active_bo_msgs/msg/ActiveRLResponse.msg b/src/active_bo_msgs/msg/ActiveRLResponse.msg index 9565fc8..f844283 100644 --- a/src/active_bo_msgs/msg/ActiveRLResponse.msg +++ b/src/active_bo_msgs/msg/ActiveRLResponse.msg @@ -1,3 +1,4 @@ float64[] weights +bool[] overwrite_weight uint16 final_step float64 reward \ No newline at end of file diff --git a/src/active_bo_ros/active_bo_ros/OverwriteParadigm/__init__.py b/src/active_bo_ros/active_bo_ros/OverwriteParadigm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/active_bo_ros/active_bo_ros/OverwriteParadigm/fixed_params.py b/src/active_bo_ros/active_bo_ros/OverwriteParadigm/fixed_params.py new file mode 100644 index 0000000..e69de29 diff --git a/src/active_bo_ros/active_bo_ros/active_rl_topic.py b/src/active_bo_ros/active_bo_ros/active_rl_topic.py index 1100efc..0729c09 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_topic.py @@ -1,5 +1,7 @@ -from active_bo_msgs.msg import ActiveRL +from active_bo_msgs.msg import ActiveRLRequest from active_bo_msgs.msg import ActiveRLResponse +from active_bo_msgs.msg import ActiveRLEvalRequest +from active_bo_msgs.msg import ActiveRLEvalResponse from active_bo_msgs.msg import ImageFeedback @@ -19,7 +21,7 @@ import time import copy -class ActiveRLService(Node): +class ActiveRL(Node): def __init__(self): super().__init__('active_rl_service') rl_callback_group = ReentrantCallbackGroup() @@ -30,7 +32,7 @@ class ActiveRLService(Node): self.active_rl_pub = self.create_publisher(ActiveRLResponse, 'active_rl_response', 1, callback_group=rl_callback_group) - self.active_rl_sub = self.create_subscription(ActiveRL, + self.active_rl_sub = self.create_subscription(ActiveRLRequest, 'active_rl_request', self.active_rl_callback, 1, callback_group=rl_callback_group) @@ -48,11 +50,11 @@ class ActiveRLService(Node): 1, callback_group=topic_callback_group) # Active RL Evaluation Publisher, Subscriber and Message attributes - self.eval_pub = self.create_publisher(ActiveRL, + self.eval_pub = self.create_publisher(ActiveRLEvalRequest, 'active_rl_eval_request', 1, callback_group=topic_callback_group) - self.eval_sub = self.create_subscription(ActiveRL, + self.eval_sub = self.create_subscription(ActiveRLEvalResponse, 'active_rl_eval_response', self.active_rl_eval_callback, 1, @@ -61,6 +63,7 @@ class ActiveRLService(Node): self.eval_response_received = False self.eval_policy = None self.eval_weights = None + self.overwrite_weight = None # RL Environments self.env = None @@ -69,7 +72,7 @@ class ActiveRLService(Node): self.best_pol_shown = False self.policy_sent = False self.rl_pending = False - self.interactive_run = False + self.interactive_run = 0 self.display_run = False # Main loop timer object @@ -84,6 +87,7 @@ class ActiveRLService(Node): self.rl_policy = None self.rl_weights = None self.interactive_run = 0 + self.display_run = False def active_rl_callback(self, msg): self.rl_env = msg.env @@ -117,6 +121,7 @@ class ActiveRLService(Node): def active_rl_eval_callback(self, msg): self.eval_policy = np.array(msg.policy, dtype=np.float64) self.eval_weights = msg.weights + self.overwrite_weight = msg.overwrite_weight self.get_logger().info('Active RL Eval: Responded!') self.env.reset(seed=self.rl_seed) @@ -185,7 +190,7 @@ class ActiveRLService(Node): self.rl_reward = 0.0 self.env.reset(seed=self.rl_seed) - eval_request = ActiveRL() + eval_request = ActiveRLEvalRequest() eval_request.policy = self.rl_policy.tolist() eval_request.weights = self.rl_weights @@ -214,6 +219,7 @@ class ActiveRLService(Node): rl_response.weights = self.eval_weights rl_response.reward = self.rl_reward rl_response.final_step = self.rl_step + rl_response.overwrite_weight = self.overwrite_weight self.active_rl_pub.publish(rl_response) @@ -236,7 +242,7 @@ class ActiveRLService(Node): self.rl_reward = 0.0 self.env.reset(seed=self.rl_seed) - eval_request = ActiveRL() + eval_request = ActiveRLEvalRequest() eval_request.policy = self.rl_policy.tolist() eval_request.weights = self.rl_weights @@ -260,45 +266,23 @@ class ActiveRLService(Node): rl_response.weights = self.rl_weights rl_response.reward = env_reward rl_response.final_step = step_count + if self.overwrite_weight is None: + overwrite_weight = [False] * len(self.rl_weights) + else: + overwrite_weight = self.overwrite_weight + + rl_response.overwrite_weight = overwrite_weight self.active_rl_pub.publish(rl_response) self.reset_rl_request() self.rl_pending = False - # if not self.policy_sent: - # self.rl_step = 0 - # self.rl_reward = 0.0 - # self.env.reset(seed=self.rl_seed) - # self.policy_sent = True - # done = self.next_image(self.rl_policy, self.display_run) - # - # if done: - # rl_response = ActiveRLResponse() - # rl_response.weights = self.rl_weights - # rl_response.reward = self.rl_reward - # rl_response.final_step = self.rl_step - # - # self.active_rl_pub.publish(rl_response) - # self.end_time = time.time() - # self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}') - # self.begin_time = None - # self.end_time = None - # - # # reset flags and attributes - # self.reset_eval_request() - # self.reset_rl_request() - # - # self.rl_step = 0 - # self.rl_reward = 0.0 - # - # self.rl_pending = False - def main(args=None): rclpy.init(args=args) - active_rl_service = ActiveRLService() + active_rl_service = ActiveRL() rclpy.spin(active_rl_service) diff --git a/src/active_bo_ros/active_bo_ros/interactive_bo.py b/src/active_bo_ros/active_bo_ros/interactive_bo.py index 3241910..82a0536 100644 --- a/src/active_bo_ros/active_bo_ros/interactive_bo.py +++ b/src/active_bo_ros/active_bo_ros/interactive_bo.py @@ -1,6 +1,6 @@ from active_bo_msgs.msg import ActiveBORequest from active_bo_msgs.msg import ActiveBOResponse -from active_bo_msgs.msg import ActiveRL +from active_bo_msgs.msg import ActiveRLRequest from active_bo_msgs.msg import ActiveRLResponse from active_bo_msgs.msg import ActiveBOState @@ -54,6 +54,7 @@ class ActiveBOTopic(Node): self.bo_runs = 0 self.bo_acq_fcn = None self.bo_metric_parameter = None + self.bo_metric_parameter_2 = None self.current_run = 0 self.current_episode = 0 self.seed = None @@ -61,7 +62,7 @@ class ActiveBOTopic(Node): self.save_result = False # Active Reinforcement Learning Publisher, Subscriber and Message attributes - self.active_rl_pub = self.create_publisher(ActiveRL, + self.active_rl_pub = self.create_publisher(ActiveRLRequest, 'active_rl_request', 1, callback_group=rl_callback_group) self.active_rl_sub = self.create_subscription(ActiveRLResponse, @@ -73,6 +74,7 @@ class ActiveBOTopic(Node): self.rl_weights = None self.rl_final_step = None self.rl_reward = 0.0 + self.overwrite_weight = None # State Publisher self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1) @@ -112,6 +114,7 @@ class ActiveBOTopic(Node): self.bo_runs = 0 self.bo_acq_fcn = None self.bo_metric_parameter = None + self.bo_metric_parameter_2 = None self.current_run = 0 self.current_episode = 0 self.save_result = False @@ -172,6 +175,7 @@ class ActiveBOTopic(Node): if self.rl_pending: # self.get_logger().info('Active Reinforcement Learning response received!') self.rl_weights = np.array(msg.weights, dtype=np.float64) + self.overwrite_weight = np.array(msg.overwrite_weight, dtype=bool) self.rl_final_step = msg.final_step self.rl_reward = msg.reward @@ -193,7 +197,7 @@ class ActiveBOTopic(Node): self.user_asked = False self.rl_pending = False - self.reset_rl_response() + # self.reset_rl_response() def mainloop_callback(self): if not self.active_bo_pending: @@ -223,7 +227,7 @@ class ActiveBOTopic(Node): else: seed = int(np.random.randint(1, 2147483647, 1)[0]) - rl_msg = ActiveRL() + rl_msg = ActiveRLRequest() rl_msg.env = self.bo_env rl_msg.seed = seed rl_msg.display_run = False @@ -282,7 +286,7 @@ class ActiveBOTopic(Node): np.savetxt(path, data, delimiter=',') - active_rl_request = ActiveRL() + active_rl_request = ActiveRLRequest() if self.bo_fixed_seed: seed = int(self.seed_array[0, best_policy_idx]) @@ -292,10 +296,10 @@ class ActiveBOTopic(Node): active_rl_request.env = self.bo_env active_rl_request.seed = seed + active_rl_request.display_run = True + active_rl_request.interactive_run = 1 active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist() active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist() - active_rl_request.interactive_run = 1 - active_rl_request.display_run = True self.active_rl_pub.publish(active_rl_request) @@ -335,7 +339,7 @@ class ActiveBOTopic(Node): if user_query.query(): self.last_query = self.current_episode self.user_asked = True - active_rl_request = ActiveRL() + active_rl_request = ActiveRLRequest() old_policy, y_max, old_weights, _ = self.BO.get_best_result() # self.get_logger().info(f'Best: {y_max}, w:{old_weights}') @@ -349,9 +353,9 @@ class ActiveBOTopic(Node): active_rl_request.env = self.bo_env active_rl_request.seed = seed active_rl_request.display_run = True + active_rl_request.interactive_run = 0 active_rl_request.policy = old_policy.tolist() active_rl_request.weights = old_weights.tolist() - active_rl_request.interactive_run = 0 # self.get_logger().info('Calling: Active RL') self.active_rl_pub.publish(active_rl_request) @@ -359,6 +363,16 @@ class ActiveBOTopic(Node): else: x_next = self.BO.next_observation() + self.get_logger().info(f'x_next: {x_next}') + self.get_logger().info(f'overwrite: {self.overwrite_weight}') + self.get_logger().info(f'rl_weights: {self.rl_weights}') + + if self.overwrite: + if self.overwrite_weight is not None and self.rl_weights is not None: + x_next[self.overwrite_weight] = self.rl_weights[self.overwrite_weight] + self.get_logger().info(f'x_next: {x_next}') + self.get_logger().info(f'overwrite: {self.overwrite_weight}') + self.get_logger().info(f'rl_weights: {self.rl_weights}') # self.get_logger().info('Next Observation BO!') self.BO.policy_model.weights = np.around(x_next, decimals=8) if self.bo_fixed_seed: @@ -366,13 +380,14 @@ class ActiveBOTopic(Node): else: seed = int(np.random.randint(1, 2147483647, 1)[0]) - rl_msg = ActiveRL() + rl_msg = ActiveRLRequest() rl_msg.env = self.bo_env rl_msg.seed = seed rl_msg.display_run = False rl_msg.interactive_run = 2 - rl_msg.weights = x_next.tolist() rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist() + rl_msg.weights = x_next.tolist() + self.rl_pending = True self.active_rl_pub.publish(rl_msg) @@ -381,15 +396,12 @@ class ActiveBOTopic(Node): self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y) self.get_logger().info(f'Current Episode: {self.current_episode},' f' best reward: {self.reward[self.current_episode, self.current_run]}') + else: self.best_policy[:, self.current_run], \ self.best_pol_reward[:, self.current_run], \ self.best_weights[:, self.current_run], idx = self.BO.get_best_result() - # self.get_logger().info(f'best idx: {idx}') - - # self.reward[:, self.current_run] = self.BO.best_reward.T - if self.current_run < self.bo_runs - 1: self.BO = None