71 lines
1.8 KiB
Python
71 lines
1.8 KiB
Python
|
|
import rclpy
|
|
from rclpy.node import Node
|
|
|
|
from geometry_msgs.msg import PoseArray
|
|
from geometry_msgs.msg import Pose
|
|
from active_bo_msgs.msg import DMP
|
|
|
|
import pydmps
|
|
import pydmps.dmp_discrete
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
class DMPNode(Node):
|
|
|
|
def __init__(self):
|
|
super().__init__('dmp_node')
|
|
self.subscription = self.create_subscription(
|
|
DMP,
|
|
'franka_iml_experiment/dmps',
|
|
self.dmp_callback,
|
|
10)
|
|
self.subscription # prevent unused variable warning
|
|
|
|
self.traj_publisher = self.create_publisher(PoseArray, 'moveit_interface/task_space_trajectory', 10)
|
|
|
|
def dmp_callback(self, msg):
|
|
start = msg.start_point
|
|
self.get_logger().info(f"{start}")
|
|
end = np.array(msg.end_point)
|
|
time = msg.time
|
|
nr_bfs = msg.nr_bfs
|
|
|
|
# weights = np.vstack((msg.p_x, msg.p_y, msg.p_z, msg.o_x, msg.o_y, msg.o_z, msg.o_w))
|
|
weights = np.vstack((msg.p_x, msg.p_y))
|
|
|
|
dmp = pydmps.dmp_discrete.DMPs_discrete(n_dmps=2, n_bfs=nr_bfs, w=weights, y0=start[:2], goal=end[:2])
|
|
y_track, _, _ = dmp.rollout(tau=time)
|
|
|
|
pose_msg = PoseArray()
|
|
for i in range(y_track.shape[0]):
|
|
pose = Pose()
|
|
pose.position.x = y_track[i, 0]
|
|
pose.position.y = y_track[i, 1]
|
|
# pose.position.z = y_track[i, 2]
|
|
# pose.orientation.x = y_track[i, 3]
|
|
# pose.orientation.y = y_track[i, 4]
|
|
# pose.orientation.z = y_track[i, 5]
|
|
# pose.orientation.w = y_track[i, 6]
|
|
pose_msg.poses.append(pose)
|
|
|
|
self.traj_publisher.publish(pose_msg)
|
|
|
|
|
|
|
|
def main(args=None):
|
|
rclpy.init(args=args)
|
|
|
|
dmp_node = DMPNode()
|
|
|
|
rclpy.spin(dmp_node)
|
|
|
|
dmp_node.destroy_node()
|
|
rclpy.shutdown()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|