Active BO works

This commit is contained in:
Niko Feith 2023-04-04 17:03:09 +02:00
parent baa81ca4a7
commit a17dc77234
5 changed files with 60 additions and 23 deletions

View File

@ -85,14 +85,21 @@ class ActiveBOTopic(Node):
self.current_episode = 0
def active_bo_callback(self, msg):
self.get_logger().info('Active Bayesian Optimization request pending!')
self.active_bo_pending = True
self.bo_nr_weights = msg.nr_weights
self.bo_steps = msg.max_steps
self.bo_episodes = msg.nr_episodes
self.bo_runs = msg.nr_runs
self.bo_acq_fcn = msg.acquisition_function
self.bo_epsilon = msg.epsilon
if not self.active_bo_pending:
self.get_logger().info('Active Bayesian Optimization request pending!')
self.active_bo_pending = True
self.bo_nr_weights = msg.nr_weights
self.bo_steps = msg.max_steps
self.bo_episodes = msg.nr_episodes
self.bo_runs = msg.nr_runs
self.bo_acq_fcn = msg.acquisition_function
self.bo_epsilon = msg.epsilon
# initialize
self.reward = np.zeros((self.bo_episodes, self.bo_runs))
self.best_pol_reward = np.zeros((1, self.bo_runs))
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
def reset_rl_response(self):
self.rl_weights = None
@ -100,11 +107,12 @@ class ActiveBOTopic(Node):
self.rl_reward = None
def active_rl_callback(self, msg):
self.get_logger().info('Active Reinforcement Learning response pending!')
self.active_rl_pending = False
self.rl_weights = None
self.rl_final_step = None
self.rl_reward = None
if self.active_rl_pending:
self.get_logger().info('Active Reinforcement Learning response pending!')
self.active_rl_pending = False
self.rl_weights = msg.weights
self.rl_final_step = msg.final_step
self.rl_reward = msg.reward
def mainloop_callback(self):
if self.active_bo_pending:
@ -115,14 +123,9 @@ class ActiveBOTopic(Node):
acq=self.bo_acq_fcn,
nr_weights=self.bo_nr_weights)
self.reward = np.zeros((self.bo_episodes, self.bo_runs))
self.best_pol_reward = np.zeros((1, self.bo_runs))
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
self.BO.initialize()
if self.current_run >= self.bo_runs:
if self.current_run == self.bo_runs:
bo_response = ActiveBOResponse()
best_policy_idx = np.argmax(self.best_pol_reward)
@ -132,6 +135,7 @@ class ActiveBOTopic(Node):
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
self.get_logger().info('Responding: Active BO')
self.active_bo_pub.publish(bo_response)
self.reset_bo_request()
self.active_bo_pending = False
@ -162,16 +166,21 @@ class ActiveBOTopic(Node):
x_next = self.BO.next_observation()
self.BO.eval_new_observation(x_next)
self.current_episode += 1
self.get_logger().info(f'Current Episode: {self.current_episode}')
else:
self.best_policy[:, self.current_run], \
self.best_pol_reward[:, self.current_run], \
self.best_weights[:, self.current_run] = self.BO.get_best_result()
self.reward[:, self.current_run] = self.BO.best_reward.T
self.current_episode += 1
else:
self.BO = None
self.current_episode = 0
self.current_run += 1
self.get_logger().info(f'Current Run: {self.current_run}')
def main(args=None):

View File

@ -63,7 +63,7 @@ class ActiveRLService(Node):
self.best_pol_shown = False
# Main loop timer object
self.mainloop_timer_period = 0.1
self.mainloop_timer_period = 0.05
self.mainloop = self.create_timer(self.mainloop_timer_period,
self.mainloop_callback,
callback_group=mainloop_callback_group)
@ -77,6 +77,7 @@ class ActiveRLService(Node):
self.rl_weights = msg.weights
self.get_logger().info('Active RL: Called!')
self.env.reset()
self.active_rl_pending = True
def reset_eval_request(self):
@ -118,6 +119,7 @@ class ActiveRLService(Node):
if not done and self.rl_step == len(policy):
distance = -(self.env.goal_position - output[0][0])
self.rl_reward += distance * self.distance_penalty
done = True
return done
@ -136,11 +138,15 @@ class ActiveRLService(Node):
self.eval_pub.publish(eval_request)
self.get_logger().info('Active RL: Called!')
self.get_logger().info('Active RL: Waiting for Eval!')
self.env.reset()
self.best_pol_shown = True
elif self.best_pol_shown:
if not self.eval_response_received:
self.get_logger().info('Active RL: Waiting for Eval!')
pass
if self.eval_response_received:
done = self.next_image(self.eval_policy)
@ -153,6 +159,8 @@ class ActiveRLService(Node):
self.active_rl_pub.publish(rl_response)
self.env.reset()
# reset flags and attributes
self.reset_eval_request()
self.reset_rl_request()

View File

@ -0,0 +1,17 @@
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
Node(
package='active_bo_ros',
executable='active_bo_topic',
name='active_bo_topic'
),
Node(
package='active_bo_ros',
executable='active_rl_topic',
name='active_rl_topic'
),
])

View File

@ -1,6 +1,7 @@
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
Node(

View File

@ -32,6 +32,8 @@ setup(
'bo_srv = active_bo_ros.bo_service:main',
'active_bo_srv = active_bo_ros.active_bo_service:main',
'active_rl_srv = active_bo_ros.active_rl_service:main',
'active_bo_topic = active_bo_ros.active_bo_topic:main',
'active_rl_topic = active_bo_ros.active_rl_topic:main',
],
},
)