From f875ee34c1c13678a0f14702b84bf45452831c24 Mon Sep 17 00:00:00 2001 From: Niko Date: Mon, 18 Mar 2024 18:00:12 +0100 Subject: [PATCH] Changed TaskEvaluation.srv to service added Action server to task node --- .../action/TaskEvaluation.action | 4 +- .../interaction_tasks/task_node.py | 73 ++++++++++++++++--- 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/src/interaction_msgs/action/TaskEvaluation.action b/src/interaction_msgs/action/TaskEvaluation.action index 6a48f09..1be6510 100644 --- a/src/interaction_msgs/action/TaskEvaluation.action +++ b/src/interaction_msgs/action/TaskEvaluation.action @@ -21,5 +21,5 @@ string current_state uint16 processed_trajectories --- # Result -float32[] parameter_array # this is needed because in case of user input the parameters arent known yet -float32[] score \ No newline at end of file +float32[] new_means # Length: number_of_population * number_of_dimensions * number_of_parameters_per_dimension, this is needed because in case of user input the parameters arent known yet +float32[] score # Length: number_of_population \ No newline at end of file diff --git a/src/interaction_tasks/interaction_tasks/task_node.py b/src/interaction_tasks/interaction_tasks/task_node.py index 8c3a518..0773d1e 100644 --- a/src/interaction_tasks/interaction_tasks/task_node.py +++ b/src/interaction_tasks/interaction_tasks/task_node.py @@ -1,6 +1,8 @@ import rclpy from rclpy.node import Node from rclpy.parameter import Parameter +from rclpy.action import ActionServer +from rclpy.action import GoalResponse, CancelResponse from transitions import Machine import yaml @@ -8,8 +10,8 @@ import numpy as np from movement_primitives.promp import ProMP from src.interaction_utils.serialization import flatten_population, unflatten_population -from interaction_msgs.srv import TaskEvaluation -from std_msgs.msg import Bool +from interaction_msgs.action import TaskEvaluation + class TaskNode(Node): @@ -17,20 +19,25 @@ class TaskNode(Node): super().__init__('task_node') # Task Attributes - + self.number_of_processed_trajectories = 0 + self.goal_dict = {} # ROS2 Interfaces - # Heartbeat - self.heartbeat_pub = self.create_publisher(Bool, 'interaction/task_heartbeat', 10) - self.heartbeat_timer = self.create_timer(5.0, self.send_heartbeat) # Topic # Service # Action - self._goal_handle = None + self._goal = None + self._task_action = ActionServer(self, + TaskEvaluation, + 'interaction/task_action', + goal_callback=self._task_goal_callback, + cancel_callback=self._task_cancel_callback, + execute_callback=self._task_execute_callback) # State Machine + self.state = None # States self.states = [ 'waiting_for_task_specs', @@ -38,6 +45,7 @@ class TaskNode(Node): 'processing_interactive_inpute', 'waiting_for_robot_response', 'waiting_for_objective_function_response', + 'sending_request' 'error_recovery' ] @@ -50,14 +58,55 @@ class TaskNode(Node): self.machine.add_transition(trigger='non_interactive_to_obj_fun', source='processing_non_interactive_input', dest='waiting_for_objective_function_response') self.machine.add_transition(trigger='interactive_to_robot', source='processing_interactive_input', dest='waiting_for_robot_response') self.machine.add_transition(trigger='interactive_to_obj_fun', source='processing_interactive_input', dest='waiting_for_objective_function_response') - self.machine.add_transition(trigger='sending_robot_results', source='waiting_for_robot_response', dest='waiting_for_task_specs') - self.machine.add_transition(trigger='sending_obj_fun_results', source='waiting_for_obj_fun_response', dest='waiting_for_task_specs') + self.machine.add_transition(trigger='sending_robot_results', source='waiting_for_robot_response', dest='sending_request') + self.machine.add_transition(trigger='sending_obj_fun_results', source='waiting_for_obj_fun_response', dest='sending_request') + self.machine.add_transition(trigger='sending_back', source='sending_request', dest='waiting_for_task_specs') self.machine.add_transition(trigger='error_trigger', source='*', dest='error_recovery') self.machine.add_transition(trigger='recovery_complete', source='error_recovery', dest='waiting_for_task_specs') + def destroy(self): + self._task_action.destroy() + super().destroy_node() + # State Methods # Callback functions - def send_heartbeat(self): - msg = Bool() - self.heartbeat_pub.publish(msg) + def _task_goal_callback(self, goal): + self._goal = goal + + if goal.user_input: + self.interactive_specs_received() + else: + self.non_interactive_specs_received() + + return GoalResponse.ACCEPT + + def _task_cancel_callback(self, _): + self.error_trigger() + return CancelResponse.ACCEPT + + async def _task_execute_callback(self, goal_handle): + feedback_msg = TaskEvaluation.Feedback() + result_msg = TaskEvaluation.Result() + + # Feedback Loop + while not goal_handle.is_cancel_requested(): + # Send Feedback msg + feedback_msg.current_state = self.state + feedback_msg.processed_trajectories = self.number_of_processed_trajectories + goal_handle.publish_feedback(feedback_msg) + + if self.state == 'sending_request': + result_msg.score = self.goal_dict['score'] + result_msg.new_means = self.goal_dict['new_means'] + + break + + if goal_handle.is_cancel_requested(): + goal_handle.canceled() + result_msg.score = -1 + return result_msg + + self.sending_back() + goal_handle.succeed() + return result_msg