added improvement query
This commit is contained in:
parent
62f0e15881
commit
8eb472b065
@ -7,4 +7,5 @@ uint16 nr_episodes
|
||||
uint16 nr_runs
|
||||
string acquisition_function
|
||||
float32 metric_parameter
|
||||
uint16 metric_parameter_2
|
||||
bool save_result
|
@ -2,17 +2,22 @@ import numpy as np
|
||||
|
||||
|
||||
class ImprovementQuery:
|
||||
def __init__(self, threshold, period):
|
||||
def __init__(self, threshold, period, last_query, rewards):
|
||||
self.threshold = threshold
|
||||
self.period = period
|
||||
self.last_query = last_query
|
||||
self.rewards = rewards
|
||||
|
||||
def query(self, reward_array):
|
||||
if reward_array.shape < self.period:
|
||||
def query(self):
|
||||
if self.rewards.shape[0] < self.period:
|
||||
return False
|
||||
|
||||
elif self.rewards.shape[0] < self.last_query + self.period:
|
||||
return False
|
||||
|
||||
else:
|
||||
first = reward_array[-self.period]
|
||||
last = reward_array[-1]
|
||||
first = self.rewards[-self.period]
|
||||
last = self.rewards[-1]
|
||||
|
||||
slope = (last - first) / self.period
|
||||
|
||||
|
@ -78,10 +78,6 @@ class ActiveRLService(Node):
|
||||
self.mainloop_callback,
|
||||
callback_group=mainloop_callback_group)
|
||||
|
||||
# time messurements
|
||||
self.begin_time = None
|
||||
self.end_time = None
|
||||
|
||||
def reset_rl_request(self):
|
||||
self.rl_env = None
|
||||
self.rl_seed = None
|
||||
@ -90,7 +86,6 @@ class ActiveRLService(Node):
|
||||
self.interactive_run = 0
|
||||
|
||||
def active_rl_callback(self, msg):
|
||||
self.begin_time = time.time()
|
||||
self.rl_env = msg.env
|
||||
self.rl_seed = msg.seed
|
||||
self.display_run = msg.display_run
|
||||
@ -221,10 +216,6 @@ class ActiveRLService(Node):
|
||||
rl_response.final_step = self.rl_step
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
self.end_time = time.time()
|
||||
# self.get_logger().info(f'RL Time: {self.end_time-self.begin_time}, mode: {self.interactive_run}')
|
||||
self.begin_time = None
|
||||
self.end_time = None
|
||||
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
|
||||
@ -262,11 +253,6 @@ class ActiveRLService(Node):
|
||||
self.rl_reward = 0.0
|
||||
self.rl_pending = False
|
||||
|
||||
self.end_time = time.time()
|
||||
# self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
|
||||
self.begin_time = None
|
||||
self.end_time = None
|
||||
|
||||
elif self.interactive_run == 2:
|
||||
env_reward, step_count = self.complete_run(self.rl_policy)
|
||||
|
||||
@ -277,12 +263,6 @@ class ActiveRLService(Node):
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
|
||||
self.end_time = time.time()
|
||||
# self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
|
||||
|
||||
self.begin_time = None
|
||||
self.end_time = None
|
||||
|
||||
self.reset_rl_request()
|
||||
self.rl_pending = False
|
||||
|
||||
|
@ -73,8 +73,6 @@ class ActiveBOTopic(Node):
|
||||
self.rl_weights = None
|
||||
self.rl_final_step = None
|
||||
self.rl_reward = 0.0
|
||||
self.user_asked = False
|
||||
self.last_user_reward = 0.0
|
||||
|
||||
# State Publisher
|
||||
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
|
||||
@ -92,21 +90,17 @@ class ActiveBOTopic(Node):
|
||||
self.best_policy = None
|
||||
self.best_weights = None
|
||||
|
||||
# User Query
|
||||
self.last_query = 0
|
||||
self.user_asked = False
|
||||
self.last_user_reward = 0.0
|
||||
|
||||
# Main loop timer object
|
||||
self.mainloop_timer_period = 0.1
|
||||
self.mainloop = self.create_timer(self.mainloop_timer_period,
|
||||
self.mainloop_callback,
|
||||
callback_group=mainloop_callback_group)
|
||||
|
||||
# time messurements
|
||||
self.init_begin = None
|
||||
self.init_end = None
|
||||
|
||||
self.rl_begin = None
|
||||
self.rl_end = None
|
||||
self.user_query_begin = None
|
||||
self.user_query_end = None
|
||||
|
||||
def reset_bo_request(self):
|
||||
self.bo_env = None
|
||||
self.bo_metric = None
|
||||
@ -138,6 +132,7 @@ class ActiveBOTopic(Node):
|
||||
self.bo_runs = msg.nr_runs
|
||||
self.bo_acq_fcn = msg.acquisition_function
|
||||
self.bo_metric_parameter = msg.metric_parameter
|
||||
self.bo_metric_parameter_2 = msg.metric_parameter_2
|
||||
self.save_result = msg.save_result
|
||||
self.seed_array = np.zeros((1, self.bo_runs))
|
||||
|
||||
@ -185,10 +180,6 @@ class ActiveBOTopic(Node):
|
||||
|
||||
if self.init_pending:
|
||||
self.init_step += 1
|
||||
self.init_end = time.time()
|
||||
self.get_logger().info(f'Init Time: {self.init_end-self.init_begin}')
|
||||
self.init_begin = None
|
||||
self.init_end = None
|
||||
|
||||
if self.init_step == self.nr_init:
|
||||
self.init_step = 0
|
||||
@ -224,7 +215,6 @@ class ActiveBOTopic(Node):
|
||||
# self.get_logger().info(f'{self.rl_pending}')
|
||||
|
||||
if self.init_pending:
|
||||
self.init_begin = time.time()
|
||||
if self.bo_fixed_seed:
|
||||
seed = self.seed
|
||||
else:
|
||||
@ -322,6 +312,12 @@ class ActiveBOTopic(Node):
|
||||
elif self.bo_metric == "regular":
|
||||
user_query = RegularQuery(self.bo_metric_parameter, self.current_episode)
|
||||
|
||||
elif self.bo_metric == "improvement":
|
||||
user_query = ImprovementQuery(self.bo_metric_parameter,
|
||||
self.bo_metric_parameter_2,
|
||||
self.last_query,
|
||||
self.reward[:self.current_episode, self.current_run])
|
||||
|
||||
elif self.bo_metric == "max acquisition":
|
||||
user_query = MaxAcqQuery(self.bo_metric_parameter,
|
||||
self.BO.GP,
|
||||
@ -330,13 +326,11 @@ class ActiveBOTopic(Node):
|
||||
acq=self.bo_acq_fcn,
|
||||
X=self.BO.X)
|
||||
|
||||
elif self.bo_metric == "improvement":
|
||||
user_query = ImprovementQuery(self.bo_metric_parameter, 10)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if user_query.query():
|
||||
self.last_query = self.current_episode
|
||||
self.user_asked = True
|
||||
active_rl_request = ActiveRL()
|
||||
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
|
||||
@ -395,6 +389,7 @@ class ActiveBOTopic(Node):
|
||||
self.BO = None
|
||||
|
||||
self.current_episode = 0
|
||||
self.last_query = 0
|
||||
if self.bo_fixed_seed:
|
||||
self.seed_array[0, self.current_run] = self.seed
|
||||
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
|
Loading…
Reference in New Issue
Block a user