diff --git a/Dockerfile b/Dockerfile index 3b29d01..f153b9b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,12 +3,13 @@ FROM osrf/ros:humble-desktop-full-jammy # Update and install dependencies RUN apt-get update && apt-get install -y \ - python3-colcon-common-extensions python3-pip \ - && rm -rf /var/lib/apt/lists/* + python3-colcon-common-extensions python3-pip COPY requirements.txt ./ RUN pip install --no-cache-dir -r requirements.txt +RUN rm -rf /var/lib/apt/lists/* + # Create a workspace WORKDIR /ros2_ws @@ -21,6 +22,7 @@ RUN . /opt/ros/humble/setup.sh && \ RUN echo "source /opt/ros/humble/setup.bash" >> ~/.bashrc +ENV PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/ros/humble/bin" # Source the workspace CMD ["/bin/bash"] diff --git a/README.md b/README.md index 5906c34..2d04b65 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,25 @@ docker build -t interactive-robot-learning-framework . docker-compose up -d ``` +## Setting up Pycharm Environment +Add a new Remote Interpreter based on the docker-compose file. +Then add the following paths to the interpreter paths. You can find them when clicking on the Interpreter in the lower right corner of the Pycharm window. +Then select "Interpreter Settings" within the setting go to "Show All..." Python interpreters (its in the list of the Python Interpreters). +Subsequently, you can add with the "+" a new interpreter, choose "Docker Compose" and select the "docker-compose.yaml" file from this repository. +The last step is to add the ros2 paths within your docker container to the interpreter paths (its one of the buttons next to the "+". + +The following paths are necessary to develop in the docker container otherwise python cannot find rclpy and the custom msg packages: + +```bash +/opt/ros/humble/local/lib/python3.10/dist-packages + +/opt/ros/humble/lib/python3.10/site-packages + +/ros2_ws/install/interaction_msgs/local/lib/python3.10/dist-packages +``` +If there is a + + ## Framework Structure The Interactive Robot Learning Framework consists of several key ROS2 packages, each responsible for different aspects of robot learning and interaction. Below is an overview of each package: diff --git a/requirements.txt b/requirements.txt index a898367..e00b53d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,9 @@ lark~=1.1.1 scipy~=1.12.0 scikit-learn~=1.4.0 transitions==0.9.0 -movement-primitives[all]~=0.7.0 \ No newline at end of file +movement-primitives[all]~=0.7.0 +cma~=3.3.0 +rclpy~=3.3.11 +PyYAML~=5.4.1 +rpyutils~=0.2.1 +turtlesim~=1.4.2 \ No newline at end of file diff --git a/src/interaction_msgs/CMakeLists.txt b/src/interaction_msgs/CMakeLists.txt index 5a3f60c..9ff2a86 100644 --- a/src/interaction_msgs/CMakeLists.txt +++ b/src/interaction_msgs/CMakeLists.txt @@ -11,6 +11,7 @@ find_package(rosidl_default_generators REQUIRED) rosidl_generate_interfaces(${PROJECT_NAME} "srv/Query.srv" "srv/Task.srv" + "srv/TaskEvaluation.srv" "msg/OptimizerState.msg" "msg/Opt2UI.msg" "msg/UI2Opt.msg" diff --git a/src/interaction_msgs/msg/Order.msg b/src/interaction_msgs/msg/Order.msg new file mode 100644 index 0000000..e69de29 diff --git a/src/interaction_msgs/srv/TaskEvaluation.srv b/src/interaction_msgs/srv/TaskEvaluation.srv new file mode 100644 index 0000000..8187ce5 --- /dev/null +++ b/src/interaction_msgs/srv/TaskEvaluation.srv @@ -0,0 +1,19 @@ +bool user_input +# case if user_input is true +uint16 nr_of_population +float32[] user_mean # Length: number_of_dimensions * number_of_parameters_per_dimension +float32[] user_covariance_diag # same as user_mean +float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension +float32[] conditional_points # Length: (number_of_dimensions + time_stamp) * number_of_conditional_points +float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that) + +# case if user_input is false +uint16 number_of_population +uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity) +uint16 number_of_parameters_per_dimensions +float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension + +--- +# response +float32[] parameter_array # this is needed because in case of user input the parameters arent known yet +float32[] score \ No newline at end of file diff --git a/src/interaction_optimizers/config/cma_es_init_parameters.yaml b/src/interaction_optimizers/config/cma_es_init_parameters.yaml new file mode 100644 index 0000000..ef95f1b --- /dev/null +++ b/src/interaction_optimizers/config/cma_es_init_parameters.yaml @@ -0,0 +1,4 @@ +initial_mean_centre: 0.0 +initial_mean_std_dev: 0.2 +initial_variance: 0.3 +random_seed: '' diff --git a/src/interaction_optimizers/interaction_optimizers/cmaes_optimization_node.py b/src/interaction_optimizers/interaction_optimizers/cmaes_optimization_node.py new file mode 100644 index 0000000..df24703 --- /dev/null +++ b/src/interaction_optimizers/interaction_optimizers/cmaes_optimization_node.py @@ -0,0 +1,193 @@ +import os + +import rclpy +from rclpy.node import Node + +from transitions import Machine +import cma +import yaml +import numpy as np + +# Msg/Srv/Action +from interaction_msgs.srv import Query +from interaction_msgs.srv import TaskEvaluation +from std_msgs.msg import Bool + + +class CMAESOptimizationNode(Node): + def __init__(self): + super().__init__('cmaes_optimization_node') + + # CMA-ES Attributes + self.cmaes = None + self.number_of_dimensions = 3 + self.number_of_parameters_per_dimensions = 10 + # the number of weights is double the number of dims * params per dim since its Position and Velocity + self.number_of_weights = 2 * self.number_of_dimensions * self.number_of_parameters_per_dimensions + + self.random_seed = None + + # Query Attributes + self.query_metric = 'random' + self.query_parameters = {} + self.best_rewards_per_iteration = [] + # ROS2 Interfaces + self.__future = None + # Heartbeat Topics - to make sure that there is no deadlock + self.heartbeat_timeout = 30 # secs + self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats) + + # MR Heartbeat + self.last_mr_heartbeat_time = None + self.mr_heartbeat_sub = self.create_subscription(Bool, 'interaction/mr_heartbeat', self.mr_heatbeat_callback) + self.last_task_heartbeat_time = None + self.task_heartbeat_sub = self.create_subscription(Bool, 'interaction/task_heartbeat', self.task_heartbeat_callback) + + # Topics + + # Service + self.query_srv = self.create_client(Query, 'query_srv') + + # Define states + self.states = [ + 'idle', + 'initialization', + 'query_decision', + 'non_interactive_mode', + 'wait_for_evaluation', + 'interactive_mode', + 'wait_for_user_response', + 'prepare_user_data_for_evaluation', + 'wait_for_user_evaluation', + 'update_optimizer', + 'check_termination', + 'complete', + 'error_recovery', + ] + + # Initialize state machine + self.machine = Machine(self, states=self.states, initial='idle') + + # region Transitions + self.machine.add_transition(trigger='order_received', source='idle', dest='initialization') + self.machine.add_transition(trigger='initialization_complete', source='initialization', dest='query_decision') + self.machine.add_transition(trigger='no_interaction', source='query_decision', dest='non_interactive_mode') + self.machine.add_transition(trigger='data_send_for_evaluation', source='non_interactive_mode', dest='wait_for_evaluation') + self.machine.add_transition(trigger='evaluation_response_received', source='wait_for_evaluation', dest='update_optimizer') + self.machine.add_transition(trigger='interaction', source='query_decision', dest='interactive_mode') + self.machine.add_transition(trigger='data_send_to_user', source='interactive_mode', dest='wait_for_user_response') + self.machine.add_transition(trigger='user_response_received', source='wait_for_user_response', dest='prepare_user_data_for_evaluation') + self.machine.add_transition(trigger='send_user_data_to_evaluation', source='prepare_user_data_for_evaluation', dest='wait_for_user_evaluation') + self.machine.add_transition(trigger='received_user_data_for_evaluation', source='wait_for_user_evaluation', dest='update_optimizer') + self.machine.add_transition(trigger='optimizer_updated', source='update_optimizer', dest='check_termination') + self.machine.add_transition(trigger='next_optimizer_step', source='check_termination', dest='query_decision') + self.machine.add_transition(trigger='finished', source='check_termination', dest='complete') + self.machine.add_transition(trigger='error_trigger', source='*', dest='error_recovery') + self.machine.add_transition(trigger='recovery_complete', source='error_recovery', dest='idle') + # endregion + + # region State Functions + def on_enter_initialization(self): + config_file_path = self.get_parameter('cmaes_config_file_path').get_parameter_value().string_value + + # Load YAML + with open(config_file_path, 'r') as file: + config = yaml.safe_load(file) + + if config['random_seed'] == '': + self.random_seed = None + else: + self.random_seed = int(config['random_seed']) + config['seed'] = self.random_seed + + mean_centre = config['initial_mean_centre'] + mean_std_dev = config['initial_mean_std_dev'] + + random_gen = np.random.default_rng(seed=self.random_seed) + + initial_mean = random_gen.normal(mean_centre, mean_std_dev, size=(self.number_of_weights, 1)) + initial_variance = config['initial_variance'] + self.cmaes = cma.CMAEvolutionStrategy(initial_mean, initial_variance, inopts=config) + + # Trigger transition + self.initialization_complete() + + def on_enter_query_decision(self): + request = Query.Request() + + if self.query_metric == 'random': + pass + elif self.query_metric == 'regular': + pass + elif self.query_metric == 'improvement': + pass + + self.__future = self.query_srv.call_async(request) + self.__future.add_done_callback(self.handle_query_response) + + def on_enter_non_interactive_mode(self): + # Reset Task heartbeat to check if the other node crashed + self.last_task_heartbeat_time = None + + request = TaskEvaluation.Request() + request.user_input = False + request.number_of_population = self.cmaes.popsize + request.number_of_dimensions = self.number_of_dimensions * 2 + request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions + + population = self.cmaes.ask() + + + + def on_enter_interactive_mode(self): + pass + + def on_enter_prepare_user_data_for_evaluation(self): + pass + + def on_enter_update_optimizer(self): + pass + + def on_enter_check_termination(self): + pass + + def on_enter_error_recovery(self): + pass + + # endregion + + # region Callback Functions + def check_heartbeats(self): + if self.last_mr_heartbeat_time: + current_time = self.get_clock().now().nanoseconds + if (current_time - self.last_mr_heartbeat_time) > (self.heartbeat_timeout * 1e9): + self.get_logger().error("MR Interface heartbeat timed out!") + self.error_trigger() + else: + pass + + if self.last_mr_heartbeat_time: + current_time = self.get_clock().now().nanoseconds + if (current_time - self.last_task_heartbeat_time) > (self.heartbeat_timeout * 1e9): + self.get_logger().error("Task Node heartbeat timed out!") + self.error_trigger() + else: + pass + + def mr_heartbeat_callback(self, _): + self.last_mr_heartbeat_time = self.get_clock().now().nanoseconds + + def handle_query_response(self, future): + try: + response = future.result() + + if response.interaction: + self.interaction() + else: + self.non_interaction() + + except Exception as e: + self.get_logger().error(f'Query service call failed: {str(e)}') + self.error_trigger() + + # endregion \ No newline at end of file diff --git a/src/interaction_optimizers/setup.py b/src/interaction_optimizers/setup.py index 243c133..5344ab5 100644 --- a/src/interaction_optimizers/setup.py +++ b/src/interaction_optimizers/setup.py @@ -1,4 +1,6 @@ from setuptools import find_packages, setup +import os +from glob import glob package_name = 'interaction_optimizers' @@ -10,6 +12,7 @@ setup( ('share/ament_index/resource_index/packages', ['resource/' + package_name]), ('share/' + package_name, ['package.xml']), + (os.path.join('share', package_name, 'config'), glob('config/*.yaml')), ], install_requires=['setuptools'], zip_safe=True,