diff --git a/src/active_bo_ros/active_bo_ros/active_rl_topic.py b/src/active_bo_ros/active_bo_ros/active_rl_topic.py index 41df5b4..c8b5650 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_topic.py @@ -73,7 +73,7 @@ class ActiveRLService(Node): self.display_run = False # Main loop timer object - self.mainloop_timer_period = 0.05 + self.mainloop_timer_period = 0.1 self.mainloop = self.create_timer(self.mainloop_timer_period, self.mainloop_callback, callback_group=mainloop_callback_group) diff --git a/src/active_bo_ros/active_bo_ros/interactive_bo.py b/src/active_bo_ros/active_bo_ros/interactive_bo.py index 4255f99..95bd276 100644 --- a/src/active_bo_ros/active_bo_ros/interactive_bo.py +++ b/src/active_bo_ros/active_bo_ros/interactive_bo.py @@ -74,6 +74,7 @@ class ActiveBOTopic(Node): self.rl_final_step = None self.rl_reward = 0.0 self.user_asked = False + self.last_user_reward = 0.0 # State Publisher self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1) @@ -111,6 +112,8 @@ class ActiveBOTopic(Node): self.current_episode = 0 self.save_result = False self.seed_array = None + self.env = None + self.active_bo_pending = False def active_bo_callback(self, msg): if not self.active_bo_pending: @@ -164,6 +167,10 @@ class ActiveBOTopic(Node): self.init_step = 0 self.init_pending = False + if self.user_asked: + self.last_user_reward = self.rl_reward + self.user_asked = False + self.rl_pending = False self.reset_rl_response() @@ -171,16 +178,17 @@ class ActiveBOTopic(Node): if self.active_bo_pending: # set rl environment - if self.bo_env == "Mountain Car": - self.env = Continuous_MountainCarEnv() - elif self.bo_env == "Cartpole": - self.env = CartPoleEnv() - elif self.bo_env == "Acrobot": - self.env = AcrobotEnv() - elif self.bo_env == "Pendulum": - self.env = PendulumEnv() - else: - raise NotImplementedError + if self.env is None: + if self.bo_env == "Mountain Car": + self.env = Continuous_MountainCarEnv() + elif self.bo_env == "Cartpole": + self.env = CartPoleEnv() + elif self.bo_env == "Acrobot": + self.env = AcrobotEnv() + elif self.bo_env == "Pendulum": + self.env = PendulumEnv() + else: + raise NotImplementedError if self.BO is None: self.BO = BayesianOptimization(self.env, @@ -194,7 +202,7 @@ class ActiveBOTopic(Node): # self.BO.initialize() self.init_pending = True self.get_logger().info('BO Initialization is starting!') - self.get_logger().info(f'{self.rl_pending}') + # self.get_logger().info(f'{self.rl_pending}') if self.init_pending and not self.rl_pending: @@ -283,8 +291,6 @@ class ActiveBOTopic(Node): self.get_logger().info('Responding: Active BO') self.active_bo_pub.publish(bo_response) self.reset_bo_request() - self.active_bo_pending = False - self.BO = None else: if self.rl_pending or self.init_pending: @@ -388,9 +394,7 @@ class ActiveBOTopic(Node): state_msg.current_episode = self.current_episode + 1 \ if self.current_episode < self.bo_episodes else self.bo_episodes state_msg.best_reward = float(self.best_reward) - if self.user_asked: - state_msg.last_user_reward = float(self.rl_reward) - self.user_asked = False + state_msg.last_user_reward = self.last_user_reward self.state_pub.publish(state_msg) diff --git a/src/active_bo_ros/launch/launch_active_bo.launch.py b/src/active_bo_ros/launch/launch_active_bo.launch.py index 465dc21..cf94771 100755 --- a/src/active_bo_ros/launch/launch_active_bo.launch.py +++ b/src/active_bo_ros/launch/launch_active_bo.launch.py @@ -16,14 +16,14 @@ def generate_launch_description(): ) ) - rl_launch = IncludeLaunchDescription( - PythonLaunchDescriptionSource( - os.path.join( - get_package_share_directory('active_bo_ros'), - 'rl_service.launch.py' - ) - ) - ) + # rl_launch = IncludeLaunchDescription( + # PythonLaunchDescriptionSource( + # os.path.join( + # get_package_share_directory('active_bo_ros'), + # 'rl_service.launch.py' + # ) + # ) + # ) bo_launch = IncludeLaunchDescription( PythonLaunchDescriptionSource( @@ -36,6 +36,6 @@ def generate_launch_description(): return LaunchDescription([ websocket_launch, - rl_launch, + # rl_launch, bo_launch ])