diff --git a/src/active_bo_msgs/msg/ActiveBORequest.msg b/src/active_bo_msgs/msg/ActiveBORequest.msg index ad4bed8..65162f0 100644 --- a/src/active_bo_msgs/msg/ActiveBORequest.msg +++ b/src/active_bo_msgs/msg/ActiveBORequest.msg @@ -7,4 +7,5 @@ uint16 nr_episodes uint16 nr_runs string acquisition_function float32 metric_parameter +uint16 metric_parameter_2 bool save_result \ No newline at end of file diff --git a/src/active_bo_ros/active_bo_ros/UserQuery/improvement_query.py b/src/active_bo_ros/active_bo_ros/UserQuery/improvement_query.py index 19afd15..6278fe0 100644 --- a/src/active_bo_ros/active_bo_ros/UserQuery/improvement_query.py +++ b/src/active_bo_ros/active_bo_ros/UserQuery/improvement_query.py @@ -2,17 +2,22 @@ import numpy as np class ImprovementQuery: - def __init__(self, threshold, period): + def __init__(self, threshold, period, last_query, rewards): self.threshold = threshold self.period = period + self.last_query = last_query + self.rewards = rewards - def query(self, reward_array): - if reward_array.shape < self.period: + def query(self): + if self.rewards.shape[0] < self.period: + return False + + elif self.rewards.shape[0] < self.last_query + self.period: return False else: - first = reward_array[-self.period] - last = reward_array[-1] + first = self.rewards[-self.period] + last = self.rewards[-1] slope = (last - first) / self.period diff --git a/src/active_bo_ros/active_bo_ros/active_rl_topic.py b/src/active_bo_ros/active_bo_ros/active_rl_topic.py index 782ceea..f425d35 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_topic.py @@ -78,10 +78,6 @@ class ActiveRLService(Node): self.mainloop_callback, callback_group=mainloop_callback_group) - # time messurements - self.begin_time = None - self.end_time = None - def reset_rl_request(self): self.rl_env = None self.rl_seed = None @@ -90,7 +86,6 @@ class ActiveRLService(Node): self.interactive_run = 0 def active_rl_callback(self, msg): - self.begin_time = time.time() self.rl_env = msg.env self.rl_seed = msg.seed self.display_run = msg.display_run @@ -221,10 +216,6 @@ class ActiveRLService(Node): rl_response.final_step = self.rl_step self.active_rl_pub.publish(rl_response) - self.end_time = time.time() - # self.get_logger().info(f'RL Time: {self.end_time-self.begin_time}, mode: {self.interactive_run}') - self.begin_time = None - self.end_time = None self.env.reset(seed=self.rl_seed) @@ -262,11 +253,6 @@ class ActiveRLService(Node): self.rl_reward = 0.0 self.rl_pending = False - self.end_time = time.time() - # self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}') - self.begin_time = None - self.end_time = None - elif self.interactive_run == 2: env_reward, step_count = self.complete_run(self.rl_policy) @@ -277,12 +263,6 @@ class ActiveRLService(Node): self.active_rl_pub.publish(rl_response) - self.end_time = time.time() - # self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}') - - self.begin_time = None - self.end_time = None - self.reset_rl_request() self.rl_pending = False diff --git a/src/active_bo_ros/active_bo_ros/interactive_bo.py b/src/active_bo_ros/active_bo_ros/interactive_bo.py index 1c57c6b..a5c2434 100644 --- a/src/active_bo_ros/active_bo_ros/interactive_bo.py +++ b/src/active_bo_ros/active_bo_ros/interactive_bo.py @@ -73,8 +73,6 @@ class ActiveBOTopic(Node): self.rl_weights = None self.rl_final_step = None self.rl_reward = 0.0 - self.user_asked = False - self.last_user_reward = 0.0 # State Publisher self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1) @@ -92,21 +90,17 @@ class ActiveBOTopic(Node): self.best_policy = None self.best_weights = None + # User Query + self.last_query = 0 + self.user_asked = False + self.last_user_reward = 0.0 + # Main loop timer object self.mainloop_timer_period = 0.1 self.mainloop = self.create_timer(self.mainloop_timer_period, self.mainloop_callback, callback_group=mainloop_callback_group) - # time messurements - self.init_begin = None - self.init_end = None - - self.rl_begin = None - self.rl_end = None - self.user_query_begin = None - self.user_query_end = None - def reset_bo_request(self): self.bo_env = None self.bo_metric = None @@ -138,6 +132,7 @@ class ActiveBOTopic(Node): self.bo_runs = msg.nr_runs self.bo_acq_fcn = msg.acquisition_function self.bo_metric_parameter = msg.metric_parameter + self.bo_metric_parameter_2 = msg.metric_parameter_2 self.save_result = msg.save_result self.seed_array = np.zeros((1, self.bo_runs)) @@ -185,10 +180,6 @@ class ActiveBOTopic(Node): if self.init_pending: self.init_step += 1 - self.init_end = time.time() - self.get_logger().info(f'Init Time: {self.init_end-self.init_begin}') - self.init_begin = None - self.init_end = None if self.init_step == self.nr_init: self.init_step = 0 @@ -224,7 +215,6 @@ class ActiveBOTopic(Node): # self.get_logger().info(f'{self.rl_pending}') if self.init_pending: - self.init_begin = time.time() if self.bo_fixed_seed: seed = self.seed else: @@ -322,6 +312,12 @@ class ActiveBOTopic(Node): elif self.bo_metric == "regular": user_query = RegularQuery(self.bo_metric_parameter, self.current_episode) + elif self.bo_metric == "improvement": + user_query = ImprovementQuery(self.bo_metric_parameter, + self.bo_metric_parameter_2, + self.last_query, + self.reward[:self.current_episode, self.current_run]) + elif self.bo_metric == "max acquisition": user_query = MaxAcqQuery(self.bo_metric_parameter, self.BO.GP, @@ -330,13 +326,11 @@ class ActiveBOTopic(Node): acq=self.bo_acq_fcn, X=self.BO.X) - elif self.bo_metric == "improvement": - user_query = ImprovementQuery(self.bo_metric_parameter, 10) - else: raise NotImplementedError if user_query.query(): + self.last_query = self.current_episode self.user_asked = True active_rl_request = ActiveRL() old_policy, y_max, old_weights, _ = self.BO.get_best_result() @@ -395,6 +389,7 @@ class ActiveBOTopic(Node): self.BO = None self.current_episode = 0 + self.last_query = 0 if self.bo_fixed_seed: self.seed_array[0, self.current_run] = self.seed self.seed = int(np.random.randint(1, 2147483647, 1)[0])