cma-es node in dev
This commit is contained in:
parent
acf51f01c6
commit
a23e5be64c
@ -3,12 +3,13 @@ FROM osrf/ros:humble-desktop-full-jammy
|
|||||||
|
|
||||||
# Update and install dependencies
|
# Update and install dependencies
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3-colcon-common-extensions python3-pip \
|
python3-colcon-common-extensions python3-pip
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
COPY requirements.txt ./
|
COPY requirements.txt ./
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
RUN rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Create a workspace
|
# Create a workspace
|
||||||
WORKDIR /ros2_ws
|
WORKDIR /ros2_ws
|
||||||
|
|
||||||
@ -21,6 +22,7 @@ RUN . /opt/ros/humble/setup.sh && \
|
|||||||
|
|
||||||
RUN echo "source /opt/ros/humble/setup.bash" >> ~/.bashrc
|
RUN echo "source /opt/ros/humble/setup.bash" >> ~/.bashrc
|
||||||
|
|
||||||
|
ENV PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/ros/humble/bin"
|
||||||
|
|
||||||
# Source the workspace
|
# Source the workspace
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
19
README.md
19
README.md
@ -27,6 +27,25 @@ docker build -t interactive-robot-learning-framework .
|
|||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Setting up Pycharm Environment
|
||||||
|
Add a new Remote Interpreter based on the docker-compose file.
|
||||||
|
Then add the following paths to the interpreter paths. You can find them when clicking on the Interpreter in the lower right corner of the Pycharm window.
|
||||||
|
Then select "Interpreter Settings" within the setting go to "Show All..." Python interpreters (its in the list of the Python Interpreters).
|
||||||
|
Subsequently, you can add with the "+" a new interpreter, choose "Docker Compose" and select the "docker-compose.yaml" file from this repository.
|
||||||
|
The last step is to add the ros2 paths within your docker container to the interpreter paths (its one of the buttons next to the "+".
|
||||||
|
|
||||||
|
The following paths are necessary to develop in the docker container otherwise python cannot find rclpy and the custom msg packages:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
/opt/ros/humble/local/lib/python3.10/dist-packages
|
||||||
|
|
||||||
|
/opt/ros/humble/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
/ros2_ws/install/interaction_msgs/local/lib/python3.10/dist-packages
|
||||||
|
```
|
||||||
|
If there is a
|
||||||
|
|
||||||
|
|
||||||
## Framework Structure
|
## Framework Structure
|
||||||
|
|
||||||
The Interactive Robot Learning Framework consists of several key ROS2 packages, each responsible for different aspects of robot learning and interaction. Below is an overview of each package:
|
The Interactive Robot Learning Framework consists of several key ROS2 packages, each responsible for different aspects of robot learning and interaction. Below is an overview of each package:
|
||||||
|
@ -7,4 +7,9 @@ lark~=1.1.1
|
|||||||
scipy~=1.12.0
|
scipy~=1.12.0
|
||||||
scikit-learn~=1.4.0
|
scikit-learn~=1.4.0
|
||||||
transitions==0.9.0
|
transitions==0.9.0
|
||||||
movement-primitives[all]~=0.7.0
|
movement-primitives[all]~=0.7.0
|
||||||
|
cma~=3.3.0
|
||||||
|
rclpy~=3.3.11
|
||||||
|
PyYAML~=5.4.1
|
||||||
|
rpyutils~=0.2.1
|
||||||
|
turtlesim~=1.4.2
|
@ -11,6 +11,7 @@ find_package(rosidl_default_generators REQUIRED)
|
|||||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
"srv/Query.srv"
|
"srv/Query.srv"
|
||||||
"srv/Task.srv"
|
"srv/Task.srv"
|
||||||
|
"srv/TaskEvaluation.srv"
|
||||||
"msg/OptimizerState.msg"
|
"msg/OptimizerState.msg"
|
||||||
"msg/Opt2UI.msg"
|
"msg/Opt2UI.msg"
|
||||||
"msg/UI2Opt.msg"
|
"msg/UI2Opt.msg"
|
||||||
|
0
src/interaction_msgs/msg/Order.msg
Normal file
0
src/interaction_msgs/msg/Order.msg
Normal file
19
src/interaction_msgs/srv/TaskEvaluation.srv
Normal file
19
src/interaction_msgs/srv/TaskEvaluation.srv
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
bool user_input
|
||||||
|
# case if user_input is true
|
||||||
|
uint16 nr_of_population
|
||||||
|
float32[] user_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
float32[] user_covariance_diag # same as user_mean
|
||||||
|
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
float32[] conditional_points # Length: (number_of_dimensions + time_stamp) * number_of_conditional_points
|
||||||
|
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
|
||||||
|
|
||||||
|
# case if user_input is false
|
||||||
|
uint16 number_of_population
|
||||||
|
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
|
||||||
|
uint16 number_of_parameters_per_dimensions
|
||||||
|
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
|
||||||
|
---
|
||||||
|
# response
|
||||||
|
float32[] parameter_array # this is needed because in case of user input the parameters arent known yet
|
||||||
|
float32[] score
|
@ -0,0 +1,4 @@
|
|||||||
|
initial_mean_centre: 0.0
|
||||||
|
initial_mean_std_dev: 0.2
|
||||||
|
initial_variance: 0.3
|
||||||
|
random_seed: ''
|
@ -0,0 +1,193 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
|
||||||
|
from transitions import Machine
|
||||||
|
import cma
|
||||||
|
import yaml
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Msg/Srv/Action
|
||||||
|
from interaction_msgs.srv import Query
|
||||||
|
from interaction_msgs.srv import TaskEvaluation
|
||||||
|
from std_msgs.msg import Bool
|
||||||
|
|
||||||
|
|
||||||
|
class CMAESOptimizationNode(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('cmaes_optimization_node')
|
||||||
|
|
||||||
|
# CMA-ES Attributes
|
||||||
|
self.cmaes = None
|
||||||
|
self.number_of_dimensions = 3
|
||||||
|
self.number_of_parameters_per_dimensions = 10
|
||||||
|
# the number of weights is double the number of dims * params per dim since its Position and Velocity
|
||||||
|
self.number_of_weights = 2 * self.number_of_dimensions * self.number_of_parameters_per_dimensions
|
||||||
|
|
||||||
|
self.random_seed = None
|
||||||
|
|
||||||
|
# Query Attributes
|
||||||
|
self.query_metric = 'random'
|
||||||
|
self.query_parameters = {}
|
||||||
|
self.best_rewards_per_iteration = []
|
||||||
|
# ROS2 Interfaces
|
||||||
|
self.__future = None
|
||||||
|
# Heartbeat Topics - to make sure that there is no deadlock
|
||||||
|
self.heartbeat_timeout = 30 # secs
|
||||||
|
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
|
||||||
|
|
||||||
|
# MR Heartbeat
|
||||||
|
self.last_mr_heartbeat_time = None
|
||||||
|
self.mr_heartbeat_sub = self.create_subscription(Bool, 'interaction/mr_heartbeat', self.mr_heatbeat_callback)
|
||||||
|
self.last_task_heartbeat_time = None
|
||||||
|
self.task_heartbeat_sub = self.create_subscription(Bool, 'interaction/task_heartbeat', self.task_heartbeat_callback)
|
||||||
|
|
||||||
|
# Topics
|
||||||
|
|
||||||
|
# Service
|
||||||
|
self.query_srv = self.create_client(Query, 'query_srv')
|
||||||
|
|
||||||
|
# Define states
|
||||||
|
self.states = [
|
||||||
|
'idle',
|
||||||
|
'initialization',
|
||||||
|
'query_decision',
|
||||||
|
'non_interactive_mode',
|
||||||
|
'wait_for_evaluation',
|
||||||
|
'interactive_mode',
|
||||||
|
'wait_for_user_response',
|
||||||
|
'prepare_user_data_for_evaluation',
|
||||||
|
'wait_for_user_evaluation',
|
||||||
|
'update_optimizer',
|
||||||
|
'check_termination',
|
||||||
|
'complete',
|
||||||
|
'error_recovery',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize state machine
|
||||||
|
self.machine = Machine(self, states=self.states, initial='idle')
|
||||||
|
|
||||||
|
# region Transitions
|
||||||
|
self.machine.add_transition(trigger='order_received', source='idle', dest='initialization')
|
||||||
|
self.machine.add_transition(trigger='initialization_complete', source='initialization', dest='query_decision')
|
||||||
|
self.machine.add_transition(trigger='no_interaction', source='query_decision', dest='non_interactive_mode')
|
||||||
|
self.machine.add_transition(trigger='data_send_for_evaluation', source='non_interactive_mode', dest='wait_for_evaluation')
|
||||||
|
self.machine.add_transition(trigger='evaluation_response_received', source='wait_for_evaluation', dest='update_optimizer')
|
||||||
|
self.machine.add_transition(trigger='interaction', source='query_decision', dest='interactive_mode')
|
||||||
|
self.machine.add_transition(trigger='data_send_to_user', source='interactive_mode', dest='wait_for_user_response')
|
||||||
|
self.machine.add_transition(trigger='user_response_received', source='wait_for_user_response', dest='prepare_user_data_for_evaluation')
|
||||||
|
self.machine.add_transition(trigger='send_user_data_to_evaluation', source='prepare_user_data_for_evaluation', dest='wait_for_user_evaluation')
|
||||||
|
self.machine.add_transition(trigger='received_user_data_for_evaluation', source='wait_for_user_evaluation', dest='update_optimizer')
|
||||||
|
self.machine.add_transition(trigger='optimizer_updated', source='update_optimizer', dest='check_termination')
|
||||||
|
self.machine.add_transition(trigger='next_optimizer_step', source='check_termination', dest='query_decision')
|
||||||
|
self.machine.add_transition(trigger='finished', source='check_termination', dest='complete')
|
||||||
|
self.machine.add_transition(trigger='error_trigger', source='*', dest='error_recovery')
|
||||||
|
self.machine.add_transition(trigger='recovery_complete', source='error_recovery', dest='idle')
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region State Functions
|
||||||
|
def on_enter_initialization(self):
|
||||||
|
config_file_path = self.get_parameter('cmaes_config_file_path').get_parameter_value().string_value
|
||||||
|
|
||||||
|
# Load YAML
|
||||||
|
with open(config_file_path, 'r') as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
if config['random_seed'] == '':
|
||||||
|
self.random_seed = None
|
||||||
|
else:
|
||||||
|
self.random_seed = int(config['random_seed'])
|
||||||
|
config['seed'] = self.random_seed
|
||||||
|
|
||||||
|
mean_centre = config['initial_mean_centre']
|
||||||
|
mean_std_dev = config['initial_mean_std_dev']
|
||||||
|
|
||||||
|
random_gen = np.random.default_rng(seed=self.random_seed)
|
||||||
|
|
||||||
|
initial_mean = random_gen.normal(mean_centre, mean_std_dev, size=(self.number_of_weights, 1))
|
||||||
|
initial_variance = config['initial_variance']
|
||||||
|
self.cmaes = cma.CMAEvolutionStrategy(initial_mean, initial_variance, inopts=config)
|
||||||
|
|
||||||
|
# Trigger transition
|
||||||
|
self.initialization_complete()
|
||||||
|
|
||||||
|
def on_enter_query_decision(self):
|
||||||
|
request = Query.Request()
|
||||||
|
|
||||||
|
if self.query_metric == 'random':
|
||||||
|
pass
|
||||||
|
elif self.query_metric == 'regular':
|
||||||
|
pass
|
||||||
|
elif self.query_metric == 'improvement':
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.__future = self.query_srv.call_async(request)
|
||||||
|
self.__future.add_done_callback(self.handle_query_response)
|
||||||
|
|
||||||
|
def on_enter_non_interactive_mode(self):
|
||||||
|
# Reset Task heartbeat to check if the other node crashed
|
||||||
|
self.last_task_heartbeat_time = None
|
||||||
|
|
||||||
|
request = TaskEvaluation.Request()
|
||||||
|
request.user_input = False
|
||||||
|
request.number_of_population = self.cmaes.popsize
|
||||||
|
request.number_of_dimensions = self.number_of_dimensions * 2
|
||||||
|
request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions
|
||||||
|
|
||||||
|
population = self.cmaes.ask()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def on_enter_interactive_mode(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_enter_prepare_user_data_for_evaluation(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_enter_update_optimizer(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_enter_check_termination(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_enter_error_recovery(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region Callback Functions
|
||||||
|
def check_heartbeats(self):
|
||||||
|
if self.last_mr_heartbeat_time:
|
||||||
|
current_time = self.get_clock().now().nanoseconds
|
||||||
|
if (current_time - self.last_mr_heartbeat_time) > (self.heartbeat_timeout * 1e9):
|
||||||
|
self.get_logger().error("MR Interface heartbeat timed out!")
|
||||||
|
self.error_trigger()
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.last_mr_heartbeat_time:
|
||||||
|
current_time = self.get_clock().now().nanoseconds
|
||||||
|
if (current_time - self.last_task_heartbeat_time) > (self.heartbeat_timeout * 1e9):
|
||||||
|
self.get_logger().error("Task Node heartbeat timed out!")
|
||||||
|
self.error_trigger()
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def mr_heartbeat_callback(self, _):
|
||||||
|
self.last_mr_heartbeat_time = self.get_clock().now().nanoseconds
|
||||||
|
|
||||||
|
def handle_query_response(self, future):
|
||||||
|
try:
|
||||||
|
response = future.result()
|
||||||
|
|
||||||
|
if response.interaction:
|
||||||
|
self.interaction()
|
||||||
|
else:
|
||||||
|
self.non_interaction()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Query service call failed: {str(e)}')
|
||||||
|
self.error_trigger()
|
||||||
|
|
||||||
|
# endregion
|
@ -1,4 +1,6 @@
|
|||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
package_name = 'interaction_optimizers'
|
package_name = 'interaction_optimizers'
|
||||||
|
|
||||||
@ -10,6 +12,7 @@ setup(
|
|||||||
('share/ament_index/resource_index/packages',
|
('share/ament_index/resource_index/packages',
|
||||||
['resource/' + package_name]),
|
['resource/' + package_name]),
|
||||||
('share/' + package_name, ['package.xml']),
|
('share/' + package_name, ['package.xml']),
|
||||||
|
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
|
||||||
],
|
],
|
||||||
install_requires=['setuptools'],
|
install_requires=['setuptools'],
|
||||||
zip_safe=True,
|
zip_safe=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user