cmaes node TODO:

error state handling
finished state handling
parameters forwared to the query node
This commit is contained in:
Niko Feith 2024-03-13 15:04:15 +01:00
parent aa04e0b833
commit e0874c3d9d
2 changed files with 46 additions and 7 deletions

View File

@ -1,6 +1,7 @@
bool user_input bool user_input
uint16 number_of_population
# case if user_input is true # case if user_input is true
uint16 nr_of_population
float32[] user_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension float32[] user_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension float32[] user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
@ -8,7 +9,6 @@ float32[] conditional_points # Length: (number_of_dimensions + time_stamp[0,1])
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that) float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
# case if user_input is false # case if user_input is false
uint16 number_of_population
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity) uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
uint16 number_of_parameters_per_dimensions uint16 number_of_parameters_per_dimensions
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension

View File

@ -25,11 +25,13 @@ class CMAESOptimizationNode(Node):
# CMA-ES Attributes # CMA-ES Attributes
self.cmaes = None self.cmaes = None
self.number_of_initial_episodes = self.declare_parameter('number_of_initial_episodes', 5) self.number_of_initial_episodes = self.declare_parameter('number_of_initial_episodes', 5)
self.initial_user_covariance = self.declare_parameter('initial_user_covariance', 1.0)
self.episode = 0 self.episode = 0
self.number_of_dimensions = 3 self.number_of_dimensions = 3
self.number_of_parameters_per_dimensions = 10 self.number_of_parameters_per_dimensions = 10
# the number of weights is double the number of dims * params per dim since its Position and Velocity # the number of weights is double the number of dims * params per dim since its Position and Velocity
self.number_of_weights = 2 * self.number_of_dimensions * self.number_of_parameters_per_dimensions self.number_of_weights = 2 * self.number_of_dimensions * self.number_of_parameters_per_dimensions
self.user_covariance = np.ones((self.number_of_weights,1)) * self.initial_user_covariance
self.random_seed = None self.random_seed = None
@ -41,6 +43,7 @@ class CMAESOptimizationNode(Node):
# ROS2 Interfaces # ROS2 Interfaces
self.__future = None self.__future = None
self.__task_response = None self.__task_response = None
self.__user_response = None
# Heartbeat Topics - to make sure that there is no deadlock # Heartbeat Topics - to make sure that there is no deadlock
self.heartbeat_timeout = 30 # secs self.heartbeat_timeout = 30 # secs
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats) self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
@ -57,6 +60,7 @@ class CMAESOptimizationNode(Node):
self.parameter_srv = self.create_service(ParameterChange, 'cmaes_parameter_srv', self.parameter_callback) self.parameter_srv = self.create_service(ParameterChange, 'cmaes_parameter_srv', self.parameter_callback)
self.query_srv = self.create_client(Query, 'query_srv') self.query_srv = self.create_client(Query, 'query_srv')
self.task_srv = self.create_client(TaskEvaluation, 'task_srv') self.task_srv = self.create_client(TaskEvaluation, 'task_srv')
self.user_interface_srv = self.create_client(UserInterface, 'user_interface_srv')
# Define states # Define states
self.states = [ self.states = [
@ -126,11 +130,11 @@ class CMAESOptimizationNode(Node):
request = Query.Request() request = Query.Request()
if self.query_metric == 'random': if self.query_metric == 'random':
pass pass #TODO
elif self.query_metric == 'regular': elif self.query_metric == 'regular':
pass pass #TODO
elif self.query_metric == 'improvement': elif self.query_metric == 'improvement':
pass pass #TODO
self.query_srv.wait_for_service(10) self.query_srv.wait_for_service(10)
self.__future = self.query_srv.call_async(request) self.__future = self.query_srv.call_async(request)
@ -157,12 +161,39 @@ class CMAESOptimizationNode(Node):
self.data_send_for_evaluation() self.data_send_for_evaluation()
def on_enter_interactive_mode(self): def on_enter_interactive_mode(self):
# Reset Mixed Reality heartbeat to check if the other node crashed
self.last_mr_heartbeat_time = None
request = UserInterface.Request() request = UserInterface.Request()
request.current_cma_mean = flatten_population(self.cmaes.mean) request.current_cma_mean = flatten_population(self.cmaes.mean)
request.best_parameters request.best_parameters = flatten_population(self.best_parameters_per_interation[-1])
request.current_user_covariance_diag = self.user_covariance
self.__future = self.user_interface_srv.call_async(request)
self.__future.add_done_callback(self.handle_user_response)
self.data_send_to_user()
def on_enter_prepare_user_data_for_evaluation(self): def on_enter_prepare_user_data_for_evaluation(self):
pass # Reset Task heartbeat to check if the other node crashed
self.last_task_heartbeat_time = None
# Update the user_covariance
self.user_covariance = self.__user_response.user_covariance_diag
request = TaskEvaluation.Request()
request.user_input = True
request.number_of_population = self.cmaes.popsize
request.user_parameters = self.__user_response.user_parameters
request.user_covariance_diag = self.user_covariance
request.current_cma_mean = flatten_population(self.cmaes.mean)
request.conditional_points = self.__user_response.conditional_points
request.weight_parameter = self.__user_response.weight_parameter
self.__future = self.task_srv.call_async(request)
self.__future.add_done_callback(self.handle_task_response)
self.send_user_data_to_evaluation()
def on_enter_update_optimizer(self): def on_enter_update_optimizer(self):
population = unflatten_population(self.__task_response.parameter_array, population = unflatten_population(self.__task_response.parameter_array,
@ -298,4 +329,12 @@ class CMAESOptimizationNode(Node):
self.get_logger().error(f'Task service call failed: {e}') self.get_logger().error(f'Task service call failed: {e}')
self.error_trigger() self.error_trigger()
def handle_user_response(self, future):
try:
self.__user_response = future.result()
self.user_response_received()
except Exception as e:
self.get_logger().error(f'Task service call failed: {e}')
self.error_trigger()
# endregion # endregion