added improvement query

This commit is contained in:
Niko Feith 2023-06-14 15:38:10 +02:00
parent 8eb472b065
commit f927ce7ae6
5 changed files with 12 additions and 6 deletions

View File

@ -8,4 +8,5 @@ uint16 nr_runs
string acquisition_function string acquisition_function
float32 metric_parameter float32 metric_parameter
uint16 metric_parameter_2 uint16 metric_parameter_2
bool save_result bool save_result
bool overwrite

View File

@ -166,7 +166,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
) )
if not terminated: if not terminated:
reward = 1.0 reward = 1.0 - math.pow(action, 2) * 0.1
elif self.steps_beyond_terminated is None: elif self.steps_beyond_terminated is None:
# Pole just fell! # Pole just fell!
self.steps_beyond_terminated = 0 self.steps_beyond_terminated = 0

View File

@ -16,7 +16,7 @@ class ImprovementQuery:
return False return False
else: else:
first = self.rewards[-self.period] first = self.rewards[-self.period-1]
last = self.rewards[-1] last = self.rewards[-1]
slope = (last - first) / self.period slope = (last - first) / self.period

View File

@ -104,7 +104,7 @@ class ActiveRLService(Node):
else: else:
raise NotImplementedError raise NotImplementedError
self.get_logger().info('Active RL: Called!') # self.get_logger().info('Active RL: Called!')
self.env.reset(seed=self.rl_seed) self.env.reset(seed=self.rl_seed)
self.rl_pending = True self.rl_pending = True
self.policy_sent = False self.policy_sent = False

View File

@ -94,6 +94,7 @@ class ActiveBOTopic(Node):
self.last_query = 0 self.last_query = 0
self.user_asked = False self.user_asked = False
self.last_user_reward = 0.0 self.last_user_reward = 0.0
self.overwrite = False
# Main loop timer object # Main loop timer object
self.mainloop_timer_period = 0.1 self.mainloop_timer_period = 0.1
@ -118,6 +119,7 @@ class ActiveBOTopic(Node):
self.env = None self.env = None
self.active_bo_pending = False self.active_bo_pending = False
self.BO = None self.BO = None
self.overwrite = False
def active_bo_callback(self, msg): def active_bo_callback(self, msg):
if not self.active_bo_pending: if not self.active_bo_pending:
@ -135,6 +137,7 @@ class ActiveBOTopic(Node):
self.bo_metric_parameter_2 = msg.metric_parameter_2 self.bo_metric_parameter_2 = msg.metric_parameter_2
self.save_result = msg.save_result self.save_result = msg.save_result
self.seed_array = np.zeros((1, self.bo_runs)) self.seed_array = np.zeros((1, self.bo_runs))
self.overwrite = msg.overwrite
# initialize # initialize
self.reward = np.zeros((self.bo_episodes+self.nr_init, self.bo_runs)) self.reward = np.zeros((self.bo_episodes+self.nr_init, self.bo_runs))
@ -375,7 +378,9 @@ class ActiveBOTopic(Node):
self.active_rl_pub.publish(rl_msg) self.active_rl_pub.publish(rl_msg)
self.current_episode += 1 self.current_episode += 1
# self.get_logger().info(f'Current Episode: {self.current_episode}') self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y)
self.get_logger().info(f'Current Episode: {self.current_episode},'
f' best reward: {self.reward[self.current_episode, self.current_run]}')
else: else:
self.best_policy[:, self.current_run], \ self.best_policy[:, self.current_run], \
self.best_pol_reward[:, self.current_run], \ self.best_pol_reward[:, self.current_run], \
@ -383,7 +388,7 @@ class ActiveBOTopic(Node):
# self.get_logger().info(f'best idx: {idx}') # self.get_logger().info(f'best idx: {idx}')
self.reward[:, self.current_run] = self.BO.best_reward.T # self.reward[:, self.current_run] = self.BO.best_reward.T
if self.current_run < self.bo_runs - 1: if self.current_run < self.bo_runs - 1:
self.BO = None self.BO = None