added improvement query

This commit is contained in:
Niko Feith 2023-06-13 13:23:33 +02:00
parent 62f0e15881
commit 8eb472b065
4 changed files with 25 additions and 44 deletions

View File

@ -7,4 +7,5 @@ uint16 nr_episodes
uint16 nr_runs
string acquisition_function
float32 metric_parameter
uint16 metric_parameter_2
bool save_result

View File

@ -2,17 +2,22 @@ import numpy as np
class ImprovementQuery:
def __init__(self, threshold, period):
def __init__(self, threshold, period, last_query, rewards):
self.threshold = threshold
self.period = period
self.last_query = last_query
self.rewards = rewards
def query(self, reward_array):
if reward_array.shape < self.period:
def query(self):
if self.rewards.shape[0] < self.period:
return False
elif self.rewards.shape[0] < self.last_query + self.period:
return False
else:
first = reward_array[-self.period]
last = reward_array[-1]
first = self.rewards[-self.period]
last = self.rewards[-1]
slope = (last - first) / self.period

View File

@ -78,10 +78,6 @@ class ActiveRLService(Node):
self.mainloop_callback,
callback_group=mainloop_callback_group)
# time messurements
self.begin_time = None
self.end_time = None
def reset_rl_request(self):
self.rl_env = None
self.rl_seed = None
@ -90,7 +86,6 @@ class ActiveRLService(Node):
self.interactive_run = 0
def active_rl_callback(self, msg):
self.begin_time = time.time()
self.rl_env = msg.env
self.rl_seed = msg.seed
self.display_run = msg.display_run
@ -221,10 +216,6 @@ class ActiveRLService(Node):
rl_response.final_step = self.rl_step
self.active_rl_pub.publish(rl_response)
self.end_time = time.time()
# self.get_logger().info(f'RL Time: {self.end_time-self.begin_time}, mode: {self.interactive_run}')
self.begin_time = None
self.end_time = None
self.env.reset(seed=self.rl_seed)
@ -262,11 +253,6 @@ class ActiveRLService(Node):
self.rl_reward = 0.0
self.rl_pending = False
self.end_time = time.time()
# self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
self.begin_time = None
self.end_time = None
elif self.interactive_run == 2:
env_reward, step_count = self.complete_run(self.rl_policy)
@ -277,12 +263,6 @@ class ActiveRLService(Node):
self.active_rl_pub.publish(rl_response)
self.end_time = time.time()
# self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
self.begin_time = None
self.end_time = None
self.reset_rl_request()
self.rl_pending = False

View File

@ -73,8 +73,6 @@ class ActiveBOTopic(Node):
self.rl_weights = None
self.rl_final_step = None
self.rl_reward = 0.0
self.user_asked = False
self.last_user_reward = 0.0
# State Publisher
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
@ -92,21 +90,17 @@ class ActiveBOTopic(Node):
self.best_policy = None
self.best_weights = None
# User Query
self.last_query = 0
self.user_asked = False
self.last_user_reward = 0.0
# Main loop timer object
self.mainloop_timer_period = 0.1
self.mainloop = self.create_timer(self.mainloop_timer_period,
self.mainloop_callback,
callback_group=mainloop_callback_group)
# time messurements
self.init_begin = None
self.init_end = None
self.rl_begin = None
self.rl_end = None
self.user_query_begin = None
self.user_query_end = None
def reset_bo_request(self):
self.bo_env = None
self.bo_metric = None
@ -138,6 +132,7 @@ class ActiveBOTopic(Node):
self.bo_runs = msg.nr_runs
self.bo_acq_fcn = msg.acquisition_function
self.bo_metric_parameter = msg.metric_parameter
self.bo_metric_parameter_2 = msg.metric_parameter_2
self.save_result = msg.save_result
self.seed_array = np.zeros((1, self.bo_runs))
@ -185,10 +180,6 @@ class ActiveBOTopic(Node):
if self.init_pending:
self.init_step += 1
self.init_end = time.time()
self.get_logger().info(f'Init Time: {self.init_end-self.init_begin}')
self.init_begin = None
self.init_end = None
if self.init_step == self.nr_init:
self.init_step = 0
@ -224,7 +215,6 @@ class ActiveBOTopic(Node):
# self.get_logger().info(f'{self.rl_pending}')
if self.init_pending:
self.init_begin = time.time()
if self.bo_fixed_seed:
seed = self.seed
else:
@ -322,6 +312,12 @@ class ActiveBOTopic(Node):
elif self.bo_metric == "regular":
user_query = RegularQuery(self.bo_metric_parameter, self.current_episode)
elif self.bo_metric == "improvement":
user_query = ImprovementQuery(self.bo_metric_parameter,
self.bo_metric_parameter_2,
self.last_query,
self.reward[:self.current_episode, self.current_run])
elif self.bo_metric == "max acquisition":
user_query = MaxAcqQuery(self.bo_metric_parameter,
self.BO.GP,
@ -330,13 +326,11 @@ class ActiveBOTopic(Node):
acq=self.bo_acq_fcn,
X=self.BO.X)
elif self.bo_metric == "improvement":
user_query = ImprovementQuery(self.bo_metric_parameter, 10)
else:
raise NotImplementedError
if user_query.query():
self.last_query = self.current_episode
self.user_asked = True
active_rl_request = ActiveRL()
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
@ -395,6 +389,7 @@ class ActiveBOTopic(Node):
self.BO = None
self.current_episode = 0
self.last_query = 0
if self.bo_fixed_seed:
self.seed_array[0, self.current_run] = self.seed
self.seed = int(np.random.randint(1, 2147483647, 1)[0])