added improvement query
This commit is contained in:
parent
62f0e15881
commit
8eb472b065
@ -7,4 +7,5 @@ uint16 nr_episodes
|
|||||||
uint16 nr_runs
|
uint16 nr_runs
|
||||||
string acquisition_function
|
string acquisition_function
|
||||||
float32 metric_parameter
|
float32 metric_parameter
|
||||||
|
uint16 metric_parameter_2
|
||||||
bool save_result
|
bool save_result
|
@ -2,17 +2,22 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
class ImprovementQuery:
|
class ImprovementQuery:
|
||||||
def __init__(self, threshold, period):
|
def __init__(self, threshold, period, last_query, rewards):
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.period = period
|
self.period = period
|
||||||
|
self.last_query = last_query
|
||||||
|
self.rewards = rewards
|
||||||
|
|
||||||
def query(self, reward_array):
|
def query(self):
|
||||||
if reward_array.shape < self.period:
|
if self.rewards.shape[0] < self.period:
|
||||||
|
return False
|
||||||
|
|
||||||
|
elif self.rewards.shape[0] < self.last_query + self.period:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
first = reward_array[-self.period]
|
first = self.rewards[-self.period]
|
||||||
last = reward_array[-1]
|
last = self.rewards[-1]
|
||||||
|
|
||||||
slope = (last - first) / self.period
|
slope = (last - first) / self.period
|
||||||
|
|
||||||
|
@ -78,10 +78,6 @@ class ActiveRLService(Node):
|
|||||||
self.mainloop_callback,
|
self.mainloop_callback,
|
||||||
callback_group=mainloop_callback_group)
|
callback_group=mainloop_callback_group)
|
||||||
|
|
||||||
# time messurements
|
|
||||||
self.begin_time = None
|
|
||||||
self.end_time = None
|
|
||||||
|
|
||||||
def reset_rl_request(self):
|
def reset_rl_request(self):
|
||||||
self.rl_env = None
|
self.rl_env = None
|
||||||
self.rl_seed = None
|
self.rl_seed = None
|
||||||
@ -90,7 +86,6 @@ class ActiveRLService(Node):
|
|||||||
self.interactive_run = 0
|
self.interactive_run = 0
|
||||||
|
|
||||||
def active_rl_callback(self, msg):
|
def active_rl_callback(self, msg):
|
||||||
self.begin_time = time.time()
|
|
||||||
self.rl_env = msg.env
|
self.rl_env = msg.env
|
||||||
self.rl_seed = msg.seed
|
self.rl_seed = msg.seed
|
||||||
self.display_run = msg.display_run
|
self.display_run = msg.display_run
|
||||||
@ -221,10 +216,6 @@ class ActiveRLService(Node):
|
|||||||
rl_response.final_step = self.rl_step
|
rl_response.final_step = self.rl_step
|
||||||
|
|
||||||
self.active_rl_pub.publish(rl_response)
|
self.active_rl_pub.publish(rl_response)
|
||||||
self.end_time = time.time()
|
|
||||||
# self.get_logger().info(f'RL Time: {self.end_time-self.begin_time}, mode: {self.interactive_run}')
|
|
||||||
self.begin_time = None
|
|
||||||
self.end_time = None
|
|
||||||
|
|
||||||
self.env.reset(seed=self.rl_seed)
|
self.env.reset(seed=self.rl_seed)
|
||||||
|
|
||||||
@ -262,11 +253,6 @@ class ActiveRLService(Node):
|
|||||||
self.rl_reward = 0.0
|
self.rl_reward = 0.0
|
||||||
self.rl_pending = False
|
self.rl_pending = False
|
||||||
|
|
||||||
self.end_time = time.time()
|
|
||||||
# self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
|
|
||||||
self.begin_time = None
|
|
||||||
self.end_time = None
|
|
||||||
|
|
||||||
elif self.interactive_run == 2:
|
elif self.interactive_run == 2:
|
||||||
env_reward, step_count = self.complete_run(self.rl_policy)
|
env_reward, step_count = self.complete_run(self.rl_policy)
|
||||||
|
|
||||||
@ -277,12 +263,6 @@ class ActiveRLService(Node):
|
|||||||
|
|
||||||
self.active_rl_pub.publish(rl_response)
|
self.active_rl_pub.publish(rl_response)
|
||||||
|
|
||||||
self.end_time = time.time()
|
|
||||||
# self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
|
|
||||||
|
|
||||||
self.begin_time = None
|
|
||||||
self.end_time = None
|
|
||||||
|
|
||||||
self.reset_rl_request()
|
self.reset_rl_request()
|
||||||
self.rl_pending = False
|
self.rl_pending = False
|
||||||
|
|
||||||
|
@ -73,8 +73,6 @@ class ActiveBOTopic(Node):
|
|||||||
self.rl_weights = None
|
self.rl_weights = None
|
||||||
self.rl_final_step = None
|
self.rl_final_step = None
|
||||||
self.rl_reward = 0.0
|
self.rl_reward = 0.0
|
||||||
self.user_asked = False
|
|
||||||
self.last_user_reward = 0.0
|
|
||||||
|
|
||||||
# State Publisher
|
# State Publisher
|
||||||
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
|
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
|
||||||
@ -92,21 +90,17 @@ class ActiveBOTopic(Node):
|
|||||||
self.best_policy = None
|
self.best_policy = None
|
||||||
self.best_weights = None
|
self.best_weights = None
|
||||||
|
|
||||||
|
# User Query
|
||||||
|
self.last_query = 0
|
||||||
|
self.user_asked = False
|
||||||
|
self.last_user_reward = 0.0
|
||||||
|
|
||||||
# Main loop timer object
|
# Main loop timer object
|
||||||
self.mainloop_timer_period = 0.1
|
self.mainloop_timer_period = 0.1
|
||||||
self.mainloop = self.create_timer(self.mainloop_timer_period,
|
self.mainloop = self.create_timer(self.mainloop_timer_period,
|
||||||
self.mainloop_callback,
|
self.mainloop_callback,
|
||||||
callback_group=mainloop_callback_group)
|
callback_group=mainloop_callback_group)
|
||||||
|
|
||||||
# time messurements
|
|
||||||
self.init_begin = None
|
|
||||||
self.init_end = None
|
|
||||||
|
|
||||||
self.rl_begin = None
|
|
||||||
self.rl_end = None
|
|
||||||
self.user_query_begin = None
|
|
||||||
self.user_query_end = None
|
|
||||||
|
|
||||||
def reset_bo_request(self):
|
def reset_bo_request(self):
|
||||||
self.bo_env = None
|
self.bo_env = None
|
||||||
self.bo_metric = None
|
self.bo_metric = None
|
||||||
@ -138,6 +132,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.bo_runs = msg.nr_runs
|
self.bo_runs = msg.nr_runs
|
||||||
self.bo_acq_fcn = msg.acquisition_function
|
self.bo_acq_fcn = msg.acquisition_function
|
||||||
self.bo_metric_parameter = msg.metric_parameter
|
self.bo_metric_parameter = msg.metric_parameter
|
||||||
|
self.bo_metric_parameter_2 = msg.metric_parameter_2
|
||||||
self.save_result = msg.save_result
|
self.save_result = msg.save_result
|
||||||
self.seed_array = np.zeros((1, self.bo_runs))
|
self.seed_array = np.zeros((1, self.bo_runs))
|
||||||
|
|
||||||
@ -185,10 +180,6 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
if self.init_pending:
|
if self.init_pending:
|
||||||
self.init_step += 1
|
self.init_step += 1
|
||||||
self.init_end = time.time()
|
|
||||||
self.get_logger().info(f'Init Time: {self.init_end-self.init_begin}')
|
|
||||||
self.init_begin = None
|
|
||||||
self.init_end = None
|
|
||||||
|
|
||||||
if self.init_step == self.nr_init:
|
if self.init_step == self.nr_init:
|
||||||
self.init_step = 0
|
self.init_step = 0
|
||||||
@ -224,7 +215,6 @@ class ActiveBOTopic(Node):
|
|||||||
# self.get_logger().info(f'{self.rl_pending}')
|
# self.get_logger().info(f'{self.rl_pending}')
|
||||||
|
|
||||||
if self.init_pending:
|
if self.init_pending:
|
||||||
self.init_begin = time.time()
|
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
seed = self.seed
|
seed = self.seed
|
||||||
else:
|
else:
|
||||||
@ -322,6 +312,12 @@ class ActiveBOTopic(Node):
|
|||||||
elif self.bo_metric == "regular":
|
elif self.bo_metric == "regular":
|
||||||
user_query = RegularQuery(self.bo_metric_parameter, self.current_episode)
|
user_query = RegularQuery(self.bo_metric_parameter, self.current_episode)
|
||||||
|
|
||||||
|
elif self.bo_metric == "improvement":
|
||||||
|
user_query = ImprovementQuery(self.bo_metric_parameter,
|
||||||
|
self.bo_metric_parameter_2,
|
||||||
|
self.last_query,
|
||||||
|
self.reward[:self.current_episode, self.current_run])
|
||||||
|
|
||||||
elif self.bo_metric == "max acquisition":
|
elif self.bo_metric == "max acquisition":
|
||||||
user_query = MaxAcqQuery(self.bo_metric_parameter,
|
user_query = MaxAcqQuery(self.bo_metric_parameter,
|
||||||
self.BO.GP,
|
self.BO.GP,
|
||||||
@ -330,13 +326,11 @@ class ActiveBOTopic(Node):
|
|||||||
acq=self.bo_acq_fcn,
|
acq=self.bo_acq_fcn,
|
||||||
X=self.BO.X)
|
X=self.BO.X)
|
||||||
|
|
||||||
elif self.bo_metric == "improvement":
|
|
||||||
user_query = ImprovementQuery(self.bo_metric_parameter, 10)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
if user_query.query():
|
if user_query.query():
|
||||||
|
self.last_query = self.current_episode
|
||||||
self.user_asked = True
|
self.user_asked = True
|
||||||
active_rl_request = ActiveRL()
|
active_rl_request = ActiveRL()
|
||||||
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
|
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
|
||||||
@ -395,6 +389,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.BO = None
|
self.BO = None
|
||||||
|
|
||||||
self.current_episode = 0
|
self.current_episode = 0
|
||||||
|
self.last_query = 0
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
self.seed_array[0, self.current_run] = self.seed
|
self.seed_array[0, self.current_run] = self.seed
|
||||||
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user