added improvement query
This commit is contained in:
parent
8eb472b065
commit
f927ce7ae6
@ -9,3 +9,4 @@ string acquisition_function
|
|||||||
float32 metric_parameter
|
float32 metric_parameter
|
||||||
uint16 metric_parameter_2
|
uint16 metric_parameter_2
|
||||||
bool save_result
|
bool save_result
|
||||||
|
bool overwrite
|
@ -166,7 +166,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not terminated:
|
if not terminated:
|
||||||
reward = 1.0
|
reward = 1.0 - math.pow(action, 2) * 0.1
|
||||||
elif self.steps_beyond_terminated is None:
|
elif self.steps_beyond_terminated is None:
|
||||||
# Pole just fell!
|
# Pole just fell!
|
||||||
self.steps_beyond_terminated = 0
|
self.steps_beyond_terminated = 0
|
||||||
|
@ -16,7 +16,7 @@ class ImprovementQuery:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
first = self.rewards[-self.period]
|
first = self.rewards[-self.period-1]
|
||||||
last = self.rewards[-1]
|
last = self.rewards[-1]
|
||||||
|
|
||||||
slope = (last - first) / self.period
|
slope = (last - first) / self.period
|
||||||
|
@ -104,7 +104,7 @@ class ActiveRLService(Node):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
self.get_logger().info('Active RL: Called!')
|
# self.get_logger().info('Active RL: Called!')
|
||||||
self.env.reset(seed=self.rl_seed)
|
self.env.reset(seed=self.rl_seed)
|
||||||
self.rl_pending = True
|
self.rl_pending = True
|
||||||
self.policy_sent = False
|
self.policy_sent = False
|
||||||
|
@ -94,6 +94,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.last_query = 0
|
self.last_query = 0
|
||||||
self.user_asked = False
|
self.user_asked = False
|
||||||
self.last_user_reward = 0.0
|
self.last_user_reward = 0.0
|
||||||
|
self.overwrite = False
|
||||||
|
|
||||||
# Main loop timer object
|
# Main loop timer object
|
||||||
self.mainloop_timer_period = 0.1
|
self.mainloop_timer_period = 0.1
|
||||||
@ -118,6 +119,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.env = None
|
self.env = None
|
||||||
self.active_bo_pending = False
|
self.active_bo_pending = False
|
||||||
self.BO = None
|
self.BO = None
|
||||||
|
self.overwrite = False
|
||||||
|
|
||||||
def active_bo_callback(self, msg):
|
def active_bo_callback(self, msg):
|
||||||
if not self.active_bo_pending:
|
if not self.active_bo_pending:
|
||||||
@ -135,6 +137,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.bo_metric_parameter_2 = msg.metric_parameter_2
|
self.bo_metric_parameter_2 = msg.metric_parameter_2
|
||||||
self.save_result = msg.save_result
|
self.save_result = msg.save_result
|
||||||
self.seed_array = np.zeros((1, self.bo_runs))
|
self.seed_array = np.zeros((1, self.bo_runs))
|
||||||
|
self.overwrite = msg.overwrite
|
||||||
|
|
||||||
# initialize
|
# initialize
|
||||||
self.reward = np.zeros((self.bo_episodes+self.nr_init, self.bo_runs))
|
self.reward = np.zeros((self.bo_episodes+self.nr_init, self.bo_runs))
|
||||||
@ -375,7 +378,9 @@ class ActiveBOTopic(Node):
|
|||||||
self.active_rl_pub.publish(rl_msg)
|
self.active_rl_pub.publish(rl_msg)
|
||||||
|
|
||||||
self.current_episode += 1
|
self.current_episode += 1
|
||||||
# self.get_logger().info(f'Current Episode: {self.current_episode}')
|
self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y)
|
||||||
|
self.get_logger().info(f'Current Episode: {self.current_episode},'
|
||||||
|
f' best reward: {self.reward[self.current_episode, self.current_run]}')
|
||||||
else:
|
else:
|
||||||
self.best_policy[:, self.current_run], \
|
self.best_policy[:, self.current_run], \
|
||||||
self.best_pol_reward[:, self.current_run], \
|
self.best_pol_reward[:, self.current_run], \
|
||||||
@ -383,7 +388,7 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
# self.get_logger().info(f'best idx: {idx}')
|
# self.get_logger().info(f'best idx: {idx}')
|
||||||
|
|
||||||
self.reward[:, self.current_run] = self.BO.best_reward.T
|
# self.reward[:, self.current_run] = self.BO.best_reward.T
|
||||||
|
|
||||||
if self.current_run < self.bo_runs - 1:
|
if self.current_run < self.bo_runs - 1:
|
||||||
self.BO = None
|
self.BO = None
|
||||||
|
Loading…
Reference in New Issue
Block a user