added improvement query
This commit is contained in:
parent
8eb472b065
commit
f927ce7ae6
@ -8,4 +8,5 @@ uint16 nr_runs
|
||||
string acquisition_function
|
||||
float32 metric_parameter
|
||||
uint16 metric_parameter_2
|
||||
bool save_result
|
||||
bool save_result
|
||||
bool overwrite
|
@ -166,7 +166,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
)
|
||||
|
||||
if not terminated:
|
||||
reward = 1.0
|
||||
reward = 1.0 - math.pow(action, 2) * 0.1
|
||||
elif self.steps_beyond_terminated is None:
|
||||
# Pole just fell!
|
||||
self.steps_beyond_terminated = 0
|
||||
|
@ -16,7 +16,7 @@ class ImprovementQuery:
|
||||
return False
|
||||
|
||||
else:
|
||||
first = self.rewards[-self.period]
|
||||
first = self.rewards[-self.period-1]
|
||||
last = self.rewards[-1]
|
||||
|
||||
slope = (last - first) / self.period
|
||||
|
@ -104,7 +104,7 @@ class ActiveRLService(Node):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.get_logger().info('Active RL: Called!')
|
||||
# self.get_logger().info('Active RL: Called!')
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
self.rl_pending = True
|
||||
self.policy_sent = False
|
||||
|
@ -94,6 +94,7 @@ class ActiveBOTopic(Node):
|
||||
self.last_query = 0
|
||||
self.user_asked = False
|
||||
self.last_user_reward = 0.0
|
||||
self.overwrite = False
|
||||
|
||||
# Main loop timer object
|
||||
self.mainloop_timer_period = 0.1
|
||||
@ -118,6 +119,7 @@ class ActiveBOTopic(Node):
|
||||
self.env = None
|
||||
self.active_bo_pending = False
|
||||
self.BO = None
|
||||
self.overwrite = False
|
||||
|
||||
def active_bo_callback(self, msg):
|
||||
if not self.active_bo_pending:
|
||||
@ -135,6 +137,7 @@ class ActiveBOTopic(Node):
|
||||
self.bo_metric_parameter_2 = msg.metric_parameter_2
|
||||
self.save_result = msg.save_result
|
||||
self.seed_array = np.zeros((1, self.bo_runs))
|
||||
self.overwrite = msg.overwrite
|
||||
|
||||
# initialize
|
||||
self.reward = np.zeros((self.bo_episodes+self.nr_init, self.bo_runs))
|
||||
@ -375,7 +378,9 @@ class ActiveBOTopic(Node):
|
||||
self.active_rl_pub.publish(rl_msg)
|
||||
|
||||
self.current_episode += 1
|
||||
# self.get_logger().info(f'Current Episode: {self.current_episode}')
|
||||
self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y)
|
||||
self.get_logger().info(f'Current Episode: {self.current_episode},'
|
||||
f' best reward: {self.reward[self.current_episode, self.current_run]}')
|
||||
else:
|
||||
self.best_policy[:, self.current_run], \
|
||||
self.best_pol_reward[:, self.current_run], \
|
||||
@ -383,7 +388,7 @@ class ActiveBOTopic(Node):
|
||||
|
||||
# self.get_logger().info(f'best idx: {idx}')
|
||||
|
||||
self.reward[:, self.current_run] = self.BO.best_reward.T
|
||||
# self.reward[:, self.current_run] = self.BO.best_reward.T
|
||||
|
||||
if self.current_run < self.bo_runs - 1:
|
||||
self.BO = None
|
||||
|
Loading…
Reference in New Issue
Block a user