From 3a8acb6807af9a0741444d92bad5c8968983216e Mon Sep 17 00:00:00 2001 From: Niko Date: Mon, 5 Jun 2023 14:56:14 +0200 Subject: [PATCH] debugging regular --- .../active_bo_ros/UserQuery/regular_query.py | 15 +++++++-------- .../active_bo_ros/active_bo_topic.py | 4 ++-- .../active_bo_ros/active_rl_topic.py | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/active_bo_ros/active_bo_ros/UserQuery/regular_query.py b/src/active_bo_ros/active_bo_ros/UserQuery/regular_query.py index b8949e3..7349859 100644 --- a/src/active_bo_ros/active_bo_ros/UserQuery/regular_query.py +++ b/src/active_bo_ros/active_bo_ros/UserQuery/regular_query.py @@ -1,13 +1,12 @@ class RegularQuery: - def __init__(self, regular): - self.regular = regular - self.counter = 0 + def __init__(self, regular, episode): + self.regular = int(regular) + self.counter = episode def query(self): - if self.counter < self.regular: - self.counter += 1 - return False + + if self.counter % self.regular == 0 and self.counter != 0: + return True else: - self.counter = 0 - return True + return False diff --git a/src/active_bo_ros/active_bo_ros/active_bo_topic.py b/src/active_bo_ros/active_bo_ros/active_bo_topic.py index b9e17dd..b9284ff 100644 --- a/src/active_bo_ros/active_bo_ros/active_bo_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_bo_topic.py @@ -212,7 +212,7 @@ class ActiveBOTopic(Node): home_dir = os.path.expanduser('~') file_path = os.path.join(home_dir, 'Documents/IntRLResults') filename = env + '-' + acq + '-' + self.bo_metric + '-' \ - + str(self.bo_metric_parameter) + '-' \ + + str(round(self.bo_metric_parameter, 2)) + '-' \ + str(self.bo_nr_weights) + '-' + str(time.time()) filename = filename.replace('.', '_') + '.csv' path = os.path.join(file_path, filename) @@ -251,7 +251,7 @@ class ActiveBOTopic(Node): user_query = RandomQuery(self.bo_metric_parameter) elif self.bo_metric == "regular": - user_query = RegularQuery(self.bo_metric_parameter) + user_query = RegularQuery(self.bo_metric_parameter, self.current_episode) elif self.bo_metric == "max acquisition": user_query = MaxAcqQuery(self.bo_metric_parameter, diff --git a/src/active_bo_ros/active_bo_ros/active_rl_topic.py b/src/active_bo_ros/active_bo_ros/active_rl_topic.py index 73c579a..9a5d56a 100644 --- a/src/active_bo_ros/active_bo_ros/active_rl_topic.py +++ b/src/active_bo_ros/active_bo_ros/active_rl_topic.py @@ -186,7 +186,7 @@ class ActiveRLService(Node): if done: rl_response = ActiveRLResponse() - rl_response.weights = self.rl_weights + rl_response.weights = self.eval_weights rl_response.reward = self.rl_reward rl_response.final_step = self.rl_step