continue with optimizer node

This commit is contained in:
Niko Feith 2024-03-12 14:48:40 +01:00
parent 7c3ad608f7
commit aa04e0b833
6 changed files with 114 additions and 12 deletions

View File

@ -9,7 +9,4 @@ scikit-learn~=1.4.0
transitions==0.9.0 transitions==0.9.0
movement-primitives[all]~=0.7.0 movement-primitives[all]~=0.7.0
cma~=3.3.0 cma~=3.3.0
rclpy~=3.3.11 PyYAML~=5.4.1
PyYAML~=5.4.1
rpyutils~=0.2.1
turtlesim~=1.4.2

View File

@ -12,6 +12,8 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"srv/Query.srv" "srv/Query.srv"
"srv/Task.srv" "srv/Task.srv"
"srv/TaskEvaluation.srv" "srv/TaskEvaluation.srv"
"srv/UserInterface.srv"
"srv/ParameterChange.srv"
"msg/OptimizerState.msg" "msg/OptimizerState.msg"
"msg/Opt2UI.msg" "msg/Opt2UI.msg"
"msg/UI2Opt.msg" "msg/UI2Opt.msg"

View File

@ -0,0 +1,6 @@
string[] parameter_name
string[] parameter_type # One of 'float', 'string', 'bool', potentially others
string[] parameter_value
---
bool success
string message # For error reporting

View File

@ -1,10 +1,10 @@
bool user_input bool user_input
# case if user_input is true # case if user_input is true
uint16 nr_of_population uint16 nr_of_population
float32[] user_mean # Length: number_of_dimensions * number_of_parameters_per_dimension float32[] user_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] user_covariance_diag # same as user_mean float32[] user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] conditional_points # Length: (number_of_dimensions + time_stamp) * number_of_conditional_points float32[] conditional_points # Length: (number_of_dimensions + time_stamp[0,1]) * number_of_conditional_points
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that) float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
# case if user_input is false # case if user_input is false
@ -12,7 +12,6 @@ uint16 number_of_population
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity) uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
uint16 number_of_parameters_per_dimensions uint16 number_of_parameters_per_dimensions
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension
--- ---
# response # response
float32[] parameter_array # this is needed because in case of user input the parameters arent known yet float32[] parameter_array # this is needed because in case of user input the parameters arent known yet

View File

@ -0,0 +1,8 @@
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] best_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] current_user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
---
float32[] user_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] conditional_points # Length: (number_of_dimensions + time_stamp[0,1]) * number_of_conditional_points
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)

View File

@ -2,6 +2,7 @@ import os
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
from rclpy.parameter import Parameter
from transitions import Machine from transitions import Machine
import cma import cma
@ -12,6 +13,8 @@ from src.interaction_utils.serialization import flatten_population, unflatten_po
# Msg/Srv/Action # Msg/Srv/Action
from interaction_msgs.srv import Query from interaction_msgs.srv import Query
from interaction_msgs.srv import TaskEvaluation from interaction_msgs.srv import TaskEvaluation
from interaction_msgs.srv import ParameterChange
from interaction_msgs.srv import UserInterface
from std_msgs.msg import Bool from std_msgs.msg import Bool
@ -21,6 +24,8 @@ class CMAESOptimizationNode(Node):
# CMA-ES Attributes # CMA-ES Attributes
self.cmaes = None self.cmaes = None
self.number_of_initial_episodes = self.declare_parameter('number_of_initial_episodes', 5)
self.episode = 0
self.number_of_dimensions = 3 self.number_of_dimensions = 3
self.number_of_parameters_per_dimensions = 10 self.number_of_parameters_per_dimensions = 10
# the number of weights is double the number of dims * params per dim since its Position and Velocity # the number of weights is double the number of dims * params per dim since its Position and Velocity
@ -31,7 +36,8 @@ class CMAESOptimizationNode(Node):
# Query Attributes # Query Attributes
self.query_metric = 'random' self.query_metric = 'random'
self.query_parameters = {} self.query_parameters = {}
self.best_rewards_per_iteration = [] self.best_parameters_per_interation = []
self.best_reward_per_iteration = []
# ROS2 Interfaces # ROS2 Interfaces
self.__future = None self.__future = None
self.__task_response = None self.__task_response = None
@ -48,6 +54,7 @@ class CMAESOptimizationNode(Node):
# Topics # Topics
# Service # Service
self.parameter_srv = self.create_service(ParameterChange, 'cmaes_parameter_srv', self.parameter_callback)
self.query_srv = self.create_client(Query, 'query_srv') self.query_srv = self.create_client(Query, 'query_srv')
self.task_srv = self.create_client(TaskEvaluation, 'task_srv') self.task_srv = self.create_client(TaskEvaluation, 'task_srv')
@ -125,6 +132,7 @@ class CMAESOptimizationNode(Node):
elif self.query_metric == 'improvement': elif self.query_metric == 'improvement':
pass pass
self.query_srv.wait_for_service(10)
self.__future = self.query_srv.call_async(request) self.__future = self.query_srv.call_async(request)
self.__future.add_done_callback(self.handle_query_response) self.__future.add_done_callback(self.handle_query_response)
@ -149,15 +157,41 @@ class CMAESOptimizationNode(Node):
self.data_send_for_evaluation() self.data_send_for_evaluation()
def on_enter_interactive_mode(self): def on_enter_interactive_mode(self):
pass request = UserInterface.Request()
request.current_cma_mean = flatten_population(self.cmaes.mean)
request.best_parameters
def on_enter_prepare_user_data_for_evaluation(self): def on_enter_prepare_user_data_for_evaluation(self):
pass pass
def on_enter_update_optimizer(self): def on_enter_update_optimizer(self):
pass population = unflatten_population(self.__task_response.parameter_array,
self.cmaes.popsize,
self.number_of_dimensions,
self.number_of_parameters_per_dimensions)
scores = self.__task_response.scores
self.cmaes.tell(population, scores)
# save best results
best_idx = max(enumerate(scores), key=lambda item: item[1])[0]
self.best_reward_per_iteration.append(scores[best_idx].copy())
self.best_parameters_per_interation.append(population[best_idx, :, :].copy())
self.__task_response = None
self.episode += 1
self.optimizer_updated()
def on_enter_check_termination(self): def on_enter_check_termination(self):
max_episodes = self.get_parameter('max_episode_count').get_parameter_value().integer_value
if self.episode >= max_episodes:
self.finished()
else:
self.next_optimizer_step()
def on_enter_complete(self):
pass pass
def on_enter_error_recovery(self): def on_enter_error_recovery(self):
@ -166,6 +200,59 @@ class CMAESOptimizationNode(Node):
# endregion # endregion
# region Callback Functions # region Callback Functions
def parameter_callback(self, request, response):
param_names = request.parameter_name
param_types = request.parameter_type
param_values_str = request.parameter_value
if len(param_names) != len(param_types) or len(param_types) != len(param_values_str):
response.success = False
response.message = 'Lists must have the same size'
return response
try:
# Parameter update loop
all_params = []
for i in range(len(param_names)):
param_name = param_names[i]
param_type = param_types[i]
param_value_str = param_values_str[i]
# Input Validation (adjust as needed)
if param_type not in ['float32', 'string', 'bool']:
response.success = False
response.message = 'Unsupported parameter type'
return response
# Value Conversion (based on param_type)
if param_type == 'float32':
param_value = float(param_value_str)
elif param_type == 'bool':
param_value = (param_value_str.lower() == 'true')
else: # 'string'
param_value = param_value_str
# Compose the all parameter list
param = Parameter(param_name, Parameter.Type.from_parameter_value(param_value), param_value)
all_params.append(param)
# Attempt to set the parameter
set_params_result = self.set_parameters(all_params)
response.success = True # Assume success unless set_parameters fails
for result in set_params_result:
if not result.successful:
response.success = False
response.message = result.reason
break
return response
except Exception as e:
self.get_logger().error(f'Parameter update failed: {str(e)}')
response.success = False
response.message = 'Parameter update failed, please check logs'
return response
def check_heartbeats(self): def check_heartbeats(self):
if self.last_mr_heartbeat_time: if self.last_mr_heartbeat_time:
current_time = self.get_clock().now().nanoseconds current_time = self.get_clock().now().nanoseconds
@ -190,6 +277,9 @@ class CMAESOptimizationNode(Node):
try: try:
response = future.result() response = future.result()
if self.episode < self.number_of_initial_episodes:
self.non_interaction()
if response.interaction: if response.interaction:
self.interaction() self.interaction()
else: else:
@ -208,4 +298,4 @@ class CMAESOptimizationNode(Node):
self.get_logger().error(f'Task service call failed: {e}') self.get_logger().error(f'Task service call failed: {e}')
self.error_trigger() self.error_trigger()
# endregion # endregion