From b902e074244afe513ca876084922af645cd61621 Mon Sep 17 00:00:00 2001 From: Niko Date: Mon, 12 Jun 2023 13:57:36 +0200 Subject: [PATCH] Improved simulation time --- .../active_bo_ros/active_rl_topic.py | 93 +++++++++++++++---- .../active_bo_ros/interactive_bo.py | 51 ++++++---- 2 files changed, 108 insertions(+), 36 deletions(-) diff --git a/src/active_bo_ros/active_bo_ros/active_rl_topic.py b/src/active_bo_ros/active_bo_ros/active_rl_topic.py index c8b5650..3d2aadf 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_topic.py @@ -78,6 +78,10 @@ class ActiveRLService(Node): self.mainloop_callback, callback_group=mainloop_callback_group) + # time messurements + self.begin_time = None + self.end_time = None + def reset_rl_request(self): self.rl_env = None self.rl_seed = None @@ -86,6 +90,7 @@ class ActiveRLService(Node): self.interactive_run = 0 def active_rl_callback(self, msg): + self.begin_time = time.time() self.rl_env = msg.env self.rl_seed = msg.seed self.display_run = msg.display_run @@ -155,6 +160,27 @@ class ActiveRLService(Node): return done + def complete_run(self, policy): + env_reward = 0.0 + step_count = 0 + + self.env.reset(seed=self.rl_seed) + + for i in range(len(policy)): + action = policy[i] + action_clipped = action.clip(min=-1.0, max=1.0) + output = self.env.step(action_clipped.astype(np.float64)) + + env_reward += output[1] + done = output[2] + step_count += 1 + + if done: + break + + self.env.reset(seed=self.rl_seed) + return env_reward, step_count + def mainloop_callback(self): if self.rl_pending: if self.interactive_run == 0: @@ -195,6 +221,10 @@ class ActiveRLService(Node): rl_response.final_step = self.rl_step self.active_rl_pub.publish(rl_response) + self.end_time = time.time() + self.get_logger().info(f'RL Time: {self.end_time-self.begin_time}, mode: {self.interactive_run}') + self.begin_time = None + self.end_time = None self.env.reset(seed=self.rl_seed) @@ -232,32 +262,57 @@ class ActiveRLService(Node): self.rl_reward = 0.0 self.rl_pending = False + self.end_time = time.time() + self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}') + self.begin_time = None + self.end_time = None + elif self.interactive_run == 2: - if not self.policy_sent: - self.rl_step = 0 - self.rl_reward = 0.0 - self.env.reset(seed=self.rl_seed) - self.policy_sent = True - done = self.next_image(self.rl_policy, self.display_run) + env_reward, step_count = self.complete_run(self.rl_policy) - if done: - rl_response = ActiveRLResponse() - rl_response.weights = self.rl_weights - rl_response.reward = self.rl_reward - rl_response.final_step = self.rl_step + rl_response = ActiveRLResponse() + rl_response.weights = self.rl_weights + rl_response.reward = env_reward + rl_response.final_step = step_count - self.active_rl_pub.publish(rl_response) + self.active_rl_pub.publish(rl_response) - # reset flags and attributes - self.reset_eval_request() - self.reset_rl_request() + self.end_time = time.time() + self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}') - self.rl_step = 0 - self.rl_reward = 0.0 - - self.rl_pending = False + self.begin_time = None + self.end_time = None + self.reset_rl_request() + self.rl_pending = False + # if not self.policy_sent: + # self.rl_step = 0 + # self.rl_reward = 0.0 + # self.env.reset(seed=self.rl_seed) + # self.policy_sent = True + # done = self.next_image(self.rl_policy, self.display_run) + # + # if done: + # rl_response = ActiveRLResponse() + # rl_response.weights = self.rl_weights + # rl_response.reward = self.rl_reward + # rl_response.final_step = self.rl_step + # + # self.active_rl_pub.publish(rl_response) + # self.end_time = time.time() + # self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}') + # self.begin_time = None + # self.end_time = None + # + # # reset flags and attributes + # self.reset_eval_request() + # self.reset_rl_request() + # + # self.rl_step = 0 + # self.rl_reward = 0.0 + # + # self.rl_pending = False def main(args=None): diff --git a/src/active_bo_ros/active_bo_ros/interactive_bo.py b/src/active_bo_ros/active_bo_ros/interactive_bo.py index 95bd276..70b8037 100644 --- a/src/active_bo_ros/active_bo_ros/interactive_bo.py +++ b/src/active_bo_ros/active_bo_ros/interactive_bo.py @@ -98,6 +98,15 @@ class ActiveBOTopic(Node): self.mainloop_callback, callback_group=mainloop_callback_group) + # time messurements + self.init_begin = None + self.init_end = None + + self.rl_begin = None + self.rl_end = None + self.user_query_begin = None + self.user_query_end = None + def reset_bo_request(self): self.bo_env = None self.bo_metric = None @@ -144,6 +153,18 @@ class ActiveBOTopic(Node): else: self.seed = None + # set rl environment + if self.bo_env == "Mountain Car": + self.env = Continuous_MountainCarEnv() + elif self.bo_env == "Cartpole": + self.env = CartPoleEnv() + elif self.bo_env == "Acrobot": + self.env = AcrobotEnv() + elif self.bo_env == "Pendulum": + self.env = PendulumEnv() + else: + raise NotImplementedError + def reset_rl_response(self): self.rl_weights = None self.rl_final_step = None @@ -163,6 +184,11 @@ class ActiveBOTopic(Node): if self.init_pending: self.init_step += 1 + self.init_end = time.time() + self.get_logger().info(f'Init Time: {self.init_end-self.init_begin}') + self.init_begin = None + self.init_end = None + if self.init_step == self.nr_init: self.init_step = 0 self.init_pending = False @@ -175,20 +201,12 @@ class ActiveBOTopic(Node): self.reset_rl_response() def mainloop_callback(self): - if self.active_bo_pending: + if not self.active_bo_pending: + return - # set rl environment - if self.env is None: - if self.bo_env == "Mountain Car": - self.env = Continuous_MountainCarEnv() - elif self.bo_env == "Cartpole": - self.env = CartPoleEnv() - elif self.bo_env == "Acrobot": - self.env = AcrobotEnv() - elif self.bo_env == "Pendulum": - self.env = PendulumEnv() - else: - raise NotImplementedError + else: + if self.rl_pending: + return if self.BO is None: self.BO = BayesianOptimization(self.env, @@ -204,8 +222,8 @@ class ActiveBOTopic(Node): self.get_logger().info('BO Initialization is starting!') # self.get_logger().info(f'{self.rl_pending}') - if self.init_pending and not self.rl_pending: - + if self.init_pending: + self.init_begin = time.time() if self.bo_fixed_seed: seed = self.seed else: @@ -222,7 +240,6 @@ class ActiveBOTopic(Node): self.active_rl_pub.publish(rl_msg) self.rl_pending = True - return if self.current_run == self.bo_runs: bo_response = ActiveBOResponse() @@ -293,7 +310,7 @@ class ActiveBOTopic(Node): self.reset_bo_request() else: - if self.rl_pending or self.init_pending: + if self.init_pending: return else: if self.current_episode < self.bo_episodes: