cma-es node in dev
This commit is contained in:
parent
acf51f01c6
commit
a23e5be64c
@ -3,12 +3,13 @@ FROM osrf/ros:humble-desktop-full-jammy
|
||||
|
||||
# Update and install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3-colcon-common-extensions python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
python3-colcon-common-extensions python3-pip
|
||||
|
||||
COPY requirements.txt ./
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
RUN rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create a workspace
|
||||
WORKDIR /ros2_ws
|
||||
|
||||
@ -21,6 +22,7 @@ RUN . /opt/ros/humble/setup.sh && \
|
||||
|
||||
RUN echo "source /opt/ros/humble/setup.bash" >> ~/.bashrc
|
||||
|
||||
ENV PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/ros/humble/bin"
|
||||
|
||||
# Source the workspace
|
||||
CMD ["/bin/bash"]
|
||||
|
19
README.md
19
README.md
@ -27,6 +27,25 @@ docker build -t interactive-robot-learning-framework .
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
## Setting up Pycharm Environment
|
||||
Add a new Remote Interpreter based on the docker-compose file.
|
||||
Then add the following paths to the interpreter paths. You can find them when clicking on the Interpreter in the lower right corner of the Pycharm window.
|
||||
Then select "Interpreter Settings" within the setting go to "Show All..." Python interpreters (its in the list of the Python Interpreters).
|
||||
Subsequently, you can add with the "+" a new interpreter, choose "Docker Compose" and select the "docker-compose.yaml" file from this repository.
|
||||
The last step is to add the ros2 paths within your docker container to the interpreter paths (its one of the buttons next to the "+".
|
||||
|
||||
The following paths are necessary to develop in the docker container otherwise python cannot find rclpy and the custom msg packages:
|
||||
|
||||
```bash
|
||||
/opt/ros/humble/local/lib/python3.10/dist-packages
|
||||
|
||||
/opt/ros/humble/lib/python3.10/site-packages
|
||||
|
||||
/ros2_ws/install/interaction_msgs/local/lib/python3.10/dist-packages
|
||||
```
|
||||
If there is a
|
||||
|
||||
|
||||
## Framework Structure
|
||||
|
||||
The Interactive Robot Learning Framework consists of several key ROS2 packages, each responsible for different aspects of robot learning and interaction. Below is an overview of each package:
|
||||
|
@ -7,4 +7,9 @@ lark~=1.1.1
|
||||
scipy~=1.12.0
|
||||
scikit-learn~=1.4.0
|
||||
transitions==0.9.0
|
||||
movement-primitives[all]~=0.7.0
|
||||
movement-primitives[all]~=0.7.0
|
||||
cma~=3.3.0
|
||||
rclpy~=3.3.11
|
||||
PyYAML~=5.4.1
|
||||
rpyutils~=0.2.1
|
||||
turtlesim~=1.4.2
|
@ -11,6 +11,7 @@ find_package(rosidl_default_generators REQUIRED)
|
||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
"srv/Query.srv"
|
||||
"srv/Task.srv"
|
||||
"srv/TaskEvaluation.srv"
|
||||
"msg/OptimizerState.msg"
|
||||
"msg/Opt2UI.msg"
|
||||
"msg/UI2Opt.msg"
|
||||
|
0
src/interaction_msgs/msg/Order.msg
Normal file
0
src/interaction_msgs/msg/Order.msg
Normal file
19
src/interaction_msgs/srv/TaskEvaluation.srv
Normal file
19
src/interaction_msgs/srv/TaskEvaluation.srv
Normal file
@ -0,0 +1,19 @@
|
||||
bool user_input
|
||||
# case if user_input is true
|
||||
uint16 nr_of_population
|
||||
float32[] user_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||
float32[] user_covariance_diag # same as user_mean
|
||||
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||
float32[] conditional_points # Length: (number_of_dimensions + time_stamp) * number_of_conditional_points
|
||||
float32 weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
|
||||
|
||||
# case if user_input is false
|
||||
uint16 number_of_population
|
||||
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
|
||||
uint16 number_of_parameters_per_dimensions
|
||||
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension
|
||||
|
||||
---
|
||||
# response
|
||||
float32[] parameter_array # this is needed because in case of user input the parameters arent known yet
|
||||
float32[] score
|
@ -0,0 +1,4 @@
|
||||
initial_mean_centre: 0.0
|
||||
initial_mean_std_dev: 0.2
|
||||
initial_variance: 0.3
|
||||
random_seed: ''
|
@ -0,0 +1,193 @@
|
||||
import os
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from transitions import Machine
|
||||
import cma
|
||||
import yaml
|
||||
import numpy as np
|
||||
|
||||
# Msg/Srv/Action
|
||||
from interaction_msgs.srv import Query
|
||||
from interaction_msgs.srv import TaskEvaluation
|
||||
from std_msgs.msg import Bool
|
||||
|
||||
|
||||
class CMAESOptimizationNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('cmaes_optimization_node')
|
||||
|
||||
# CMA-ES Attributes
|
||||
self.cmaes = None
|
||||
self.number_of_dimensions = 3
|
||||
self.number_of_parameters_per_dimensions = 10
|
||||
# the number of weights is double the number of dims * params per dim since its Position and Velocity
|
||||
self.number_of_weights = 2 * self.number_of_dimensions * self.number_of_parameters_per_dimensions
|
||||
|
||||
self.random_seed = None
|
||||
|
||||
# Query Attributes
|
||||
self.query_metric = 'random'
|
||||
self.query_parameters = {}
|
||||
self.best_rewards_per_iteration = []
|
||||
# ROS2 Interfaces
|
||||
self.__future = None
|
||||
# Heartbeat Topics - to make sure that there is no deadlock
|
||||
self.heartbeat_timeout = 30 # secs
|
||||
self.heartbeat_timer = self.create_timer(1.0, self.check_heartbeats)
|
||||
|
||||
# MR Heartbeat
|
||||
self.last_mr_heartbeat_time = None
|
||||
self.mr_heartbeat_sub = self.create_subscription(Bool, 'interaction/mr_heartbeat', self.mr_heatbeat_callback)
|
||||
self.last_task_heartbeat_time = None
|
||||
self.task_heartbeat_sub = self.create_subscription(Bool, 'interaction/task_heartbeat', self.task_heartbeat_callback)
|
||||
|
||||
# Topics
|
||||
|
||||
# Service
|
||||
self.query_srv = self.create_client(Query, 'query_srv')
|
||||
|
||||
# Define states
|
||||
self.states = [
|
||||
'idle',
|
||||
'initialization',
|
||||
'query_decision',
|
||||
'non_interactive_mode',
|
||||
'wait_for_evaluation',
|
||||
'interactive_mode',
|
||||
'wait_for_user_response',
|
||||
'prepare_user_data_for_evaluation',
|
||||
'wait_for_user_evaluation',
|
||||
'update_optimizer',
|
||||
'check_termination',
|
||||
'complete',
|
||||
'error_recovery',
|
||||
]
|
||||
|
||||
# Initialize state machine
|
||||
self.machine = Machine(self, states=self.states, initial='idle')
|
||||
|
||||
# region Transitions
|
||||
self.machine.add_transition(trigger='order_received', source='idle', dest='initialization')
|
||||
self.machine.add_transition(trigger='initialization_complete', source='initialization', dest='query_decision')
|
||||
self.machine.add_transition(trigger='no_interaction', source='query_decision', dest='non_interactive_mode')
|
||||
self.machine.add_transition(trigger='data_send_for_evaluation', source='non_interactive_mode', dest='wait_for_evaluation')
|
||||
self.machine.add_transition(trigger='evaluation_response_received', source='wait_for_evaluation', dest='update_optimizer')
|
||||
self.machine.add_transition(trigger='interaction', source='query_decision', dest='interactive_mode')
|
||||
self.machine.add_transition(trigger='data_send_to_user', source='interactive_mode', dest='wait_for_user_response')
|
||||
self.machine.add_transition(trigger='user_response_received', source='wait_for_user_response', dest='prepare_user_data_for_evaluation')
|
||||
self.machine.add_transition(trigger='send_user_data_to_evaluation', source='prepare_user_data_for_evaluation', dest='wait_for_user_evaluation')
|
||||
self.machine.add_transition(trigger='received_user_data_for_evaluation', source='wait_for_user_evaluation', dest='update_optimizer')
|
||||
self.machine.add_transition(trigger='optimizer_updated', source='update_optimizer', dest='check_termination')
|
||||
self.machine.add_transition(trigger='next_optimizer_step', source='check_termination', dest='query_decision')
|
||||
self.machine.add_transition(trigger='finished', source='check_termination', dest='complete')
|
||||
self.machine.add_transition(trigger='error_trigger', source='*', dest='error_recovery')
|
||||
self.machine.add_transition(trigger='recovery_complete', source='error_recovery', dest='idle')
|
||||
# endregion
|
||||
|
||||
# region State Functions
|
||||
def on_enter_initialization(self):
|
||||
config_file_path = self.get_parameter('cmaes_config_file_path').get_parameter_value().string_value
|
||||
|
||||
# Load YAML
|
||||
with open(config_file_path, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
if config['random_seed'] == '':
|
||||
self.random_seed = None
|
||||
else:
|
||||
self.random_seed = int(config['random_seed'])
|
||||
config['seed'] = self.random_seed
|
||||
|
||||
mean_centre = config['initial_mean_centre']
|
||||
mean_std_dev = config['initial_mean_std_dev']
|
||||
|
||||
random_gen = np.random.default_rng(seed=self.random_seed)
|
||||
|
||||
initial_mean = random_gen.normal(mean_centre, mean_std_dev, size=(self.number_of_weights, 1))
|
||||
initial_variance = config['initial_variance']
|
||||
self.cmaes = cma.CMAEvolutionStrategy(initial_mean, initial_variance, inopts=config)
|
||||
|
||||
# Trigger transition
|
||||
self.initialization_complete()
|
||||
|
||||
def on_enter_query_decision(self):
|
||||
request = Query.Request()
|
||||
|
||||
if self.query_metric == 'random':
|
||||
pass
|
||||
elif self.query_metric == 'regular':
|
||||
pass
|
||||
elif self.query_metric == 'improvement':
|
||||
pass
|
||||
|
||||
self.__future = self.query_srv.call_async(request)
|
||||
self.__future.add_done_callback(self.handle_query_response)
|
||||
|
||||
def on_enter_non_interactive_mode(self):
|
||||
# Reset Task heartbeat to check if the other node crashed
|
||||
self.last_task_heartbeat_time = None
|
||||
|
||||
request = TaskEvaluation.Request()
|
||||
request.user_input = False
|
||||
request.number_of_population = self.cmaes.popsize
|
||||
request.number_of_dimensions = self.number_of_dimensions * 2
|
||||
request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions
|
||||
|
||||
population = self.cmaes.ask()
|
||||
|
||||
|
||||
|
||||
def on_enter_interactive_mode(self):
|
||||
pass
|
||||
|
||||
def on_enter_prepare_user_data_for_evaluation(self):
|
||||
pass
|
||||
|
||||
def on_enter_update_optimizer(self):
|
||||
pass
|
||||
|
||||
def on_enter_check_termination(self):
|
||||
pass
|
||||
|
||||
def on_enter_error_recovery(self):
|
||||
pass
|
||||
|
||||
# endregion
|
||||
|
||||
# region Callback Functions
|
||||
def check_heartbeats(self):
|
||||
if self.last_mr_heartbeat_time:
|
||||
current_time = self.get_clock().now().nanoseconds
|
||||
if (current_time - self.last_mr_heartbeat_time) > (self.heartbeat_timeout * 1e9):
|
||||
self.get_logger().error("MR Interface heartbeat timed out!")
|
||||
self.error_trigger()
|
||||
else:
|
||||
pass
|
||||
|
||||
if self.last_mr_heartbeat_time:
|
||||
current_time = self.get_clock().now().nanoseconds
|
||||
if (current_time - self.last_task_heartbeat_time) > (self.heartbeat_timeout * 1e9):
|
||||
self.get_logger().error("Task Node heartbeat timed out!")
|
||||
self.error_trigger()
|
||||
else:
|
||||
pass
|
||||
|
||||
def mr_heartbeat_callback(self, _):
|
||||
self.last_mr_heartbeat_time = self.get_clock().now().nanoseconds
|
||||
|
||||
def handle_query_response(self, future):
|
||||
try:
|
||||
response = future.result()
|
||||
|
||||
if response.interaction:
|
||||
self.interaction()
|
||||
else:
|
||||
self.non_interaction()
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Query service call failed: {str(e)}')
|
||||
self.error_trigger()
|
||||
|
||||
# endregion
|
@ -1,4 +1,6 @@
|
||||
from setuptools import find_packages, setup
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = 'interaction_optimizers'
|
||||
|
||||
@ -10,6 +12,7 @@ setup(
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
|
Loading…
Reference in New Issue
Block a user