Changed TaskEvaluation.srv to service
Adapted cmaes_optimization_node.py for the action
This commit is contained in:
parent
bffe826a74
commit
2422693b42
@ -9,11 +9,12 @@ find_package(ament_cmake REQUIRED)
|
|||||||
find_package(rosidl_default_generators REQUIRED)
|
find_package(rosidl_default_generators REQUIRED)
|
||||||
|
|
||||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
|
"action/TaskEvaluation.action"
|
||||||
"srv/Query.srv"
|
"srv/Query.srv"
|
||||||
"srv/Task.srv"
|
"srv/Task.srv"
|
||||||
"srv/TaskEvaluation.srv"
|
|
||||||
"srv/UserInterface.srv"
|
"srv/UserInterface.srv"
|
||||||
"srv/ParameterChange.srv"
|
"srv/ParameterChange.srv"
|
||||||
|
# "srv/TaskEvaluation.srv"
|
||||||
"msg/OptimizerState.msg"
|
"msg/OptimizerState.msg"
|
||||||
"msg/Opt2UI.msg"
|
"msg/Opt2UI.msg"
|
||||||
"msg/UI2Opt.msg"
|
"msg/UI2Opt.msg"
|
||||||
|
25
src/interaction_msgs/action/TaskEvaluation.action
Normal file
25
src/interaction_msgs/action/TaskEvaluation.action
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Goal
|
||||||
|
bool user_input
|
||||||
|
uint16 number_of_population
|
||||||
|
float32 duration
|
||||||
|
uint16 number_of_time_steps
|
||||||
|
|
||||||
|
# case if user_input is true
|
||||||
|
float32[] user_parameters # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
float32[] user_covariance_diag # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
float32[] current_cma_mean # Length: number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
float32[] conditional_points # Length: (number_of_dimensions + time_stamp[0,1]) * number_of_conditional_points
|
||||||
|
float32[] weight_parameter # this parameter sets the weighted average 0 dont trust user 1 completly trust user (it is set by the user or it is decays over time i have to do some experiments on that)
|
||||||
|
|
||||||
|
# case if user_input is false
|
||||||
|
uint16 number_of_dimensions # this is the number of ProMPs * 2 (Position and Velocity)
|
||||||
|
uint16 number_of_parameters_per_dimensions
|
||||||
|
float32[] parameter_array # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension
|
||||||
|
---
|
||||||
|
# Feedback
|
||||||
|
string current_state
|
||||||
|
uint16 processed_trajectories
|
||||||
|
---
|
||||||
|
# Result
|
||||||
|
float32[] parameter_array # this is needed because in case of user input the parameters arent known yet
|
||||||
|
float32[] score
|
@ -3,6 +3,7 @@ import os
|
|||||||
import rclpy
|
import rclpy
|
||||||
from rclpy.node import Node
|
from rclpy.node import Node
|
||||||
from rclpy.parameter import Parameter
|
from rclpy.parameter import Parameter
|
||||||
|
from rclpy.action import ActionClient
|
||||||
|
|
||||||
from transitions import Machine
|
from transitions import Machine
|
||||||
import cma
|
import cma
|
||||||
@ -12,7 +13,7 @@ from src.interaction_utils.serialization import flatten_population, unflatten_po
|
|||||||
|
|
||||||
# Msg/Srv/Action
|
# Msg/Srv/Action
|
||||||
from interaction_msgs.srv import Query
|
from interaction_msgs.srv import Query
|
||||||
from interaction_msgs.srv import TaskEvaluation
|
from interaction_msgs.action import TaskEvaluation
|
||||||
from interaction_msgs.srv import ParameterChange
|
from interaction_msgs.srv import ParameterChange
|
||||||
from interaction_msgs.srv import UserInterface
|
from interaction_msgs.srv import UserInterface
|
||||||
from std_msgs.msg import Bool
|
from std_msgs.msg import Bool
|
||||||
@ -42,6 +43,8 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.best_reward_per_iteration = []
|
self.best_reward_per_iteration = []
|
||||||
# ROS2 Interfaces
|
# ROS2 Interfaces
|
||||||
self.__future = None
|
self.__future = None
|
||||||
|
self.__send_future = None
|
||||||
|
self.__response_future = None
|
||||||
self.__task_response = None
|
self.__task_response = None
|
||||||
self.__user_response = None
|
self.__user_response = None
|
||||||
# Heartbeat Topics - to make sure that there is no deadlock
|
# Heartbeat Topics - to make sure that there is no deadlock
|
||||||
@ -51,17 +54,16 @@ class CMAESOptimizationNode(Node):
|
|||||||
# Heartbeat
|
# Heartbeat
|
||||||
self.last_mr_heartbeat_time = None
|
self.last_mr_heartbeat_time = None
|
||||||
self.mr_heartbeat_sub = self.create_subscription(Bool, 'interaction/mr_heartbeat', self.mr_heatbeat_callback)
|
self.mr_heartbeat_sub = self.create_subscription(Bool, 'interaction/mr_heartbeat', self.mr_heatbeat_callback)
|
||||||
self.last_task_heartbeat_time = None
|
|
||||||
self.task_heartbeat_sub = self.create_subscription(Bool, 'interaction/task_heartbeat', self.task_heartbeat_callback)
|
|
||||||
|
|
||||||
# Topic
|
# Topic
|
||||||
|
|
||||||
# Service
|
# Service
|
||||||
self.parameter_srv = self.create_service(ParameterChange, 'interaction/cmaes_parameter_srv', self.parameter_callback)
|
self.parameter_srv = self.create_service(ParameterChange, 'interaction/cmaes_parameter_srv', self.parameter_callback)
|
||||||
self.query_srv = self.create_client(Query, 'interaction/query_srv')
|
self.query_srv = self.create_client(Query, 'interaction/query_srv')
|
||||||
self.task_srv = self.create_client(TaskEvaluation, 'interaction/task_srv')
|
|
||||||
self.user_interface_srv = self.create_client(UserInterface, 'interaction/user_interface_srv')
|
self.user_interface_srv = self.create_client(UserInterface, 'interaction/user_interface_srv')
|
||||||
|
|
||||||
|
# Action
|
||||||
|
self._task_action = ActionClient(self, TaskEvaluation, 'interaction/task_action')
|
||||||
|
|
||||||
# State Machine
|
# State Machine
|
||||||
# States
|
# States
|
||||||
self.states = [
|
self.states = [
|
||||||
@ -142,22 +144,20 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.__future.add_done_callback(self.handle_query_response)
|
self.__future.add_done_callback(self.handle_query_response)
|
||||||
|
|
||||||
def on_enter_non_interactive_mode(self):
|
def on_enter_non_interactive_mode(self):
|
||||||
# Reset Task heartbeat to check if the other node crashed
|
goal = TaskEvaluation.Goal()
|
||||||
self.last_task_heartbeat_time = None
|
goal.user_input = False
|
||||||
|
goal.number_of_population = self.cmaes.popsize
|
||||||
request = TaskEvaluation.Request()
|
goal.number_of_dimensions = self.number_of_dimensions * 2
|
||||||
request.user_input = False
|
goal.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions
|
||||||
request.number_of_population = self.cmaes.popsize
|
|
||||||
request.number_of_dimensions = self.number_of_dimensions * 2
|
|
||||||
request.number_of_parameters_per_dimensions = self.number_of_parameters_per_dimensions
|
|
||||||
|
|
||||||
population = self.cmaes.ask()
|
population = self.cmaes.ask()
|
||||||
flat_population = flatten_population(population)
|
flat_population = flatten_population(population)
|
||||||
|
|
||||||
request.parameter_array = flat_population
|
goal.parameter_array = flat_population
|
||||||
|
|
||||||
self.__future = self.task_srv.call_async(request)
|
self._task_action.wait_for_server(timeout_sec=10.0)
|
||||||
self.__future.add_done_callback(self.handle_task_response)
|
self.__send_future = self._task_action.send_goal_async(goal)
|
||||||
|
self.__send_future.add_done_callback(self._task_response_callback)
|
||||||
|
|
||||||
self.data_send_for_evaluation()
|
self.data_send_for_evaluation()
|
||||||
|
|
||||||
@ -176,9 +176,6 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.data_send_to_user()
|
self.data_send_to_user()
|
||||||
|
|
||||||
def on_enter_prepare_user_data_for_evaluation(self):
|
def on_enter_prepare_user_data_for_evaluation(self):
|
||||||
# Reset Task heartbeat to check if the other node crashed
|
|
||||||
self.last_task_heartbeat_time = None
|
|
||||||
|
|
||||||
# Update the user_covariance
|
# Update the user_covariance
|
||||||
self.user_covariance = self.__user_response.user_covariance_diag
|
self.user_covariance = self.__user_response.user_covariance_diag
|
||||||
|
|
||||||
@ -294,14 +291,6 @@ class CMAESOptimizationNode(Node):
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.last_mr_heartbeat_time:
|
|
||||||
current_time = self.get_clock().now().nanoseconds
|
|
||||||
if (current_time - self.last_task_heartbeat_time) > (self.heartbeat_timeout * 1e9):
|
|
||||||
self.get_logger().error("Task Node heartbeat timed out!")
|
|
||||||
self.error_trigger()
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def mr_heartbeat_callback(self, _):
|
def mr_heartbeat_callback(self, _):
|
||||||
self.last_mr_heartbeat_time = self.get_clock().now().nanoseconds
|
self.last_mr_heartbeat_time = self.get_clock().now().nanoseconds
|
||||||
|
|
||||||
@ -321,15 +310,22 @@ class CMAESOptimizationNode(Node):
|
|||||||
self.get_logger().error(f'Query service call failed: {e}')
|
self.get_logger().error(f'Query service call failed: {e}')
|
||||||
self.error_trigger()
|
self.error_trigger()
|
||||||
|
|
||||||
def handle_task_response(self, future):
|
def _task_goal_callback(self, future):
|
||||||
try:
|
goal_handle = future.result()
|
||||||
self.__task_response = future.result()
|
|
||||||
self.evaluation_response_received()
|
|
||||||
|
|
||||||
except Exception as e:
|
if not goal_handle.accepted:
|
||||||
self.get_logger().error(f'Task service call failed: {e}')
|
self.get_logger().error(f'Task Goal rejected: {goal_handle}')
|
||||||
self.error_trigger()
|
self.error_trigger()
|
||||||
|
|
||||||
|
self.__response_future = goal_handle.get_result_asyn()
|
||||||
|
self.__response_future.add_done_callback(self._task_result_callback)
|
||||||
|
|
||||||
|
def _task_feedback_callback(self, msg):
|
||||||
|
self.get_logger().info(f'Received Feedback: state={msg.current_state}, processed={msg.processed_trajectories}')
|
||||||
|
def _task_result_callback(self, future):
|
||||||
|
self.__task_response = future.result()
|
||||||
|
self.evaluation_response_received()
|
||||||
|
|
||||||
def handle_user_response(self, future):
|
def handle_user_response(self, future):
|
||||||
try:
|
try:
|
||||||
self.__user_response = future.result()
|
self.__user_response = future.result()
|
||||||
|
Loading…
Reference in New Issue
Block a user