Active BO works
This commit is contained in:
parent
baa81ca4a7
commit
a17dc77234
@ -85,14 +85,21 @@ class ActiveBOTopic(Node):
|
||||
self.current_episode = 0
|
||||
|
||||
def active_bo_callback(self, msg):
|
||||
self.get_logger().info('Active Bayesian Optimization request pending!')
|
||||
self.active_bo_pending = True
|
||||
self.bo_nr_weights = msg.nr_weights
|
||||
self.bo_steps = msg.max_steps
|
||||
self.bo_episodes = msg.nr_episodes
|
||||
self.bo_runs = msg.nr_runs
|
||||
self.bo_acq_fcn = msg.acquisition_function
|
||||
self.bo_epsilon = msg.epsilon
|
||||
if not self.active_bo_pending:
|
||||
self.get_logger().info('Active Bayesian Optimization request pending!')
|
||||
self.active_bo_pending = True
|
||||
self.bo_nr_weights = msg.nr_weights
|
||||
self.bo_steps = msg.max_steps
|
||||
self.bo_episodes = msg.nr_episodes
|
||||
self.bo_runs = msg.nr_runs
|
||||
self.bo_acq_fcn = msg.acquisition_function
|
||||
self.bo_epsilon = msg.epsilon
|
||||
|
||||
# initialize
|
||||
self.reward = np.zeros((self.bo_episodes, self.bo_runs))
|
||||
self.best_pol_reward = np.zeros((1, self.bo_runs))
|
||||
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
|
||||
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
|
||||
|
||||
def reset_rl_response(self):
|
||||
self.rl_weights = None
|
||||
@ -100,11 +107,12 @@ class ActiveBOTopic(Node):
|
||||
self.rl_reward = None
|
||||
|
||||
def active_rl_callback(self, msg):
|
||||
self.get_logger().info('Active Reinforcement Learning response pending!')
|
||||
self.active_rl_pending = False
|
||||
self.rl_weights = None
|
||||
self.rl_final_step = None
|
||||
self.rl_reward = None
|
||||
if self.active_rl_pending:
|
||||
self.get_logger().info('Active Reinforcement Learning response pending!')
|
||||
self.active_rl_pending = False
|
||||
self.rl_weights = msg.weights
|
||||
self.rl_final_step = msg.final_step
|
||||
self.rl_reward = msg.reward
|
||||
|
||||
def mainloop_callback(self):
|
||||
if self.active_bo_pending:
|
||||
@ -115,14 +123,9 @@ class ActiveBOTopic(Node):
|
||||
acq=self.bo_acq_fcn,
|
||||
nr_weights=self.bo_nr_weights)
|
||||
|
||||
self.reward = np.zeros((self.bo_episodes, self.bo_runs))
|
||||
self.best_pol_reward = np.zeros((1, self.bo_runs))
|
||||
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
|
||||
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
|
||||
|
||||
self.BO.initialize()
|
||||
|
||||
if self.current_run >= self.bo_runs:
|
||||
if self.current_run == self.bo_runs:
|
||||
bo_response = ActiveBOResponse()
|
||||
|
||||
best_policy_idx = np.argmax(self.best_pol_reward)
|
||||
@ -132,6 +135,7 @@ class ActiveBOTopic(Node):
|
||||
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
|
||||
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
||||
|
||||
self.get_logger().info('Responding: Active BO')
|
||||
self.active_bo_pub.publish(bo_response)
|
||||
self.reset_bo_request()
|
||||
self.active_bo_pending = False
|
||||
@ -162,16 +166,21 @@ class ActiveBOTopic(Node):
|
||||
x_next = self.BO.next_observation()
|
||||
self.BO.eval_new_observation(x_next)
|
||||
|
||||
self.current_episode += 1
|
||||
self.get_logger().info(f'Current Episode: {self.current_episode}')
|
||||
else:
|
||||
self.best_policy[:, self.current_run], \
|
||||
self.best_pol_reward[:, self.current_run], \
|
||||
self.best_weights[:, self.current_run] = self.BO.get_best_result()
|
||||
|
||||
self.reward[:, self.current_run] = self.BO.best_reward.T
|
||||
|
||||
self.current_episode += 1
|
||||
else:
|
||||
self.BO = None
|
||||
|
||||
self.current_episode = 0
|
||||
self.current_run += 1
|
||||
self.get_logger().info(f'Current Run: {self.current_run}')
|
||||
|
||||
|
||||
|
||||
def main(args=None):
|
||||
|
@ -63,7 +63,7 @@ class ActiveRLService(Node):
|
||||
self.best_pol_shown = False
|
||||
|
||||
# Main loop timer object
|
||||
self.mainloop_timer_period = 0.1
|
||||
self.mainloop_timer_period = 0.05
|
||||
self.mainloop = self.create_timer(self.mainloop_timer_period,
|
||||
self.mainloop_callback,
|
||||
callback_group=mainloop_callback_group)
|
||||
@ -77,6 +77,7 @@ class ActiveRLService(Node):
|
||||
self.rl_weights = msg.weights
|
||||
|
||||
self.get_logger().info('Active RL: Called!')
|
||||
self.env.reset()
|
||||
self.active_rl_pending = True
|
||||
|
||||
def reset_eval_request(self):
|
||||
@ -118,6 +119,7 @@ class ActiveRLService(Node):
|
||||
if not done and self.rl_step == len(policy):
|
||||
distance = -(self.env.goal_position - output[0][0])
|
||||
self.rl_reward += distance * self.distance_penalty
|
||||
done = True
|
||||
|
||||
return done
|
||||
|
||||
@ -136,11 +138,15 @@ class ActiveRLService(Node):
|
||||
|
||||
self.eval_pub.publish(eval_request)
|
||||
self.get_logger().info('Active RL: Called!')
|
||||
self.get_logger().info('Active RL: Waiting for Eval!')
|
||||
|
||||
self.env.reset()
|
||||
|
||||
self.best_pol_shown = True
|
||||
|
||||
elif self.best_pol_shown:
|
||||
if not self.eval_response_received:
|
||||
self.get_logger().info('Active RL: Waiting for Eval!')
|
||||
pass
|
||||
|
||||
if self.eval_response_received:
|
||||
done = self.next_image(self.eval_policy)
|
||||
@ -153,6 +159,8 @@ class ActiveRLService(Node):
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
|
||||
self.env.reset()
|
||||
|
||||
# reset flags and attributes
|
||||
self.reset_eval_request()
|
||||
self.reset_rl_request()
|
||||
|
17
src/active_bo_ros/launch/active_topic.launch.py
Executable file
17
src/active_bo_ros/launch/active_topic.launch.py
Executable file
@ -0,0 +1,17 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
return LaunchDescription([
|
||||
Node(
|
||||
package='active_bo_ros',
|
||||
executable='active_bo_topic',
|
||||
name='active_bo_topic'
|
||||
),
|
||||
Node(
|
||||
package='active_bo_ros',
|
||||
executable='active_rl_topic',
|
||||
name='active_rl_topic'
|
||||
),
|
||||
])
|
@ -1,6 +1,7 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
return LaunchDescription([
|
||||
Node(
|
||||
|
@ -32,6 +32,8 @@ setup(
|
||||
'bo_srv = active_bo_ros.bo_service:main',
|
||||
'active_bo_srv = active_bo_ros.active_bo_service:main',
|
||||
'active_rl_srv = active_bo_ros.active_rl_service:main',
|
||||
'active_bo_topic = active_bo_ros.active_bo_topic:main',
|
||||
'active_rl_topic = active_bo_ros.active_rl_topic:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user