From f927ce7ae63686d46b456f32b22326f974994ee7 Mon Sep 17 00:00:00 2001 From: Niko Date: Wed, 14 Jun 2023 15:38:10 +0200 Subject: [PATCH] added improvement query --- src/active_bo_msgs/msg/ActiveBORequest.msg | 3 ++- .../active_bo_ros/ReinforcementLearning/CartPole.py | 2 +- .../active_bo_ros/UserQuery/improvement_query.py | 2 +- src/active_bo_ros/active_bo_ros/active_rl_topic.py | 2 +- src/active_bo_ros/active_bo_ros/interactive_bo.py | 9 +++++++-- 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/active_bo_msgs/msg/ActiveBORequest.msg b/src/active_bo_msgs/msg/ActiveBORequest.msg index 65162f0..b8f5e1a 100644 --- a/src/active_bo_msgs/msg/ActiveBORequest.msg +++ b/src/active_bo_msgs/msg/ActiveBORequest.msg @@ -8,4 +8,5 @@ uint16 nr_runs string acquisition_function float32 metric_parameter uint16 metric_parameter_2 -bool save_result \ No newline at end of file +bool save_result +bool overwrite \ No newline at end of file diff --git a/src/active_bo_ros/active_bo_ros/ReinforcementLearning/CartPole.py b/src/active_bo_ros/active_bo_ros/ReinforcementLearning/CartPole.py index b2c9660..8ca5ce1 100644 --- a/src/active_bo_ros/active_bo_ros/ReinforcementLearning/CartPole.py +++ b/src/active_bo_ros/active_bo_ros/ReinforcementLearning/CartPole.py @@ -166,7 +166,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): ) if not terminated: - reward = 1.0 + reward = 1.0 - math.pow(action, 2) * 0.1 elif self.steps_beyond_terminated is None: # Pole just fell! self.steps_beyond_terminated = 0 diff --git a/src/active_bo_ros/active_bo_ros/UserQuery/improvement_query.py b/src/active_bo_ros/active_bo_ros/UserQuery/improvement_query.py index 6278fe0..fc6c3b3 100644 --- a/src/active_bo_ros/active_bo_ros/UserQuery/improvement_query.py +++ b/src/active_bo_ros/active_bo_ros/UserQuery/improvement_query.py @@ -16,7 +16,7 @@ class ImprovementQuery: return False else: - first = self.rewards[-self.period] + first = self.rewards[-self.period-1] last = self.rewards[-1] slope = (last - first) / self.period diff --git a/src/active_bo_ros/active_bo_ros/active_rl_topic.py b/src/active_bo_ros/active_bo_ros/active_rl_topic.py index f425d35..1100efc 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_topic.py @@ -104,7 +104,7 @@ class ActiveRLService(Node): else: raise NotImplementedError - self.get_logger().info('Active RL: Called!') + # self.get_logger().info('Active RL: Called!') self.env.reset(seed=self.rl_seed) self.rl_pending = True self.policy_sent = False diff --git a/src/active_bo_ros/active_bo_ros/interactive_bo.py b/src/active_bo_ros/active_bo_ros/interactive_bo.py index a5c2434..3241910 100644 --- a/src/active_bo_ros/active_bo_ros/interactive_bo.py +++ b/src/active_bo_ros/active_bo_ros/interactive_bo.py @@ -94,6 +94,7 @@ class ActiveBOTopic(Node): self.last_query = 0 self.user_asked = False self.last_user_reward = 0.0 + self.overwrite = False # Main loop timer object self.mainloop_timer_period = 0.1 @@ -118,6 +119,7 @@ class ActiveBOTopic(Node): self.env = None self.active_bo_pending = False self.BO = None + self.overwrite = False def active_bo_callback(self, msg): if not self.active_bo_pending: @@ -135,6 +137,7 @@ class ActiveBOTopic(Node): self.bo_metric_parameter_2 = msg.metric_parameter_2 self.save_result = msg.save_result self.seed_array = np.zeros((1, self.bo_runs)) + self.overwrite = msg.overwrite # initialize self.reward = np.zeros((self.bo_episodes+self.nr_init, self.bo_runs)) @@ -375,7 +378,9 @@ class ActiveBOTopic(Node): self.active_rl_pub.publish(rl_msg) self.current_episode += 1 - # self.get_logger().info(f'Current Episode: {self.current_episode}') + self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y) + self.get_logger().info(f'Current Episode: {self.current_episode},' + f' best reward: {self.reward[self.current_episode, self.current_run]}') else: self.best_policy[:, self.current_run], \ self.best_pol_reward[:, self.current_run], \ @@ -383,7 +388,7 @@ class ActiveBOTopic(Node): # self.get_logger().info(f'best idx: {idx}') - self.reward[:, self.current_run] = self.BO.best_reward.T + # self.reward[:, self.current_run] = self.BO.best_reward.T if self.current_run < self.bo_runs - 1: self.BO = None