diff --git a/src/active_bo_ros/active_bo_ros/active_bo_topic.py b/src/active_bo_ros/active_bo_ros/active_bo_topic.py index 5dd0da9..a996e32 100644 --- a/src/active_bo_ros/active_bo_ros/active_bo_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_bo_topic.py @@ -15,6 +15,11 @@ from active_bo_ros.ReinforcementLearning.CartPole import CartPoleEnv from active_bo_ros.ReinforcementLearning.Pendulum import PendulumEnv from active_bo_ros.ReinforcementLearning.Acrobot import AcrobotEnv +from active_bo_ros.UserQuery.random_query import RandomQuery +from active_bo_ros.UserQuery.regular_query import RegularQuery +from active_bo_ros.UserQuery.improvement_query import ImprovementQuery +from active_bo_ros.UserQuery.max_acq_query import MaxAcqQuery + import numpy as np import time @@ -39,6 +44,7 @@ class ActiveBOTopic(Node): self.active_bo_pending = False self.bo_env = None + self.bo_metric = None self.bo_fixed_seed = False self.bo_nr_weights = None self.bo_steps = None @@ -82,6 +88,7 @@ class ActiveBOTopic(Node): def reset_bo_request(self): self.bo_env = None + self.bo_metric = None self.bo_fixed_seed = False self.bo_nr_weights = None self.bo_steps = None @@ -97,6 +104,7 @@ class ActiveBOTopic(Node): self.get_logger().info('Active Bayesian Optimization request pending!') self.active_bo_pending = True self.bo_env = msg.env + self.bo_metric = msg.metric self.bo_fixed_seed = msg.fixed_seed self.bo_nr_weights = msg.nr_weights self.bo_steps = msg.max_steps @@ -181,12 +189,24 @@ class ActiveBOTopic(Node): self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}') else: if self.current_episode < self.bo_episodes: - if np.random.uniform(0.0, 1.0, 1) < self.bo_metric_parameter: + # metrics + if self.bo_metric == "RandomQuery": + user_query = RandomQuery(self.bo_metric_parameter) + + else: + raise NotImplementedError + + if user_query.query(): active_rl_request = ActiveRL() old_policy, _, old_weights = self.BO.get_best_result() + if self.seed is None: + seed = int(np.random.randint(1, 2147483647, 1)[0]) + else: + seed = self.seed + active_rl_request.env = self.bo_env - active_rl_request.seed = self.seed + active_rl_request.seed = seed active_rl_request.policy = old_policy.tolist() active_rl_request.weights = old_weights.tolist() diff --git a/src/active_bo_ros/active_bo_ros/active_rl_topic.py b/src/active_bo_ros/active_bo_ros/active_rl_topic.py index 56d994f..0fd524d 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_topic.py @@ -35,7 +35,6 @@ class ActiveRLService(Node): self.active_rl_callback, 1, callback_group=rl_callback_group) - self.active_rl_pending = False self.rl_env = None self.rl_seed = None self.rl_policy = None @@ -66,7 +65,10 @@ class ActiveRLService(Node): # RL Environments self.env = None + # State Machine Variables self.best_pol_shown = False + self.policy_sent = False + self.active_rl_pending = False # Main loop timer object self.mainloop_timer_period = 0.05 @@ -100,6 +102,8 @@ class ActiveRLService(Node): self.get_logger().info('Active RL: Called!') self.env.reset(seed=self.rl_seed) self.active_rl_pending = True + self.policy_sent = False + self.rl_step = 0 def reset_eval_request(self): self.eval_policy = None @@ -148,9 +152,7 @@ class ActiveRLService(Node): def mainloop_callback(self): if self.active_rl_pending: if not self.best_pol_shown: - done = self.next_image(self.rl_policy) - - if done: + if not self.policy_sent: self.rl_step = 0 self.rl_reward = 0.0 self.env.reset(seed=self.rl_seed) @@ -163,7 +165,13 @@ class ActiveRLService(Node): self.get_logger().info('Active RL: Called!') self.get_logger().info('Active RL: Waiting for Eval!') + self.policy_sent = True + + done = self.next_image(self.rl_policy) + + if done: self.best_pol_shown = True + self.rl_step = 0 elif self.best_pol_shown: if not self.eval_response_received: diff --git a/src/active_bo_ros/setup.py b/src/active_bo_ros/setup.py index 5e1e7ce..ea186d7 100644 --- a/src/active_bo_ros/setup.py +++ b/src/active_bo_ros/setup.py @@ -11,7 +11,9 @@ setup( package_name + '/PolicyModel', package_name + '/ReinforcementLearning', package_name + '/AcquisitionFunctions', - package_name + '/BayesianOptimization'], + package_name + '/BayesianOptimization', + package_name + '/UserQuery', + ], data_files=[ ('share/ament_index/resource_index/packages', ['resource/' + package_name]),