added utils
This commit is contained in:
parent
a23e5be64c
commit
7c3ad608f7
@ -7,6 +7,7 @@ from transitions import Machine
|
|||||||
import cma
|
import cma
|
||||||
import yaml
|
import yaml
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from src.interaction_utils.serialization import flatten_population, unflatten_population
|
||||||
|
|
||||||
# Msg/Srv/Action
|
# Msg/Srv/Action
|
||||||
from interaction_msgs.srv import Query
|
from interaction_msgs.srv import Query
|
||||||
@ -33,6 +34,7 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.best_rewards_per_iteration = []
|
self.best_rewards_per_iteration = []
|
||||||
# ROS2 Interfaces
|
# ROS2 Interfaces
|
||||||
self.__future = None
|
self.__future = None
|
||||||
|
self.__task_response = None
|
||||||
# Heartbeat Topics - to make sure that there is no deadlock
|
# Heartbeat Topics - to make sure that there is no deadlock
|
||||||
self.heartbeat_timeout = 30 # secs
|
self.heartbeat_timeout = 30 # secs
|
||||||
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
|
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
|
||||||
@ -47,6 +49,7 @@ class CMAESOptimizationNode(Node):
|
|||||||
|
|
||||||
# Service
|
# Service
|
||||||
self.query_srv = self.create_client(Query, 'query_srv')
|
self.query_srv = self.create_client(Query, 'query_srv')
|
||||||
|
self.task_srv = self.create_client(TaskEvaluation, 'task_srv')
|
||||||
|
|
||||||
# Define states
|
# Define states
|
||||||
self.states = [
|
self.states = [
|
||||||
@ -136,8 +139,14 @@ class CMAESOptimizationNode(Node):
|
|||||||
request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions
|
request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions
|
||||||
|
|
||||||
population = self.cmaes.ask()
|
population = self.cmaes.ask()
|
||||||
|
flat_population = flatten_population(population)
|
||||||
|
|
||||||
|
request.parameter_array = flat_population
|
||||||
|
|
||||||
|
self.__future = self.task_srv.call_async(request)
|
||||||
|
self.__future.add_done_callback(self.handle_task_response)
|
||||||
|
|
||||||
|
self.data_send_for_evaluation()
|
||||||
|
|
||||||
def on_enter_interactive_mode(self):
|
def on_enter_interactive_mode(self):
|
||||||
pass
|
pass
|
||||||
@ -187,7 +196,16 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.non_interaction()
|
self.non_interaction()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.get_logger().error(f'Query service call failed: {str(e)}')
|
self.get_logger().error(f'Query service call failed: {e}')
|
||||||
|
self.error_trigger()
|
||||||
|
|
||||||
|
def handle_task_response(self, future):
|
||||||
|
try:
|
||||||
|
self.__task_response = future.result()
|
||||||
|
self.evaluation_response_received()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Task service call failed: {e}')
|
||||||
self.error_trigger()
|
self.error_trigger()
|
||||||
|
|
||||||
# endregion
|
# endregion
|
0
src/interaction_utils/__init__.py
Normal file
0
src/interaction_utils/__init__.py
Normal file
9
src/interaction_utils/serialization.py
Normal file
9
src/interaction_utils/serialization.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def flatten_population(population):
|
||||||
|
flattened_parameters = np.array(population).reshape(-1) # Efficient flattening using NumPy
|
||||||
|
return flattened_parameters
|
||||||
|
|
||||||
|
def unflatten_population(flattened_parameters, number_of_population, number_of_dimensions, number_of_parameters_per_dimension):
|
||||||
|
population = flattened_parameters.reshape(number_of_population, number_of_dimensions, number_of_parameters_per_dimension)
|
||||||
|
return population
|
Loading…
Reference in New Issue
Block a user