Fixed inconsistency
This commit is contained in:
parent
e99b131ee9
commit
3e278fb0fc
@ -83,6 +83,7 @@ class ActiveRLService(Node):
|
|||||||
self.rl_seed = None
|
self.rl_seed = None
|
||||||
self.rl_policy = None
|
self.rl_policy = None
|
||||||
self.rl_weights = None
|
self.rl_weights = None
|
||||||
|
self.interactive_run = 0
|
||||||
|
|
||||||
def active_rl_callback(self, msg):
|
def active_rl_callback(self, msg):
|
||||||
self.rl_env = msg.env
|
self.rl_env = msg.env
|
||||||
@ -240,7 +241,7 @@ class ActiveRLService(Node):
|
|||||||
|
|
||||||
if done:
|
if done:
|
||||||
rl_response = ActiveRLResponse()
|
rl_response = ActiveRLResponse()
|
||||||
rl_response.weights = self.eval_weights
|
rl_response.weights = self.rl_weights
|
||||||
rl_response.reward = self.rl_reward
|
rl_response.reward = self.rl_reward
|
||||||
rl_response.final_step = self.rl_step
|
rl_response.final_step = self.rl_step
|
||||||
|
|
||||||
|
@ -73,6 +73,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.rl_weights = None
|
self.rl_weights = None
|
||||||
self.rl_final_step = None
|
self.rl_final_step = None
|
||||||
self.rl_reward = 0.0
|
self.rl_reward = 0.0
|
||||||
|
self.user_asked = False
|
||||||
|
|
||||||
# State Publisher
|
# State Publisher
|
||||||
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
|
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
|
||||||
@ -113,7 +114,7 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
def active_bo_callback(self, msg):
|
def active_bo_callback(self, msg):
|
||||||
if not self.active_bo_pending:
|
if not self.active_bo_pending:
|
||||||
self.get_logger().info('Active Bayesian Optimization request pending!')
|
# self.get_logger().info('Active Bayesian Optimization request pending!')
|
||||||
self.active_bo_pending = True
|
self.active_bo_pending = True
|
||||||
self.bo_env = msg.env
|
self.bo_env = msg.env
|
||||||
self.bo_metric = msg.metric
|
self.bo_metric = msg.metric
|
||||||
@ -136,7 +137,7 @@ class ActiveBOTopic(Node):
|
|||||||
# set the seed
|
# set the seed
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||||
self.get_logger().info(str(self.seed))
|
# self.get_logger().info(str(self.seed))
|
||||||
else:
|
else:
|
||||||
self.seed = None
|
self.seed = None
|
||||||
|
|
||||||
@ -146,14 +147,14 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
def active_rl_callback(self, msg):
|
def active_rl_callback(self, msg):
|
||||||
if self.rl_pending:
|
if self.rl_pending:
|
||||||
self.get_logger().info('Active Reinforcement Learning response received!')
|
# self.get_logger().info('Active Reinforcement Learning response received!')
|
||||||
self.rl_weights = msg.weights
|
self.rl_weights = np.array(msg.weights, dtype=np.float64)
|
||||||
self.rl_final_step = msg.final_step
|
self.rl_final_step = msg.final_step
|
||||||
self.rl_reward = msg.reward
|
self.rl_reward = msg.reward
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.BO.add_new_observation(self.rl_reward, self.rl_weights)
|
self.BO.add_new_observation(self.rl_reward, self.rl_weights)
|
||||||
self.get_logger().info('Active Reinforcement Learning added new observation!')
|
# self.get_logger().info('Active Reinforcement Learning added new observation!')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}')
|
self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}')
|
||||||
|
|
||||||
@ -188,8 +189,11 @@ class ActiveBOTopic(Node):
|
|||||||
acq=self.bo_acq_fcn,
|
acq=self.bo_acq_fcn,
|
||||||
nr_weights=self.bo_nr_weights)
|
nr_weights=self.bo_nr_weights)
|
||||||
|
|
||||||
|
self.BO.reset_bo()
|
||||||
|
|
||||||
# self.BO.initialize()
|
# self.BO.initialize()
|
||||||
self.init_pending = True
|
self.init_pending = True
|
||||||
|
self.get_logger().info('BO Initialization is starting!')
|
||||||
|
|
||||||
if self.init_pending and not self.rl_pending:
|
if self.init_pending and not self.rl_pending:
|
||||||
|
|
||||||
@ -203,9 +207,13 @@ class ActiveBOTopic(Node):
|
|||||||
rl_msg.seed = seed
|
rl_msg.seed = seed
|
||||||
rl_msg.display_run = False
|
rl_msg.display_run = False
|
||||||
rl_msg.interactive_run = 2
|
rl_msg.interactive_run = 2
|
||||||
rl_msg.weights = self.BO.policy_model.random_policy()
|
rl_msg.weights = self.BO.policy_model.random_policy().tolist()
|
||||||
rl_msg.policy = self.BO.policy_model.rollout()
|
rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
|
||||||
|
|
||||||
|
self.active_rl_pub.publish(rl_msg)
|
||||||
|
|
||||||
self.rl_pending = True
|
self.rl_pending = True
|
||||||
|
return
|
||||||
|
|
||||||
if self.current_run == self.bo_runs:
|
if self.current_run == self.bo_runs:
|
||||||
bo_response = ActiveBOResponse()
|
bo_response = ActiveBOResponse()
|
||||||
@ -214,7 +222,7 @@ class ActiveBOTopic(Node):
|
|||||||
bo_response.best_policy = self.best_policy[:, best_policy_idx].tolist()
|
bo_response.best_policy = self.best_policy[:, best_policy_idx].tolist()
|
||||||
bo_response.best_weights = self.best_weights[:, best_policy_idx].tolist()
|
bo_response.best_weights = self.best_weights[:, best_policy_idx].tolist()
|
||||||
|
|
||||||
self.get_logger().info(f'Best Policy: {self.best_pol_reward}')
|
# self.get_logger().info(f'Best Policy: {self.best_pol_reward}')
|
||||||
|
|
||||||
self.get_logger().info(f'{best_policy_idx}, {int(self.seed_array[0, best_policy_idx])}')
|
self.get_logger().info(f'{best_policy_idx}, {int(self.seed_array[0, best_policy_idx])}')
|
||||||
|
|
||||||
@ -258,7 +266,7 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
seed = int(self.seed_array[0, best_policy_idx])
|
seed = int(self.seed_array[0, best_policy_idx])
|
||||||
self.get_logger().info(f'Used seed{seed}')
|
# self.get_logger().info(f'Used seed{seed}')
|
||||||
else:
|
else:
|
||||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||||
|
|
||||||
@ -277,12 +285,10 @@ class ActiveBOTopic(Node):
|
|||||||
self.BO = None
|
self.BO = None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.rl_pending:
|
if self.rl_pending or self.init_pending:
|
||||||
pass
|
return
|
||||||
else:
|
else:
|
||||||
if self.init_pending:
|
if self.current_episode < self.bo_episodes:
|
||||||
pass
|
|
||||||
elif self.current_episode < self.bo_episodes:
|
|
||||||
# metrics
|
# metrics
|
||||||
if self.bo_metric == "random":
|
if self.bo_metric == "random":
|
||||||
user_query = RandomQuery(self.bo_metric_parameter)
|
user_query = RandomQuery(self.bo_metric_parameter)
|
||||||
@ -305,11 +311,12 @@ class ActiveBOTopic(Node):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
if user_query.query():
|
if user_query.query():
|
||||||
|
self.user_asked = True
|
||||||
active_rl_request = ActiveRL()
|
active_rl_request = ActiveRL()
|
||||||
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
|
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
|
||||||
|
|
||||||
self.get_logger().info(f'Best: {y_max}, w:{old_weights}')
|
# self.get_logger().info(f'Best: {y_max}, w:{old_weights}')
|
||||||
self.get_logger().info(f'Size of Y: {self.BO.Y.shape}, Size of X: {self.BO.X.shape}')
|
# self.get_logger().info(f'Size of Y: {self.BO.Y.shape}, Size of X: {self.BO.X.shape}')
|
||||||
|
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
seed = self.seed
|
seed = self.seed
|
||||||
@ -323,12 +330,13 @@ class ActiveBOTopic(Node):
|
|||||||
active_rl_request.weights = old_weights.tolist()
|
active_rl_request.weights = old_weights.tolist()
|
||||||
active_rl_request.interactive_run = 0
|
active_rl_request.interactive_run = 0
|
||||||
|
|
||||||
self.get_logger().info('Calling: Active RL')
|
# self.get_logger().info('Calling: Active RL')
|
||||||
self.active_rl_pub.publish(active_rl_request)
|
self.active_rl_pub.publish(active_rl_request)
|
||||||
self.rl_pending = True
|
self.rl_pending = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
x_next = self.BO.next_observation()
|
x_next = self.BO.next_observation()
|
||||||
|
# self.get_logger().info('Next Observation BO!')
|
||||||
self.BO.policy_model.weights = np.around(x_next, decimals=8)
|
self.BO.policy_model.weights = np.around(x_next, decimals=8)
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
seed = self.seed
|
seed = self.seed
|
||||||
@ -340,10 +348,12 @@ class ActiveBOTopic(Node):
|
|||||||
rl_msg.seed = seed
|
rl_msg.seed = seed
|
||||||
rl_msg.display_run = False
|
rl_msg.display_run = False
|
||||||
rl_msg.interactive_run = 2
|
rl_msg.interactive_run = 2
|
||||||
rl_msg.weights = x_next
|
rl_msg.weights = x_next.tolist()
|
||||||
rl_msg.policy = self.BO.policy_model.rollout()
|
rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
|
||||||
self.rl_pending = True
|
self.rl_pending = True
|
||||||
|
|
||||||
|
self.active_rl_pub.publish(rl_msg)
|
||||||
|
|
||||||
self.current_episode += 1
|
self.current_episode += 1
|
||||||
# self.get_logger().info(f'Current Episode: {self.current_episode}')
|
# self.get_logger().info(f'Current Episode: {self.current_episode}')
|
||||||
else:
|
else:
|
||||||
@ -351,7 +361,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.best_pol_reward[:, self.current_run], \
|
self.best_pol_reward[:, self.current_run], \
|
||||||
self.best_weights[:, self.current_run], idx = self.BO.get_best_result()
|
self.best_weights[:, self.current_run], idx = self.BO.get_best_result()
|
||||||
|
|
||||||
self.get_logger().info(f'best idx: {idx}')
|
# self.get_logger().info(f'best idx: {idx}')
|
||||||
|
|
||||||
self.reward[:, self.current_run] = self.BO.best_reward.T
|
self.reward[:, self.current_run] = self.BO.best_reward.T
|
||||||
|
|
||||||
@ -375,7 +385,9 @@ class ActiveBOTopic(Node):
|
|||||||
state_msg.current_episode = self.current_episode + 1 \
|
state_msg.current_episode = self.current_episode + 1 \
|
||||||
if self.current_episode < self.bo_episodes else self.bo_episodes
|
if self.current_episode < self.bo_episodes else self.bo_episodes
|
||||||
state_msg.best_reward = float(self.best_reward)
|
state_msg.best_reward = float(self.best_reward)
|
||||||
state_msg.last_user_reward = float(self.rl_reward)
|
if self.user_asked:
|
||||||
|
state_msg.last_user_reward = float(self.rl_reward)
|
||||||
|
self.user_asked = False
|
||||||
self.state_pub.publish(state_msg)
|
self.state_pub.publish(state_msg)
|
||||||
|
|
||||||
|
|
||||||
|
17
src/active_bo_ros/launch/interactive_bo.launch.py
Executable file
17
src/active_bo_ros/launch/interactive_bo.launch.py
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
from launch import LaunchDescription
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
return LaunchDescription([
|
||||||
|
Node(
|
||||||
|
package='active_bo_ros',
|
||||||
|
executable='interactive_bo',
|
||||||
|
name='interactive_bo'
|
||||||
|
),
|
||||||
|
Node(
|
||||||
|
package='active_bo_ros',
|
||||||
|
executable='active_rl_topic',
|
||||||
|
name='active_rl_topic'
|
||||||
|
),
|
||||||
|
])
|
@ -35,6 +35,7 @@ setup(
|
|||||||
'bo_torch_srv = active_bo_ros.bo_torch_service:main',
|
'bo_torch_srv = active_bo_ros.bo_torch_service:main',
|
||||||
'active_bo_topic = active_bo_ros.active_bo_topic:main',
|
'active_bo_topic = active_bo_ros.active_bo_topic:main',
|
||||||
'active_rl_topic = active_bo_ros.active_rl_topic:main',
|
'active_rl_topic = active_bo_ros.active_rl_topic:main',
|
||||||
|
'interactive_bo = active_bo_ros.interactive_bo:main'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user