added utils

This commit is contained in:
Niko Feith 2024-02-29 15:31:43 +01:00
parent a23e5be64c
commit 7c3ad608f7
3 changed files with 28 additions and 1 deletions

View File

@ -7,6 +7,7 @@ from transitions import Machine
import cma
import yaml
import numpy as np
from src.interaction_utils.serialization import flatten_population, unflatten_population
# Msg/Srv/Action
from interaction_msgs.srv import Query
@ -33,6 +34,7 @@ class CMAESOptimizationNode(Node):
self.best_rewards_per_iteration = []
# ROS2 Interfaces
self.__future = None
self.__task_response = None
# Heartbeat Topics - to make sure that there is no deadlock
self.heartbeat_timeout = 30 # secs
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
@ -47,6 +49,7 @@ class CMAESOptimizationNode(Node):
# Service
self.query_srv = self.create_client(Query, 'query_srv')
self.task_srv = self.create_client(TaskEvaluation, 'task_srv')
# Define states
self.states = [
@ -136,8 +139,14 @@ class CMAESOptimizationNode(Node):
request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions
population = self.cmaes.ask()
flat_population = flatten_population(population)
request.parameter_array = flat_population
self.__future = self.task_srv.call_async(request)
self.__future.add_done_callback(self.handle_task_response)
self.data_send_for_evaluation()
def on_enter_interactive_mode(self):
pass
@ -187,7 +196,16 @@ class CMAESOptimizationNode(Node):
self.non_interaction()
except Exception as e:
self.get_logger().error(f'Query service call failed: {str(e)}')
self.get_logger().error(f'Query service call failed: {e}')
self.error_trigger()
def handle_task_response(self, future):
try:
self.__task_response = future.result()
self.evaluation_response_received()
except Exception as e:
self.get_logger().error(f'Task service call failed: {e}')
self.error_trigger()
# endregion

View File

View File

@ -0,0 +1,9 @@
import numpy as np
def flatten_population(population):
flattened_parameters = np.array(population).reshape(-1) # Efficient flattening using NumPy
return flattened_parameters
def unflatten_population(flattened_parameters, number_of_population, number_of_dimensions, number_of_parameters_per_dimension):
population = flattened_parameters.reshape(number_of_population, number_of_dimensions, number_of_parameters_per_dimension)
return population