From 7c3ad608f7b4ce13ae402a619ca4d3c06026d209 Mon Sep 17 00:00:00 2001 From: Niko Date: Thu, 29 Feb 2024 15:31:43 +0100 Subject: [PATCH] added utils --- .../cmaes_optimization_node.py | 20 ++++++++++++++++++- src/interaction_utils/__init__.py | 0 src/interaction_utils/serialization.py | 9 +++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 src/interaction_utils/__init__.py create mode 100644 src/interaction_utils/serialization.py diff --git a/src/interaction_optimizers/interaction_optimizers/cmaes_optimization_node.py b/src/interaction_optimizers/interaction_optimizers/cmaes_optimization_node.py index df24703..ece40bf 100644 --- a/src/interaction_optimizers/interaction_optimizers/cmaes_optimization_node.py +++ b/src/interaction_optimizers/interaction_optimizers/cmaes_optimization_node.py @@ -7,6 +7,7 @@ from transitions import Machine import cma import yaml import numpy as np +from src.interaction_utils.serialization import flatten_population, unflatten_population # Msg/Srv/Action from interaction_msgs.srv import Query @@ -33,6 +34,7 @@ class CMAESOptimizationNode(Node): self.best_rewards_per_iteration = [] # ROS2 Interfaces self.__future = None + self.__task_response = None # Heartbeat Topics - to make sure that there is no deadlock self.heartbeat_timeout = 30 # secs self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats) @@ -47,6 +49,7 @@ class CMAESOptimizationNode(Node): # Service self.query_srv = self.create_client(Query, 'query_srv') + self.task_srv = self.create_client(TaskEvaluation, 'task_srv') # Define states self.states = [ @@ -136,8 +139,14 @@ class CMAESOptimizationNode(Node): request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions population = self.cmaes.ask() + flat_population = flatten_population(population) + request.parameter_array = flat_population + self.__future = self.task_srv.call_async(request) + self.__future.add_done_callback(self.handle_task_response) + + self.data_send_for_evaluation() def on_enter_interactive_mode(self): pass @@ -187,7 +196,16 @@ class CMAESOptimizationNode(Node): self.non_interaction() except Exception as e: - self.get_logger().error(f'Query service call failed: {str(e)}') + self.get_logger().error(f'Query service call failed: {e}') + self.error_trigger() + + def handle_task_response(self, future): + try: + self.__task_response = future.result() + self.evaluation_response_received() + + except Exception as e: + self.get_logger().error(f'Task service call failed: {e}') self.error_trigger() # endregion \ No newline at end of file diff --git a/src/interaction_utils/__init__.py b/src/interaction_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/interaction_utils/serialization.py b/src/interaction_utils/serialization.py new file mode 100644 index 0000000..822c73e --- /dev/null +++ b/src/interaction_utils/serialization.py @@ -0,0 +1,9 @@ +import numpy as np + +def flatten_population(population): + flattened_parameters = np.array(population).reshape(-1) # Efficient flattening using NumPy + return flattened_parameters + +def unflatten_population(flattened_parameters, number_of_population, number_of_dimensions, number_of_parameters_per_dimension): + population = flattened_parameters.reshape(number_of_population, number_of_dimensions, number_of_parameters_per_dimension) + return population \ No newline at end of file