cma-es node in dev

This commit is contained in:
Niko Feith 2024-02-29 14:37:36 +01:00
parent acf51f01c6
commit a23e5be64c
9 changed files with 249 additions and 3 deletions

View File

@ -3,12 +3,13 @@ FROM osrf/ros:humble-desktop-full-jammy
# Update and install dependencies # Update and install dependencies
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
python3-colcon-common-extensions python3-pip \ python3-colcon-common-extensions python3-pip
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt ./ COPY requirements.txt ./
RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir -r requirements.txt
RUN rm -rf /var/lib/apt/lists/*
# Create a workspace # Create a workspace
WORKDIR /ros2_ws WORKDIR /ros2_ws
@ -21,6 +22,7 @@ RUN . /opt/ros/humble/setup.sh && \
RUN echo "source /opt/ros/humble/setup.bash" >> ~/.bashrc RUN echo "source /opt/ros/humble/setup.bash" >> ~/.bashrc
ENV PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/ros/humble/bin"
# Source the workspace # Source the workspace
CMD ["/bin/bash"] CMD ["/bin/bash"]

View File

@ -27,6 +27,25 @@ docker build -t interactive-robot-learning-framework .
docker-compose up -d docker-compose up -d
``` ```
## Setting up Pycharm Environment
Add a new Remote Interpreter based on the docker-compose file.
Then add the following paths to the interpreter paths. You can find them when clicking on the Interpreter in the lower right corner of the Pycharm window.
Then select "Interpreter Settings" within the setting go to "Show All..." Python interpreters (its in the list of the Python Interpreters).
Subsequently, you can add with the "+" a new interpreter, choose "Docker Compose" and select the "docker-compose.yaml" file from this repository.
The last step is to add the ros2 paths within your docker container to the interpreter paths (its one of the buttons next to the "+".
The following paths are necessary to develop in the docker container otherwise python cannot find rclpy and the custom msg packages:
```bash
/opt/ros/humble/local/lib/python3.10/dist-packages
/opt/ros/humble/lib/python3.10/site-packages
/ros2_ws/install/interaction_msgs/local/lib/python3.10/dist-packages
```
If there is a
## Framework Structure ## Framework Structure
The Interactive Robot Learning Framework consists of several key ROS2 packages, each responsible for different aspects of robot learning and interaction. Below is an overview of each package: The Interactive Robot Learning Framework consists of several key ROS2 packages, each responsible for different aspects of robot learning and interaction. Below is an overview of each package:

View File

@ -7,4 +7,9 @@ lark~=1.1.1
scipy~=1.12.0 scipy~=1.12.0
scikit-learn~=1.4.0 scikit-learn~=1.4.0
transitions==0.9.0 transitions==0.9.0
movement-primitives[all]~=0.7.0 movement-primitives[all]~=0.7.0
cma~=3.3.0
rclpy~=3.3.11
PyYAML~=5.4.1
rpyutils~=0.2.1
turtlesim~=1.4.2

View File

@ -11,6 +11,7 @@ find_package(rosidl_default_generators REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME} rosidl_generate_interfaces(${PROJECT_NAME}
"srv/Query.srv" "srv/Query.srv"
"srv/Task.srv" "srv/Task.srv"
"srv/TaskEvaluation.srv"
"msg/OptimizerState.msg" "msg/OptimizerState.msg"
"msg/Opt2UI.msg" "msg/Opt2UI.msg"
"msg/UI2Opt.msg" "msg/UI2Opt.msg"

View File

View File

@ -0,0 +1,19 @@
bool user_input
# case if user_input is true
uint16 nr_of_population
float32[] user_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] user_covariance_diag # same as user_mean
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
float32[] conditional_points # Length: (number_of_dimensions + time_stamp) * number_of_conditional_points
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
# case if user_input is false
uint16 number_of_population
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
uint16 number_of_parameters_per_dimensions
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension
---
# response
float32[] parameter_array # this is needed because in case of user input the parameters arent known yet
float32[] score

View File

@ -0,0 +1,4 @@
initial_mean_centre: 0.0
initial_mean_std_dev: 0.2
initial_variance: 0.3
random_seed: ''

View File

@ -0,0 +1,193 @@
import os
import rclpy
from rclpy.node import Node
from transitions import Machine
import cma
import yaml
import numpy as np
# Msg/Srv/Action
from interaction_msgs.srv import Query
from interaction_msgs.srv import TaskEvaluation
from std_msgs.msg import Bool
class CMAESOptimizationNode(Node):
def __init__(self):
super().__init__('cmaes_optimization_node')
# CMA-ES Attributes
self.cmaes = None
self.number_of_dimensions = 3
self.number_of_parameters_per_dimensions = 10
# the number of weights is double the number of dims * params per dim since its Position and Velocity
self.number_of_weights = 2 * self.number_of_dimensions * self.number_of_parameters_per_dimensions
self.random_seed = None
# Query Attributes
self.query_metric = 'random'
self.query_parameters = {}
self.best_rewards_per_iteration = []
# ROS2 Interfaces
self.__future = None
# Heartbeat Topics - to make sure that there is no deadlock
self.heartbeat_timeout = 30 # secs
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
# MR Heartbeat
self.last_mr_heartbeat_time = None
self.mr_heartbeat_sub = self.create_subscription(Bool, 'interaction/mr_heartbeat', self.mr_heatbeat_callback)
self.last_task_heartbeat_time = None
self.task_heartbeat_sub = self.create_subscription(Bool, 'interaction/task_heartbeat', self.task_heartbeat_callback)
# Topics
# Service
self.query_srv = self.create_client(Query, 'query_srv')
# Define states
self.states = [
'idle',
'initialization',
'query_decision',
'non_interactive_mode',
'wait_for_evaluation',
'interactive_mode',
'wait_for_user_response',
'prepare_user_data_for_evaluation',
'wait_for_user_evaluation',
'update_optimizer',
'check_termination',
'complete',
'error_recovery',
]
# Initialize state machine
self.machine = Machine(self, states=self.states, initial='idle')
# region Transitions
self.machine.add_transition(trigger='order_received', source='idle', dest='initialization')
self.machine.add_transition(trigger='initialization_complete', source='initialization', dest='query_decision')
self.machine.add_transition(trigger='no_interaction', source='query_decision', dest='non_interactive_mode')
self.machine.add_transition(trigger='data_send_for_evaluation', source='non_interactive_mode', dest='wait_for_evaluation')
self.machine.add_transition(trigger='evaluation_response_received', source='wait_for_evaluation', dest='update_optimizer')
self.machine.add_transition(trigger='interaction', source='query_decision', dest='interactive_mode')
self.machine.add_transition(trigger='data_send_to_user', source='interactive_mode', dest='wait_for_user_response')
self.machine.add_transition(trigger='user_response_received', source='wait_for_user_response', dest='prepare_user_data_for_evaluation')
self.machine.add_transition(trigger='send_user_data_to_evaluation', source='prepare_user_data_for_evaluation', dest='wait_for_user_evaluation')
self.machine.add_transition(trigger='received_user_data_for_evaluation', source='wait_for_user_evaluation', dest='update_optimizer')
self.machine.add_transition(trigger='optimizer_updated', source='update_optimizer', dest='check_termination')
self.machine.add_transition(trigger='next_optimizer_step', source='check_termination', dest='query_decision')
self.machine.add_transition(trigger='finished', source='check_termination', dest='complete')
self.machine.add_transition(trigger='error_trigger', source='*', dest='error_recovery')
self.machine.add_transition(trigger='recovery_complete', source='error_recovery', dest='idle')
# endregion
# region State Functions
def on_enter_initialization(self):
config_file_path = self.get_parameter('cmaes_config_file_path').get_parameter_value().string_value
# Load YAML
with open(config_file_path, 'r') as file:
config = yaml.safe_load(file)
if config['random_seed'] == '':
self.random_seed = None
else:
self.random_seed = int(config['random_seed'])
config['seed'] = self.random_seed
mean_centre = config['initial_mean_centre']
mean_std_dev = config['initial_mean_std_dev']
random_gen = np.random.default_rng(seed=self.random_seed)
initial_mean = random_gen.normal(mean_centre, mean_std_dev, size=(self.number_of_weights, 1))
initial_variance = config['initial_variance']
self.cmaes = cma.CMAEvolutionStrategy(initial_mean, initial_variance, inopts=config)
# Trigger transition
self.initialization_complete()
def on_enter_query_decision(self):
request = Query.Request()
if self.query_metric == 'random':
pass
elif self.query_metric == 'regular':
pass
elif self.query_metric == 'improvement':
pass
self.__future = self.query_srv.call_async(request)
self.__future.add_done_callback(self.handle_query_response)
def on_enter_non_interactive_mode(self):
# Reset Task heartbeat to check if the other node crashed
self.last_task_heartbeat_time = None
request = TaskEvaluation.Request()
request.user_input = False
request.number_of_population = self.cmaes.popsize
request.number_of_dimensions = self.number_of_dimensions * 2
request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions
population = self.cmaes.ask()
def on_enter_interactive_mode(self):
pass
def on_enter_prepare_user_data_for_evaluation(self):
pass
def on_enter_update_optimizer(self):
pass
def on_enter_check_termination(self):
pass
def on_enter_error_recovery(self):
pass
# endregion
# region Callback Functions
def check_heartbeats(self):
if self.last_mr_heartbeat_time:
current_time = self.get_clock().now().nanoseconds
if (current_time - self.last_mr_heartbeat_time) > (self.heartbeat_timeout * 1e9):
self.get_logger().error("MR Interface heartbeat timed out!")
self.error_trigger()
else:
pass
if self.last_mr_heartbeat_time:
current_time = self.get_clock().now().nanoseconds
if (current_time - self.last_task_heartbeat_time) > (self.heartbeat_timeout * 1e9):
self.get_logger().error("Task Node heartbeat timed out!")
self.error_trigger()
else:
pass
def mr_heartbeat_callback(self, _):
self.last_mr_heartbeat_time = self.get_clock().now().nanoseconds
def handle_query_response(self, future):
try:
response = future.result()
if response.interaction:
self.interaction()
else:
self.non_interaction()
except Exception as e:
self.get_logger().error(f'Query service call failed: {str(e)}')
self.error_trigger()
# endregion

View File

@ -1,4 +1,6 @@
from setuptools import find_packages, setup from setuptools import find_packages, setup
import os
from glob import glob
package_name = 'interaction_optimizers' package_name = 'interaction_optimizers'
@ -10,6 +12,7 @@ setup(
('share/ament_index/resource_index/packages', ('share/ament_index/resource_index/packages',
['resource/' + package_name]), ['resource/' + package_name]),
('share/' + package_name, ['package.xml']), ('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
], ],
install_requires=['setuptools'], install_requires=['setuptools'],
zip_safe=True, zip_safe=True,