diff --git a/src/active_bo_msgs/CMakeLists.txt b/src/active_bo_msgs/CMakeLists.txt index 414e6be..f9f39d0 100644 --- a/src/active_bo_msgs/CMakeLists.txt +++ b/src/active_bo_msgs/CMakeLists.txt @@ -30,6 +30,7 @@ rosidl_generate_interfaces(${PROJECT_NAME} "msg/ActiveRLResponse.msg" "msg/ActiveRL.msg" "msg/ImageFeedback.msg" + "msg/ActiveBOState.msg" ) diff --git a/src/active_bo_msgs/msg/ActiveBORequest.msg b/src/active_bo_msgs/msg/ActiveBORequest.msg index 6cdc248..ad4bed8 100644 --- a/src/active_bo_msgs/msg/ActiveBORequest.msg +++ b/src/active_bo_msgs/msg/ActiveBORequest.msg @@ -6,4 +6,5 @@ uint16 max_steps uint16 nr_episodes uint16 nr_runs string acquisition_function -float32 metric_parameter \ No newline at end of file +float32 metric_parameter +bool save_result \ No newline at end of file diff --git a/src/active_bo_msgs/msg/ActiveBOState.msg b/src/active_bo_msgs/msg/ActiveBOState.msg new file mode 100644 index 0000000..c6df26e --- /dev/null +++ b/src/active_bo_msgs/msg/ActiveBOState.msg @@ -0,0 +1,4 @@ +uint16 current_run +uint16 current_episode +float32 best_reward +float32 last_user_reward \ No newline at end of file diff --git a/src/active_bo_ros/active_bo_ros/active_bo_topic.py b/src/active_bo_ros/active_bo_ros/active_bo_topic.py index d66e77d..1b38efd 100644 --- a/src/active_bo_ros/active_bo_ros/active_bo_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_bo_topic.py @@ -2,6 +2,7 @@ from active_bo_msgs.msg import ActiveBORequest from active_bo_msgs.msg import ActiveBOResponse from active_bo_msgs.msg import ActiveRL from active_bo_msgs.msg import ActiveRLResponse +from active_bo_msgs.msg import ActiveBOState import rclpy from rclpy.node import Node @@ -22,6 +23,7 @@ from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery import numpy as np import time +import os class ActiveBOTopic(Node): @@ -47,14 +49,15 @@ class ActiveBOTopic(Node): self.bo_metric = None self.bo_fixed_seed = False self.bo_nr_weights = None - self.bo_steps = None - self.bo_episodes = None - self.bo_runs = None + self.bo_steps = 0 + self.bo_episodes = 0 + self.bo_runs = 0 self.bo_acq_fcn = None self.bo_metric_parameter = None self.current_run = 0 self.current_episode = 0 self.seed = None + self.save_result = False # Active Reinforcement Learning Publisher, Subscriber and Message attributes self.active_rl_pub = self.create_publisher(ActiveRL, @@ -68,7 +71,10 @@ class ActiveBOTopic(Node): self.active_rl_pending = False self.rl_weights = None self.rl_final_step = None - self.rl_reward = None + self.rl_reward = 0.0 + + # State Publisher + self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1) # RL Environments and BO self.env = None @@ -76,6 +82,7 @@ class ActiveBOTopic(Node): self.BO = None self.nr_init = 3 self.reward = None + self.best_reward = 0.0 self.best_pol_reward = None self.best_policy = None self.best_weights = None @@ -91,13 +98,14 @@ class ActiveBOTopic(Node): self.bo_metric = None self.bo_fixed_seed = False self.bo_nr_weights = None - self.bo_steps = None - self.bo_episodes = None - self.bo_runs = None + self.bo_steps = 0 + self.bo_episodes = 0 + self.bo_runs = 0 self.bo_acq_fcn = None self.bo_metric_parameter = None self.current_run = 0 self.current_episode = 0 + self.save_result = False def active_bo_callback(self, msg): if not self.active_bo_pending: @@ -112,6 +120,7 @@ class ActiveBOTopic(Node): self.bo_runs = msg.nr_runs self.bo_acq_fcn = msg.acquisition_function self.bo_metric_parameter = msg.metric_parameter + self.save_result = msg.save_result # initialize self.reward = np.zeros((self.bo_episodes, self.bo_runs)) @@ -128,7 +137,6 @@ class ActiveBOTopic(Node): def reset_rl_response(self): self.rl_weights = None self.rl_final_step = None - self.rl_reward = None def active_rl_callback(self, msg): if self.active_rl_pending: @@ -180,6 +188,18 @@ class ActiveBOTopic(Node): bo_response.reward_mean = np.mean(self.reward, axis=1).tolist() bo_response.reward_std = np.std(self.reward, axis=1).tolist() + if self.save_result: + home_dir = os.path.expanduser('~') + file_path = os.path.join(home_dir, 'Documents/IntRLResults') + filename = self.bo_metric + '-'\ + + str(self.bo_metric_parameter) + '-' \ + + str(self.bo_nr_weights) + '-' + str(time.time()) + filename = filename.replace('.', '_') + '.csv' + path = os.path.join(file_path, filename) + self.get_logger().info(path) + + np.savetxt(path, self.reward, delimiter=',') + self.get_logger().info('Responding: Active BO') self.active_bo_pub.publish(bo_response) self.reset_bo_request() @@ -249,6 +269,18 @@ class ActiveBOTopic(Node): self.current_run += 1 self.get_logger().info(f'Current Run: {self.current_run}') + # send the current states + + if self.BO is not None and self.BO.Y is not None: + self.best_reward = np.max(self.BO.Y) + + state_msg = ActiveBOState() + state_msg.current_run = self.current_run + 1 if self.current_run < self.bo_runs else self.bo_runs + state_msg.current_episode = self.current_episode + 1\ + if self.current_episode < self.bo_episodes else self.bo_episodes + state_msg.best_reward = self.best_reward + state_msg.last_user_reward = self.rl_reward + self.state_pub.publish(state_msg) def main(args=None): rclpy.init(args=args) diff --git a/src/active_bo_ros/active_bo_ros/active_rl_topic.py b/src/active_bo_ros/active_bo_ros/active_rl_topic.py index 0fd524d..b2a6106 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_topic.py @@ -172,6 +172,7 @@ class ActiveRLService(Node): if done: self.best_pol_shown = True self.rl_step = 0 + self.rl_reward = 0.0 elif self.best_pol_shown: if not self.eval_response_received: