cmaes node TODO:
error state handling finished state handling parameters forwared to the query node
This commit is contained in:
parent
aa04e0b833
commit
e0874c3d9d
@ -1,6 +1,7 @@
|
|||||||
bool user_input
|
bool user_input
|
||||||
|
uint16 number_of_population
|
||||||
|
|
||||||
# case if user_input is true
|
# case if user_input is true
|
||||||
uint16 nr_of_population
|
|
||||||
float32[] user_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
|
float32[] user_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
float32[] user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
|
float32[] user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
@ -8,7 +9,6 @@ float32[] conditional_points # Length: (number_of_dimensions + time_stamp[0,1])
|
|||||||
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
|
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
|
||||||
|
|
||||||
# case if user_input is false
|
# case if user_input is false
|
||||||
uint16 number_of_population
|
|
||||||
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
|
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
|
||||||
uint16 number_of_parameters_per_dimensions
|
uint16 number_of_parameters_per_dimensions
|
||||||
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension
|
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
@ -25,11 +25,13 @@ class CMAESOptimizationNode(Node):
|
|||||||
# CMA-ES Attributes
|
# CMA-ES Attributes
|
||||||
self.cmaes = None
|
self.cmaes = None
|
||||||
self.number_of_initial_episodes = self.declare_parameter('number_of_initial_episodes', 5)
|
self.number_of_initial_episodes = self.declare_parameter('number_of_initial_episodes', 5)
|
||||||
|
self.initial_user_covariance = self.declare_parameter('initial_user_covariance', 1.0)
|
||||||
self.episode = 0
|
self.episode = 0
|
||||||
self.number_of_dimensions = 3
|
self.number_of_dimensions = 3
|
||||||
self.number_of_parameters_per_dimensions = 10
|
self.number_of_parameters_per_dimensions = 10
|
||||||
# the number of weights is double the number of dims * params per dim since its Position and Velocity
|
# the number of weights is double the number of dims * params per dim since its Position and Velocity
|
||||||
self.number_of_weights = 2 * self.number_of_dimensions * self.number_of_parameters_per_dimensions
|
self.number_of_weights = 2 * self.number_of_dimensions * self.number_of_parameters_per_dimensions
|
||||||
|
self.user_covariance = np.ones((self.number_of_weights,1)) * self.initial_user_covariance
|
||||||
|
|
||||||
self.random_seed = None
|
self.random_seed = None
|
||||||
|
|
||||||
@ -41,6 +43,7 @@ class CMAESOptimizationNode(Node):
|
|||||||
# ROS2 Interfaces
|
# ROS2 Interfaces
|
||||||
self.__future = None
|
self.__future = None
|
||||||
self.__task_response = None
|
self.__task_response = None
|
||||||
|
self.__user_response = None
|
||||||
# Heartbeat Topics - to make sure that there is no deadlock
|
# Heartbeat Topics - to make sure that there is no deadlock
|
||||||
self.heartbeat_timeout = 30 # secs
|
self.heartbeat_timeout = 30 # secs
|
||||||
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
|
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
|
||||||
@ -57,6 +60,7 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.parameter_srv = self.create_service(ParameterChange, 'cmaes_parameter_srv', self.parameter_callback)
|
self.parameter_srv = self.create_service(ParameterChange, 'cmaes_parameter_srv', self.parameter_callback)
|
||||||
self.query_srv = self.create_client(Query, 'query_srv')
|
self.query_srv = self.create_client(Query, 'query_srv')
|
||||||
self.task_srv = self.create_client(TaskEvaluation, 'task_srv')
|
self.task_srv = self.create_client(TaskEvaluation, 'task_srv')
|
||||||
|
self.user_interface_srv = self.create_client(UserInterface, 'user_interface_srv')
|
||||||
|
|
||||||
# Define states
|
# Define states
|
||||||
self.states = [
|
self.states = [
|
||||||
@ -126,11 +130,11 @@ class CMAESOptimizationNode(Node):
|
|||||||
request = Query.Request()
|
request = Query.Request()
|
||||||
|
|
||||||
if self.query_metric == 'random':
|
if self.query_metric == 'random':
|
||||||
pass
|
pass #TODO
|
||||||
elif self.query_metric == 'regular':
|
elif self.query_metric == 'regular':
|
||||||
pass
|
pass #TODO
|
||||||
elif self.query_metric == 'improvement':
|
elif self.query_metric == 'improvement':
|
||||||
pass
|
pass #TODO
|
||||||
|
|
||||||
self.query_srv.wait_for_service(10)
|
self.query_srv.wait_for_service(10)
|
||||||
self.__future = self.query_srv.call_async(request)
|
self.__future = self.query_srv.call_async(request)
|
||||||
@ -157,12 +161,39 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.data_send_for_evaluation()
|
self.data_send_for_evaluation()
|
||||||
|
|
||||||
def on_enter_interactive_mode(self):
|
def on_enter_interactive_mode(self):
|
||||||
|
# Reset Mixed Reality heartbeat to check if the other node crashed
|
||||||
|
self.last_mr_heartbeat_time = None
|
||||||
|
|
||||||
request = UserInterface.Request()
|
request = UserInterface.Request()
|
||||||
request.current_cma_mean = flatten_population(self.cmaes.mean)
|
request.current_cma_mean = flatten_population(self.cmaes.mean)
|
||||||
request.best_parameters
|
request.best_parameters = flatten_population(self.best_parameters_per_interation[-1])
|
||||||
|
request.current_user_covariance_diag = self.user_covariance
|
||||||
|
|
||||||
|
self.__future = self.user_interface_srv.call_async(request)
|
||||||
|
self.__future.add_done_callback(self.handle_user_response)
|
||||||
|
|
||||||
|
self.data_send_to_user()
|
||||||
|
|
||||||
def on_enter_prepare_user_data_for_evaluation(self):
|
def on_enter_prepare_user_data_for_evaluation(self):
|
||||||
pass
|
# Reset Task heartbeat to check if the other node crashed
|
||||||
|
self.last_task_heartbeat_time = None
|
||||||
|
|
||||||
|
# Update the user_covariance
|
||||||
|
self.user_covariance = self.__user_response.user_covariance_diag
|
||||||
|
|
||||||
|
request = TaskEvaluation.Request()
|
||||||
|
request.user_input = True
|
||||||
|
request.number_of_population = self.cmaes.popsize
|
||||||
|
request.user_parameters = self.__user_response.user_parameters
|
||||||
|
request.user_covariance_diag = self.user_covariance
|
||||||
|
request.current_cma_mean = flatten_population(self.cmaes.mean)
|
||||||
|
request.conditional_points = self.__user_response.conditional_points
|
||||||
|
request.weight_parameter = self.__user_response.weight_parameter
|
||||||
|
|
||||||
|
self.__future = self.task_srv.call_async(request)
|
||||||
|
self.__future.add_done_callback(self.handle_task_response)
|
||||||
|
|
||||||
|
self.send_user_data_to_evaluation()
|
||||||
|
|
||||||
def on_enter_update_optimizer(self):
|
def on_enter_update_optimizer(self):
|
||||||
population = unflatten_population(self.__task_response.parameter_array,
|
population = unflatten_population(self.__task_response.parameter_array,
|
||||||
@ -298,4 +329,12 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.get_logger().error(f'Task service call failed: {e}')
|
self.get_logger().error(f'Task service call failed: {e}')
|
||||||
self.error_trigger()
|
self.error_trigger()
|
||||||
|
|
||||||
|
def handle_user_response(self, future):
|
||||||
|
try:
|
||||||
|
self.__user_response = future.result()
|
||||||
|
self.user_response_received()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Task service call failed: {e}')
|
||||||
|
self.error_trigger()
|
||||||
# endregion
|
# endregion
|
||||||
|
Loading…
Reference in New Issue
Block a user