diff --git a/src/active_bo_msgs/msg/DMP.msg b/src/active_bo_msgs/msg/DMP.msg index ce764e5..039fc69 100644 --- a/src/active_bo_msgs/msg/DMP.msg +++ b/src/active_bo_msgs/msg/DMP.msg @@ -2,6 +2,7 @@ float64[7] start_point float64[7] end_point float64 time +int32 nr_steps int32 nr_bfs # weights for the dimensions of the Pose (3 Position, 4 Orientation) @@ -11,4 +12,7 @@ float64[] p_z float64[] o_x float64[] o_y float64[] o_z -float64[] o_w \ No newline at end of file +float64[] o_w + +# BO parameters +uint8 interactive_run diff --git a/src/active_bo_ros/active_bo_ros/interactive_bo_robot.py b/src/active_bo_ros/active_bo_ros/interactive_bo_robot.py new file mode 100644 index 0000000..8b778b7 --- /dev/null +++ b/src/active_bo_ros/active_bo_ros/interactive_bo_robot.py @@ -0,0 +1,419 @@ + +from active_bo_msgs.msg import ActiveBORequest +from active_bo_msgs.msg import ActiveBOResponse +from active_bo_msgs.msg import ActiveRLRequest +from active_bo_msgs.msg import ActiveRLResponse +from active_bo_msgs.msg import ActiveBOState +from active_bo_msgs.msg import DMP + +import rclpy +from rclpy.node import Node + +from rclpy.callback_groups import ReentrantCallbackGroup + +from active_bo_ros.BayesianOptimization.BO2D import BayesianOptimization + +from active_bo_ros.UserQuery.random_query import RandomQuery +from active_bo_ros.UserQuery.regular_query import RegularQuery +from active_bo_ros.UserQuery.improvement_query import ImprovementQuery +from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery + +import numpy as np +import time +import os + + +class ActiveBOTopic(Node): + def __init__(self): + super().__init__('active_bo_topic') + + bo_callback_group = ReentrantCallbackGroup() + rl_callback_group = ReentrantCallbackGroup() + mainloop_callback_group = ReentrantCallbackGroup() + + # Active Bayesian Optimization Publisher, Subscriber and Message attributes + self.active_bo_pub = self.create_publisher(ActiveBOResponse, + 'active_bo_response', + 1, callback_group=bo_callback_group) + + self.active_bo_sub = self.create_subscription(ActiveBORequest, + 'active_bo_request', + self.bo_callback, + 1, callback_group=bo_callback_group) + + self.active_bo_pending = False + self.bo_env = None + self.bo_metric = None + self.bo_fixed_seed = False + self.bo_nr_weights = None + self.bo_nr_dims = 2 + self.bo_steps = 0 + self.bo_episodes = 0 + self.bo_runs = 0 + self.bo_acq_fcn = None + self.bo_metric_parameter = None + self.bo_metric_parameter_2 = None + self.current_run = 0 + self.current_episode = 0 + self.seed = None + self.seed_array = None + self.save_result = False + + # Active Reinforcement Learning Publisher, Subscriber and Message attributes + self.active_rl_pub = self.create_publisher(DMP, + 'active_rl_request', + 1, callback_group=rl_callback_group) + self.active_rl_sub = self.create_subscription(ActiveRLResponse, + 'active_rl_response', + self.rl_callback, + 1, callback_group=rl_callback_group) + + self.rl_pending = False + self.rl_weights = None + self.rl_final_step = None + self.rl_reward = 0.0 + self.weight_preference = None + + # State Publisher + self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1) + + # RL Environments and BO + self.env = None + + self.BO = None + self.nr_init = 3 + self.init_step = 0 + self.init_pending = False + self.reward = None + self.best_reward = 0.0 + self.best_pol_reward = None + self.best_policy = None + self.best_weights = None + + # User Query + self.last_query = 0 + self.user_asked = False + self.last_user_reward = 0.0 + self.overwrite = False + + # Main loop timer object + self.mainloop_timer_period = 0.1 + self.mainloop = self.create_timer(self.mainloop_timer_period, + self.mainloop_callback, + callback_group=mainloop_callback_group) + + def reset_bo_request(self): + self.bo_env = None + self.bo_metric = None + self.bo_fixed_seed = False + self.bo_nr_weights = None + self.bo_nr_dims = None + self.bo_steps = 0 + self.bo_episodes = 0 + self.bo_runs = 0 + self.bo_acq_fcn = None + self.bo_metric_parameter = None + self.bo_metric_parameter_2 = None + self.current_run = 0 + self.current_episode = 0 + self.save_result = False + self.seed_array = None + self.env = None + self.active_bo_pending = False + self.BO = None + self.overwrite = False + + def bo_callback(self, msg): + if not self.active_bo_pending: + # self.get_logger().info('Active Bayesian Optimization request pending!') + self.active_bo_pending = True + self.bo_env = msg.env + self.bo_metric = msg.metric + self.bo_fixed_seed = msg.fixed_seed + self.bo_nr_weights = msg.nr_weights + self.bo_nr_dims = msg.nr_dims + self.bo_steps = msg.max_steps + self.bo_episodes = msg.nr_episodes + self.bo_runs = msg.nr_runs + self.bo_acq_fcn = msg.acquisition_function + self.bo_metric_parameter = msg.metric_parameter + self.bo_metric_parameter_2 = msg.metric_parameter_2 + self.save_result = msg.save_result + self.seed_array = np.zeros((self.bo_runs, 1)) + self.overwrite = msg.overwrite + + # initialize + self.reward = np.ones((self.bo_runs, self.bo_episodes + self.nr_init - 1)) * -self.bo_steps + self.best_pol_reward = np.ones((self.bo_runs, 1)) * -self.bo_steps + self.best_policy = np.zeros((self.bo_runs, self.bo_steps, self.bo_nr_dims)) + self.best_weights = np.zeros((self.bo_runs, self.bo_nr_weights, self.bo_nr_dims)) + + # set the seed + if self.bo_fixed_seed: + self.seed = int(np.random.randint(1, 2147483647, 1)[0]) + else: + self.seed = None + + def reset_rl_response(self): + self.rl_weights = None + self.rl_final_step = None + + def rl_callback(self, msg): + if self.rl_pending: + # self.get_logger().info('Active Reinforcement Learning response received!') + self.rl_weights = np.array(msg.weights, dtype=np.float64) + self.weight_preference = np.array(msg.weight_preference, dtype=bool) + self.rl_final_step = msg.final_step + self.rl_reward = msg.reward + + try: + self.BO.add_observation(self.rl_reward, self.rl_weights) + # self.get_logger().info('Active Reinforcement Learning added new observation!') + except Exception as e: + self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}') + + if self.init_pending: + self.init_step += 1 + + if self.init_step == self.nr_init: + self.init_step = 0 + self.init_pending = False + + if self.user_asked: + self.last_user_reward = self.rl_reward + if self.bo_acq_fcn == "Preference Expected Improvement": + self.BO.acq_fun.update_proposal_model(self.rl_weights, self.weight_preference) + self.user_asked = False + + self.rl_pending = False + # self.reset_rl_response() + + def mainloop_callback(self): + if not self.active_bo_pending: + return + + else: + if self.rl_pending: + return + + if self.BO is None: + self.BO = BayesianOptimization(self.bo_steps, + self.bo_nr_dims, + self.bo_nr_weights, + acq=self.bo_acq_fcn) + + self.BO.reset_bo() + + # self.BO.initialize() + self.init_pending = True + self.get_logger().info('BO Initialization is starting!') + # self.get_logger().info(f'{self.rl_pending}') + + if self.init_pending: + if self.bo_fixed_seed: + seed = self.seed + else: + seed = int(np.random.randint(1, 2147483647, 1)[0]) + + w = self.BO.policy_model.random_weights() + + rl_msg = DMP() + rl_msg.interactive_run = 2 + rl_msg.p_x = w[:, 0] + rl_msg.p_y = w[:, 1] + rl_msg.nr_steps = self.bo_steps + rl_msg.nr_bfs = self.bo_nr_weights + + self.active_rl_pub.publish(rl_msg) + + self.rl_pending = True + + if self.current_run == self.bo_runs: + bo_response = ActiveBOResponse() + + best_policy_idx = np.argmax(self.best_pol_reward) + bo_response.best_policy = self.best_policy[best_policy_idx, :, :].flatten('F').tolist() + bo_response.best_weights = self.best_weights[best_policy_idx, :, :].flatten('F').tolist() + bo_response.nr_weights = self.bo_nr_weights + bo_response.nr_steps = self.bo_steps + + bo_response.reward_mean = np.mean(self.reward, axis=0).tolist() + bo_response.reward_std = np.std(self.reward, axis=0).tolist() + + if self.save_result: + if self.bo_env == "Reacher": + env = 're' + elif self.bo_env == "Finger": + env = 'fin' + else: + raise NotImplementedError + + if self.bo_acq_fcn == "Expected Improvement": + acq = 'ei' + elif self.bo_acq_fcn == "Probability of Improvement": + acq = 'pi' + elif self.bo_acq_fcn == "Upper Confidence Bound": + acq = 'cb' + elif self.bo_acq_fcn == "Preference Expected Improvement": + acq = 'pei' + else: + raise NotImplementedError + + home_dir = os.path.expanduser('~') + file_path = os.path.join(home_dir, 'Documents/IntRLResults') + filename = env + '-' + acq + '-' + self.bo_metric + '-' \ + + str(round(self.bo_metric_parameter, 2)) + '-' \ + + str(self.bo_nr_weights) + '-' + str(time.time()) + filename = filename.replace('.', '_') + '.csv' + path = os.path.join(file_path, filename) + + data = self.reward + + np.savetxt(path, data, delimiter=',') + + + + w = self.best_weights[best_policy_idx, :, :] + + rl_msg = DMP() + rl_msg.interactive_run = 1 + rl_msg.p_x = w[:, 0] + rl_msg.p_y = w[:, 1] + rl_msg.nr_steps = self.bo_steps + rl_msg.nr_bfs = self.bo_nr_weights + + self.active_rl_pub.publish(rl_msg) + + self.rl_pending = True + + self.active_rl_pub.publish(rl_msg) + + self.get_logger().info('Responding: Active BO') + self.active_bo_pub.publish(bo_response) + self.reset_bo_request() + + else: + if self.init_pending: + return + else: + if self.current_episode < self.bo_episodes + self.nr_init - 1: + # metrics + if self.bo_metric == "random": + user_query = RandomQuery(self.bo_metric_parameter) + + elif self.bo_metric == "regular": + user_query = RegularQuery(self.bo_metric_parameter, self.current_episode) + + elif self.bo_metric == "improvement": + user_query = ImprovementQuery(self.bo_metric_parameter, + self.bo_metric_parameter_2, + self.last_query, + self.reward[self.current_run, :self.current_episode]) + + elif self.bo_metric == "max acquisition": + user_query = MaxAcqQuery(self.bo_metric_parameter, + self.BO.GP, + 100, + self.bo_nr_weights, + acq=self.bo_acq_fcn, + X=self.BO.X) + + else: + raise NotImplementedError + + if user_query.query(): + self.last_query = self.current_episode + self.user_asked = True + y_max, old_weights, _ = self.BO.get_best_result() + + self.BO.policy_model.set_weights(old_weights) + + w = self.BO.policy_model.weights + + rl_msg = DMP() + rl_msg.interactive_run = 0 + rl_msg.p_x = w[:, 0] + rl_msg.p_y = w[:, 1] + rl_msg.nr_steps = self.bo_steps + rl_msg.nr_bfs = self.bo_nr_weights + + self.active_rl_pub.publish(rl_msg) + + self.rl_pending = True + + self.active_rl_pub.publish(rl_msg) + self.rl_pending = True + + else: + x_next = self.BO.next_observation() + self.BO.policy_model.set_weights(np.around(x_next, decimals=8)) + + w = self.BO.policy_model.weights + + rl_msg = DMP() + rl_msg.interactive_run = 1 + rl_msg.p_x = w[:, 0] + rl_msg.p_y = w[:, 1] + rl_msg.nr_steps = self.bo_steps + rl_msg.nr_bfs = self.bo_nr_weights + + self.active_rl_pub.publish(rl_msg) + + self.rl_pending = True + + self.reward[self.current_run, self.current_episode] = np.max(self.BO.Y) + self.get_logger().info(f'Current Episode: {self.current_episode},' + f' best reward: {self.reward[self.current_run, self.current_episode]}') + self.current_episode += 1 + + else: + self.best_pol_reward[self.current_run, :], \ + self.best_weights[self.current_run, :, :], idx = self.BO.get_best_result() + + self.BO.policy_model.weights = self.best_weights[self.current_run, :, :] + self.best_policy[self.current_run, :, :] = self.BO.policy_model.rollout() + + if self.current_run < self.bo_runs - 1: + self.BO = None + + self.current_episode = 0 + self.last_query = 0 + if self.bo_fixed_seed: + self.seed_array[self.current_run, 0] = self.seed + self.seed = int(np.random.randint(1, 2147483647, 1)[0]) + # self.get_logger().info(f'{self.seed}') + self.current_run += 1 + self.get_logger().info(f'Current Run: {self.current_run}') + + # send the current states + + if self.BO is not None and self.BO.Y is not None: + self.best_reward = np.max(self.BO.Y) + + state_msg = ActiveBOState() + state_msg.current_run = self.current_run + 1 if self.current_run < self.bo_runs else self.bo_runs + state_msg.current_episode = self.current_episode \ + if self.current_episode < self.bo_episodes else self.bo_episodes + state_msg.best_reward = float(self.best_reward) + state_msg.last_user_reward = self.last_user_reward + self.state_pub.publish(state_msg) + + +def main(args=None): + rclpy.init(args=args) + + active_bo_topic = ActiveBOTopic() + + rclpy.spin(active_bo_topic) + + try: + rclpy.spin(active_bo_topic) + except KeyboardInterrupt: + pass + + active_bo_topic.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main()