diff --git a/src/active_bo_ros/active_bo_ros/interactive_bo_robot.py b/src/active_bo_ros/active_bo_ros/interactive_bo_robot.py index 8b778b7..a248e8d 100644 --- a/src/active_bo_ros/active_bo_ros/interactive_bo_robot.py +++ b/src/active_bo_ros/active_bo_ros/interactive_bo_robot.py @@ -4,7 +4,6 @@ from active_bo_msgs.msg import ActiveBOResponse from active_bo_msgs.msg import ActiveRLRequest from active_bo_msgs.msg import ActiveRLResponse from active_bo_msgs.msg import ActiveBOState -from active_bo_msgs.msg import DMP import rclpy from rclpy.node import Node @@ -60,7 +59,7 @@ class ActiveBOTopic(Node): self.save_result = False # Active Reinforcement Learning Publisher, Subscriber and Message attributes - self.active_rl_pub = self.create_publisher(DMP, + self.active_rl_pub = self.create_publisher(ActiveRLRequest, 'active_rl_request', 1, callback_group=rl_callback_group) self.active_rl_sub = self.create_subscription(ActiveRLResponse, @@ -143,7 +142,7 @@ class ActiveBOTopic(Node): self.overwrite = msg.overwrite # initialize - self.reward = np.ones((self.bo_runs, self.bo_episodes + self.nr_init - 1)) * -self.bo_steps + self.reward = np.ones((self.bo_runs, self.bo_episodes + self.nr_init - 1)) * -200 self.best_pol_reward = np.ones((self.bo_runs, 1)) * -self.bo_steps self.best_policy = np.zeros((self.bo_runs, self.bo_steps, self.bo_nr_dims)) self.best_weights = np.zeros((self.bo_runs, self.bo_nr_weights, self.bo_nr_dims)) @@ -204,25 +203,18 @@ class ActiveBOTopic(Node): self.BO.reset_bo() - # self.BO.initialize() self.init_pending = True self.get_logger().info('BO Initialization is starting!') - # self.get_logger().info(f'{self.rl_pending}') if self.init_pending: - if self.bo_fixed_seed: - seed = self.seed - else: - seed = int(np.random.randint(1, 2147483647, 1)[0]) - w = self.BO.policy_model.random_weights() - - rl_msg = DMP() + rl_msg = ActiveRLRequest() rl_msg.interactive_run = 2 - rl_msg.p_x = w[:, 0] - rl_msg.p_y = w[:, 1] + rl_msg.weights = self.BO.policy_model.random_weights().flatten('F').tolist() + rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist() + rl_msg.nr_weights = self.bo_nr_weights rl_msg.nr_steps = self.bo_steps - rl_msg.nr_bfs = self.bo_nr_weights + rl_msg.nr_dims = self.bo_nr_dims self.active_rl_pub.publish(rl_msg) @@ -271,16 +263,13 @@ class ActiveBOTopic(Node): np.savetxt(path, data, delimiter=',') - - - w = self.best_weights[best_policy_idx, :, :] - - rl_msg = DMP() + rl_msg = ActiveRLRequest() rl_msg.interactive_run = 1 - rl_msg.p_x = w[:, 0] - rl_msg.p_y = w[:, 1] + rl_msg.weights = self.BO.policy_model.random_weights().flatten('F').tolist() + rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist() + rl_msg.nr_weights = self.bo_nr_weights rl_msg.nr_steps = self.bo_steps - rl_msg.nr_bfs = self.bo_nr_weights + rl_msg.nr_dims = self.bo_nr_dims self.active_rl_pub.publish(rl_msg) @@ -328,14 +317,13 @@ class ActiveBOTopic(Node): self.BO.policy_model.set_weights(old_weights) - w = self.BO.policy_model.weights - - rl_msg = DMP() + rl_msg = ActiveRLRequest() rl_msg.interactive_run = 0 - rl_msg.p_x = w[:, 0] - rl_msg.p_y = w[:, 1] + rl_msg.weights = self.BO.policy_model.random_weights().flatten('F').tolist() + rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist() + rl_msg.nr_weights = self.bo_nr_weights rl_msg.nr_steps = self.bo_steps - rl_msg.nr_bfs = self.bo_nr_weights + rl_msg.nr_dims = self.bo_nr_dims self.active_rl_pub.publish(rl_msg) @@ -348,14 +336,13 @@ class ActiveBOTopic(Node): x_next = self.BO.next_observation() self.BO.policy_model.set_weights(np.around(x_next, decimals=8)) - w = self.BO.policy_model.weights - - rl_msg = DMP() + rl_msg = ActiveRLRequest() rl_msg.interactive_run = 1 - rl_msg.p_x = w[:, 0] - rl_msg.p_y = w[:, 1] + rl_msg.weights = self.BO.policy_model.random_weights().flatten('F').tolist() + rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist() + rl_msg.nr_weights = self.bo_nr_weights rl_msg.nr_steps = self.bo_steps - rl_msg.nr_bfs = self.bo_nr_weights + rl_msg.nr_dims = self.bo_nr_dims self.active_rl_pub.publish(rl_msg)