continue with optimizer node
This commit is contained in:
parent
7c3ad608f7
commit
aa04e0b833
@ -9,7 +9,4 @@ scikit-learn~=1.4.0
|
|||||||
transitions==0.9.0
|
transitions==0.9.0
|
||||||
movement-primitives[all]~=0.7.0
|
movement-primitives[all]~=0.7.0
|
||||||
cma~=3.3.0
|
cma~=3.3.0
|
||||||
rclpy~=3.3.11
|
|
||||||
PyYAML~=5.4.1
|
PyYAML~=5.4.1
|
||||||
rpyutils~=0.2.1
|
|
||||||
turtlesim~=1.4.2
|
|
@ -12,6 +12,8 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
|||||||
"srv/Query.srv"
|
"srv/Query.srv"
|
||||||
"srv/Task.srv"
|
"srv/Task.srv"
|
||||||
"srv/TaskEvaluation.srv"
|
"srv/TaskEvaluation.srv"
|
||||||
|
"srv/UserInterface.srv"
|
||||||
|
"srv/ParameterChange.srv"
|
||||||
"msg/OptimizerState.msg"
|
"msg/OptimizerState.msg"
|
||||||
"msg/Opt2UI.msg"
|
"msg/Opt2UI.msg"
|
||||||
"msg/UI2Opt.msg"
|
"msg/UI2Opt.msg"
|
||||||
|
6
src/interaction_msgs/srv/ParameterChange.srv
Normal file
6
src/interaction_msgs/srv/ParameterChange.srv
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
string[] parameter_name
|
||||||
|
string[] parameter_type # One of 'float', 'string', 'bool', potentially others
|
||||||
|
string[] parameter_value
|
||||||
|
---
|
||||||
|
bool success
|
||||||
|
string message # For error reporting
|
@ -1,10 +1,10 @@
|
|||||||
bool user_input
|
bool user_input
|
||||||
# case if user_input is true
|
# case if user_input is true
|
||||||
uint16 nr_of_population
|
uint16 nr_of_population
|
||||||
float32[] user_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
float32[] user_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
float32[] user_covariance_diag # same as user_mean
|
float32[] user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
float32[] conditional_points # Length: (number_of_dimensions + time_stamp) * number_of_conditional_points
|
float32[] conditional_points # Length: (number_of_dimensions + time_stamp[0,1]) * number_of_conditional_points
|
||||||
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
|
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
|
||||||
|
|
||||||
# case if user_input is false
|
# case if user_input is false
|
||||||
@ -12,7 +12,6 @@ uint16 number_of_population
|
|||||||
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
|
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
|
||||||
uint16 number_of_parameters_per_dimensions
|
uint16 number_of_parameters_per_dimensions
|
||||||
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension
|
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
|
||||||
---
|
---
|
||||||
# response
|
# response
|
||||||
float32[] parameter_array # this is needed because in case of user input the parameters arent known yet
|
float32[] parameter_array # this is needed because in case of user input the parameters arent known yet
|
||||||
|
8
src/interaction_msgs/srv/UserInterface.srv
Normal file
8
src/interaction_msgs/srv/UserInterface.srv
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
float32[] best_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
float32[] current_user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
---
|
||||||
|
float32[] user_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
float32[] user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
float32[] conditional_points # Length: (number_of_dimensions + time_stamp[0,1]) * number_of_conditional_points
|
||||||
|
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
|
@ -2,6 +2,7 @@ import os
|
|||||||
|
|
||||||
import rclpy
|
import rclpy
|
||||||
from rclpy.node import Node
|
from rclpy.node import Node
|
||||||
|
from rclpy.parameter import Parameter
|
||||||
|
|
||||||
from transitions import Machine
|
from transitions import Machine
|
||||||
import cma
|
import cma
|
||||||
@ -12,6 +13,8 @@ from src.interaction_utils.serialization import flatten_population, unflatten_po
|
|||||||
# Msg/Srv/Action
|
# Msg/Srv/Action
|
||||||
from interaction_msgs.srv import Query
|
from interaction_msgs.srv import Query
|
||||||
from interaction_msgs.srv import TaskEvaluation
|
from interaction_msgs.srv import TaskEvaluation
|
||||||
|
from interaction_msgs.srv import ParameterChange
|
||||||
|
from interaction_msgs.srv import UserInterface
|
||||||
from std_msgs.msg import Bool
|
from std_msgs.msg import Bool
|
||||||
|
|
||||||
|
|
||||||
@ -21,6 +24,8 @@ class CMAESOptimizationNode(Node):
|
|||||||
|
|
||||||
# CMA-ES Attributes
|
# CMA-ES Attributes
|
||||||
self.cmaes = None
|
self.cmaes = None
|
||||||
|
self.number_of_initial_episodes = self.declare_parameter('number_of_initial_episodes', 5)
|
||||||
|
self.episode = 0
|
||||||
self.number_of_dimensions = 3
|
self.number_of_dimensions = 3
|
||||||
self.number_of_parameters_per_dimensions = 10
|
self.number_of_parameters_per_dimensions = 10
|
||||||
# the number of weights is double the number of dims * params per dim since its Position and Velocity
|
# the number of weights is double the number of dims * params per dim since its Position and Velocity
|
||||||
@ -31,7 +36,8 @@ class CMAESOptimizationNode(Node):
|
|||||||
# Query Attributes
|
# Query Attributes
|
||||||
self.query_metric = 'random'
|
self.query_metric = 'random'
|
||||||
self.query_parameters = {}
|
self.query_parameters = {}
|
||||||
self.best_rewards_per_iteration = []
|
self.best_parameters_per_interation = []
|
||||||
|
self.best_reward_per_iteration = []
|
||||||
# ROS2 Interfaces
|
# ROS2 Interfaces
|
||||||
self.__future = None
|
self.__future = None
|
||||||
self.__task_response = None
|
self.__task_response = None
|
||||||
@ -48,6 +54,7 @@ class CMAESOptimizationNode(Node):
|
|||||||
# Topics
|
# Topics
|
||||||
|
|
||||||
# Service
|
# Service
|
||||||
|
self.parameter_srv = self.create_service(ParameterChange, 'cmaes_parameter_srv', self.parameter_callback)
|
||||||
self.query_srv = self.create_client(Query, 'query_srv')
|
self.query_srv = self.create_client(Query, 'query_srv')
|
||||||
self.task_srv = self.create_client(TaskEvaluation, 'task_srv')
|
self.task_srv = self.create_client(TaskEvaluation, 'task_srv')
|
||||||
|
|
||||||
@ -125,6 +132,7 @@ class CMAESOptimizationNode(Node):
|
|||||||
elif self.query_metric == 'improvement':
|
elif self.query_metric == 'improvement':
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self.query_srv.wait_for_service(10)
|
||||||
self.__future = self.query_srv.call_async(request)
|
self.__future = self.query_srv.call_async(request)
|
||||||
self.__future.add_done_callback(self.handle_query_response)
|
self.__future.add_done_callback(self.handle_query_response)
|
||||||
|
|
||||||
@ -149,15 +157,41 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.data_send_for_evaluation()
|
self.data_send_for_evaluation()
|
||||||
|
|
||||||
def on_enter_interactive_mode(self):
|
def on_enter_interactive_mode(self):
|
||||||
pass
|
request = UserInterface.Request()
|
||||||
|
request.current_cma_mean = flatten_population(self.cmaes.mean)
|
||||||
|
request.best_parameters
|
||||||
|
|
||||||
def on_enter_prepare_user_data_for_evaluation(self):
|
def on_enter_prepare_user_data_for_evaluation(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_enter_update_optimizer(self):
|
def on_enter_update_optimizer(self):
|
||||||
pass
|
population = unflatten_population(self.__task_response.parameter_array,
|
||||||
|
self.cmaes.popsize,
|
||||||
|
self.number_of_dimensions,
|
||||||
|
self.number_of_parameters_per_dimensions)
|
||||||
|
scores = self.__task_response.scores
|
||||||
|
|
||||||
|
self.cmaes.tell(population, scores)
|
||||||
|
|
||||||
|
# save best results
|
||||||
|
best_idx = max(enumerate(scores), key=lambda item: item[1])[0]
|
||||||
|
self.best_reward_per_iteration.append(scores[best_idx].copy())
|
||||||
|
self.best_parameters_per_interation.append(population[best_idx, :, :].copy())
|
||||||
|
|
||||||
|
self.__task_response = None
|
||||||
|
self.episode += 1
|
||||||
|
|
||||||
|
self.optimizer_updated()
|
||||||
|
|
||||||
def on_enter_check_termination(self):
|
def on_enter_check_termination(self):
|
||||||
|
max_episodes = self.get_parameter('max_episode_count').get_parameter_value().integer_value
|
||||||
|
|
||||||
|
if self.episode >= max_episodes:
|
||||||
|
self.finished()
|
||||||
|
else:
|
||||||
|
self.next_optimizer_step()
|
||||||
|
|
||||||
|
def on_enter_complete(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_enter_error_recovery(self):
|
def on_enter_error_recovery(self):
|
||||||
@ -166,6 +200,59 @@ class CMAESOptimizationNode(Node):
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Callback Functions
|
# region Callback Functions
|
||||||
|
def parameter_callback(self, request, response):
|
||||||
|
param_names = request.parameter_name
|
||||||
|
param_types = request.parameter_type
|
||||||
|
param_values_str = request.parameter_value
|
||||||
|
|
||||||
|
if len(param_names) != len(param_types) or len(param_types) != len(param_values_str):
|
||||||
|
response.success = False
|
||||||
|
response.message = 'Lists must have the same size'
|
||||||
|
return response
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parameter update loop
|
||||||
|
all_params = []
|
||||||
|
for i in range(len(param_names)):
|
||||||
|
param_name = param_names[i]
|
||||||
|
param_type = param_types[i]
|
||||||
|
param_value_str = param_values_str[i]
|
||||||
|
|
||||||
|
# Input Validation (adjust as needed)
|
||||||
|
if param_type not in ['float32', 'string', 'bool']:
|
||||||
|
response.success = False
|
||||||
|
response.message = 'Unsupported parameter type'
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Value Conversion (based on param_type)
|
||||||
|
if param_type == 'float32':
|
||||||
|
param_value = float(param_value_str)
|
||||||
|
elif param_type == 'bool':
|
||||||
|
param_value = (param_value_str.lower() == 'true')
|
||||||
|
else: # 'string'
|
||||||
|
param_value = param_value_str
|
||||||
|
|
||||||
|
# Compose the all parameter list
|
||||||
|
param = Parameter(param_name, Parameter.Type.from_parameter_value(param_value), param_value)
|
||||||
|
all_params.append(param)
|
||||||
|
|
||||||
|
# Attempt to set the parameter
|
||||||
|
set_params_result = self.set_parameters(all_params)
|
||||||
|
|
||||||
|
response.success = True # Assume success unless set_parameters fails
|
||||||
|
for result in set_params_result:
|
||||||
|
if not result.successful:
|
||||||
|
response.success = False
|
||||||
|
response.message = result.reason
|
||||||
|
break
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Parameter update failed: {str(e)}')
|
||||||
|
response.success = False
|
||||||
|
response.message = 'Parameter update failed, please check logs'
|
||||||
|
return response
|
||||||
|
|
||||||
def check_heartbeats(self):
|
def check_heartbeats(self):
|
||||||
if self.last_mr_heartbeat_time:
|
if self.last_mr_heartbeat_time:
|
||||||
current_time = self.get_clock().now().nanoseconds
|
current_time = self.get_clock().now().nanoseconds
|
||||||
@ -190,6 +277,9 @@ class CMAESOptimizationNode(Node):
|
|||||||
try:
|
try:
|
||||||
response = future.result()
|
response = future.result()
|
||||||
|
|
||||||
|
if self.episode < self.number_of_initial_episodes:
|
||||||
|
self.non_interaction()
|
||||||
|
|
||||||
if response.interaction:
|
if response.interaction:
|
||||||
self.interaction()
|
self.interaction()
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user