Fixed inconsistency

This commit is contained in:
Niko Feith 2023-06-07 16:19:35 +02:00
parent e99b131ee9
commit 3e278fb0fc
4 changed files with 53 additions and 22 deletions

View File

@ -83,6 +83,7 @@ class ActiveRLService(Node):
self.rl_seed = None self.rl_seed = None
self.rl_policy = None self.rl_policy = None
self.rl_weights = None self.rl_weights = None
self.interactive_run = 0
def active_rl_callback(self, msg): def active_rl_callback(self, msg):
self.rl_env = msg.env self.rl_env = msg.env
@ -240,7 +241,7 @@ class ActiveRLService(Node):
if done: if done:
rl_response = ActiveRLResponse() rl_response = ActiveRLResponse()
rl_response.weights = self.eval_weights rl_response.weights = self.rl_weights
rl_response.reward = self.rl_reward rl_response.reward = self.rl_reward
rl_response.final_step = self.rl_step rl_response.final_step = self.rl_step

View File

@ -73,6 +73,7 @@ class ActiveBOTopic(Node):
self.rl_weights = None self.rl_weights = None
self.rl_final_step = None self.rl_final_step = None
self.rl_reward = 0.0 self.rl_reward = 0.0
self.user_asked = False
# State Publisher # State Publisher
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1) self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
@ -113,7 +114,7 @@ class ActiveBOTopic(Node):
def active_bo_callback(self, msg): def active_bo_callback(self, msg):
if not self.active_bo_pending: if not self.active_bo_pending:
self.get_logger().info('Active Bayesian Optimization request pending!') # self.get_logger().info('Active Bayesian Optimization request pending!')
self.active_bo_pending = True self.active_bo_pending = True
self.bo_env = msg.env self.bo_env = msg.env
self.bo_metric = msg.metric self.bo_metric = msg.metric
@ -136,7 +137,7 @@ class ActiveBOTopic(Node):
# set the seed # set the seed
if self.bo_fixed_seed: if self.bo_fixed_seed:
self.seed = int(np.random.randint(1, 2147483647, 1)[0]) self.seed = int(np.random.randint(1, 2147483647, 1)[0])
self.get_logger().info(str(self.seed)) # self.get_logger().info(str(self.seed))
else: else:
self.seed = None self.seed = None
@ -146,14 +147,14 @@ class ActiveBOTopic(Node):
def active_rl_callback(self, msg): def active_rl_callback(self, msg):
if self.rl_pending: if self.rl_pending:
self.get_logger().info('Active Reinforcement Learning response received!') # self.get_logger().info('Active Reinforcement Learning response received!')
self.rl_weights = msg.weights self.rl_weights = np.array(msg.weights, dtype=np.float64)
self.rl_final_step = msg.final_step self.rl_final_step = msg.final_step
self.rl_reward = msg.reward self.rl_reward = msg.reward
try: try:
self.BO.add_new_observation(self.rl_reward, self.rl_weights) self.BO.add_new_observation(self.rl_reward, self.rl_weights)
self.get_logger().info('Active Reinforcement Learning added new observation!') # self.get_logger().info('Active Reinforcement Learning added new observation!')
except Exception as e: except Exception as e:
self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}') self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}')
@ -188,8 +189,11 @@ class ActiveBOTopic(Node):
acq=self.bo_acq_fcn, acq=self.bo_acq_fcn,
nr_weights=self.bo_nr_weights) nr_weights=self.bo_nr_weights)
self.BO.reset_bo()
# self.BO.initialize() # self.BO.initialize()
self.init_pending = True self.init_pending = True
self.get_logger().info('BO Initialization is starting!')
if self.init_pending and not self.rl_pending: if self.init_pending and not self.rl_pending:
@ -203,9 +207,13 @@ class ActiveBOTopic(Node):
rl_msg.seed = seed rl_msg.seed = seed
rl_msg.display_run = False rl_msg.display_run = False
rl_msg.interactive_run = 2 rl_msg.interactive_run = 2
rl_msg.weights = self.BO.policy_model.random_policy() rl_msg.weights = self.BO.policy_model.random_policy().tolist()
rl_msg.policy = self.BO.policy_model.rollout() rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
self.active_rl_pub.publish(rl_msg)
self.rl_pending = True self.rl_pending = True
return
if self.current_run == self.bo_runs: if self.current_run == self.bo_runs:
bo_response = ActiveBOResponse() bo_response = ActiveBOResponse()
@ -214,7 +222,7 @@ class ActiveBOTopic(Node):
bo_response.best_policy = self.best_policy[:, best_policy_idx].tolist() bo_response.best_policy = self.best_policy[:, best_policy_idx].tolist()
bo_response.best_weights = self.best_weights[:, best_policy_idx].tolist() bo_response.best_weights = self.best_weights[:, best_policy_idx].tolist()
self.get_logger().info(f'Best Policy: {self.best_pol_reward}') # self.get_logger().info(f'Best Policy: {self.best_pol_reward}')
self.get_logger().info(f'{best_policy_idx}, {int(self.seed_array[0, best_policy_idx])}') self.get_logger().info(f'{best_policy_idx}, {int(self.seed_array[0, best_policy_idx])}')
@ -258,7 +266,7 @@ class ActiveBOTopic(Node):
if self.bo_fixed_seed: if self.bo_fixed_seed:
seed = int(self.seed_array[0, best_policy_idx]) seed = int(self.seed_array[0, best_policy_idx])
self.get_logger().info(f'Used seed{seed}') # self.get_logger().info(f'Used seed{seed}')
else: else:
seed = int(np.random.randint(1, 2147483647, 1)[0]) seed = int(np.random.randint(1, 2147483647, 1)[0])
@ -277,12 +285,10 @@ class ActiveBOTopic(Node):
self.BO = None self.BO = None
else: else:
if self.rl_pending: if self.rl_pending or self.init_pending:
pass return
else: else:
if self.init_pending: if self.current_episode < self.bo_episodes:
pass
elif self.current_episode < self.bo_episodes:
# metrics # metrics
if self.bo_metric == "random": if self.bo_metric == "random":
user_query = RandomQuery(self.bo_metric_parameter) user_query = RandomQuery(self.bo_metric_parameter)
@ -305,11 +311,12 @@ class ActiveBOTopic(Node):
raise NotImplementedError raise NotImplementedError
if user_query.query(): if user_query.query():
self.user_asked = True
active_rl_request = ActiveRL() active_rl_request = ActiveRL()
old_policy, y_max, old_weights, _ = self.BO.get_best_result() old_policy, y_max, old_weights, _ = self.BO.get_best_result()
self.get_logger().info(f'Best: {y_max}, w:{old_weights}') # self.get_logger().info(f'Best: {y_max}, w:{old_weights}')
self.get_logger().info(f'Size of Y: {self.BO.Y.shape}, Size of X: {self.BO.X.shape}') # self.get_logger().info(f'Size of Y: {self.BO.Y.shape}, Size of X: {self.BO.X.shape}')
if self.bo_fixed_seed: if self.bo_fixed_seed:
seed = self.seed seed = self.seed
@ -323,12 +330,13 @@ class ActiveBOTopic(Node):
active_rl_request.weights = old_weights.tolist() active_rl_request.weights = old_weights.tolist()
active_rl_request.interactive_run = 0 active_rl_request.interactive_run = 0
self.get_logger().info('Calling: Active RL') # self.get_logger().info('Calling: Active RL')
self.active_rl_pub.publish(active_rl_request) self.active_rl_pub.publish(active_rl_request)
self.rl_pending = True self.rl_pending = True
else: else:
x_next = self.BO.next_observation() x_next = self.BO.next_observation()
# self.get_logger().info('Next Observation BO!')
self.BO.policy_model.weights = np.around(x_next, decimals=8) self.BO.policy_model.weights = np.around(x_next, decimals=8)
if self.bo_fixed_seed: if self.bo_fixed_seed:
seed = self.seed seed = self.seed
@ -340,10 +348,12 @@ class ActiveBOTopic(Node):
rl_msg.seed = seed rl_msg.seed = seed
rl_msg.display_run = False rl_msg.display_run = False
rl_msg.interactive_run = 2 rl_msg.interactive_run = 2
rl_msg.weights = x_next rl_msg.weights = x_next.tolist()
rl_msg.policy = self.BO.policy_model.rollout() rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
self.rl_pending = True self.rl_pending = True
self.active_rl_pub.publish(rl_msg)
self.current_episode += 1 self.current_episode += 1
# self.get_logger().info(f'Current Episode: {self.current_episode}') # self.get_logger().info(f'Current Episode: {self.current_episode}')
else: else:
@ -351,7 +361,7 @@ class ActiveBOTopic(Node):
self.best_pol_reward[:, self.current_run], \ self.best_pol_reward[:, self.current_run], \
self.best_weights[:, self.current_run], idx = self.BO.get_best_result() self.best_weights[:, self.current_run], idx = self.BO.get_best_result()
self.get_logger().info(f'best idx: {idx}') # self.get_logger().info(f'best idx: {idx}')
self.reward[:, self.current_run] = self.BO.best_reward.T self.reward[:, self.current_run] = self.BO.best_reward.T
@ -375,7 +385,9 @@ class ActiveBOTopic(Node):
state_msg.current_episode = self.current_episode + 1 \ state_msg.current_episode = self.current_episode + 1 \
if self.current_episode < self.bo_episodes else self.bo_episodes if self.current_episode < self.bo_episodes else self.bo_episodes
state_msg.best_reward = float(self.best_reward) state_msg.best_reward = float(self.best_reward)
if self.user_asked:
state_msg.last_user_reward = float(self.rl_reward) state_msg.last_user_reward = float(self.rl_reward)
self.user_asked = False
self.state_pub.publish(state_msg) self.state_pub.publish(state_msg)

View File

@ -0,0 +1,17 @@
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
Node(
package='active_bo_ros',
executable='interactive_bo',
name='interactive_bo'
),
Node(
package='active_bo_ros',
executable='active_rl_topic',
name='active_rl_topic'
),
])

View File

@ -35,6 +35,7 @@ setup(
'bo_torch_srv = active_bo_ros.bo_torch_service:main', 'bo_torch_srv = active_bo_ros.bo_torch_service:main',
'active_bo_topic = active_bo_ros.active_bo_topic:main', 'active_bo_topic = active_bo_ros.active_bo_topic:main',
'active_rl_topic = active_bo_ros.active_rl_topic:main', 'active_rl_topic = active_bo_ros.active_rl_topic:main',
'interactive_bo = active_bo_ros.interactive_bo:main'
], ],
}, },
) )