added utils
This commit is contained in:
parent
a23e5be64c
commit
7c3ad608f7
@ -7,6 +7,7 @@ from transitions import Machine
|
||||
import cma
|
||||
import yaml
|
||||
import numpy as np
|
||||
from src.interaction_utils.serialization import flatten_population, unflatten_population
|
||||
|
||||
# Msg/Srv/Action
|
||||
from interaction_msgs.srv import Query
|
||||
@ -33,6 +34,7 @@ class CMAESOptimizationNode(Node):
|
||||
self.best_rewards_per_iteration = []
|
||||
# ROS2 Interfaces
|
||||
self.__future = None
|
||||
self.__task_response = None
|
||||
# Heartbeat Topics - to make sure that there is no deadlock
|
||||
self.heartbeat_timeout = 30 # secs
|
||||
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
|
||||
@ -47,6 +49,7 @@ class CMAESOptimizationNode(Node):
|
||||
|
||||
# Service
|
||||
self.query_srv = self.create_client(Query, 'query_srv')
|
||||
self.task_srv = self.create_client(TaskEvaluation, 'task_srv')
|
||||
|
||||
# Define states
|
||||
self.states = [
|
||||
@ -136,8 +139,14 @@ class CMAESOptimizationNode(Node):
|
||||
request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions
|
||||
|
||||
population = self.cmaes.ask()
|
||||
flat_population = flatten_population(population)
|
||||
|
||||
request.parameter_array = flat_population
|
||||
|
||||
self.__future = self.task_srv.call_async(request)
|
||||
self.__future.add_done_callback(self.handle_task_response)
|
||||
|
||||
self.data_send_for_evaluation()
|
||||
|
||||
def on_enter_interactive_mode(self):
|
||||
pass
|
||||
@ -187,7 +196,16 @@ class CMAESOptimizationNode(Node):
|
||||
self.non_interaction()
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Query service call failed: {str(e)}')
|
||||
self.get_logger().error(f'Query service call failed: {e}')
|
||||
self.error_trigger()
|
||||
|
||||
def handle_task_response(self, future):
|
||||
try:
|
||||
self.__task_response = future.result()
|
||||
self.evaluation_response_received()
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Task service call failed: {e}')
|
||||
self.error_trigger()
|
||||
|
||||
# endregion
|
0
src/interaction_utils/__init__.py
Normal file
0
src/interaction_utils/__init__.py
Normal file
9
src/interaction_utils/serialization.py
Normal file
9
src/interaction_utils/serialization.py
Normal file
@ -0,0 +1,9 @@
|
||||
import numpy as np
|
||||
|
||||
def flatten_population(population):
|
||||
flattened_parameters = np.array(population).reshape(-1) # Efficient flattening using NumPy
|
||||
return flattened_parameters
|
||||
|
||||
def unflatten_population(flattened_parameters, number_of_population, number_of_dimensions, number_of_parameters_per_dimension):
|
||||
population = flattened_parameters.reshape(number_of_population, number_of_dimensions, number_of_parameters_per_dimension)
|
||||
return population
|
Loading…
Reference in New Issue
Block a user