dual quaternion policy model added

This commit is contained in:
Niko Feith 2024-02-21 14:02:21 +01:00
parent b83c490464
commit 5e42128f06
3 changed files with 72 additions and 1 deletions

View File

@ -0,0 +1,43 @@
from dual_quaternions import DualQuaternion
import numpy as np
class DualQuaternionSplines:
def __init__(self, pose_vector, nr_interpolation_steps=10):
self.pose_vector = pose_vector
self.nr_interpolation_steps = nr_interpolation_steps
self.dual_quaternions = self.parse_input_vector()
def generate_trajectory(self):
"""Generate an interpolated trajectory from the list of dual quaternions."""
interpolated_trajectory = []
for i in range(len(self.dual_quaternions) - 1):
interpolated_trajectory.extend(self.interpolate_dual_quaternions(self.dual_quaternions[i],
self.dual_quaternions[i+1],
self.nr_interpolation_steps))
return interpolated_trajectory
def parse_input_vector(self):
"""Parse the input vector into dual quaternions."""
dual_quats = []
for i in range(0, len(self.pose_vector), 7):
pose = self.pose_vector[i:i+7]
dq = self.quaternion_to_dual_quaternion(pose)
dual_quats.append(dq)
return dual_quats
@staticmethod
def quaternion_to_dual_quaternion(pose):
"""Convert position and quaternion to a dual quaternion.
:param pose: [q_rw, q_rx, q_ry, q_rz, q_tx, q_ty, q_tz]"""
return DualQuaternion.from_quat_pose_array(pose)
@staticmethod
def interpolate_dual_quaternions(dq1, dq2, steps=10):
"""Interpolate between two dual quaternions."""
return [DualQuaternion.sclerp(dq1, dq2, t) for t in np.linspace(0, 1, steps)]

View File

@ -0,0 +1,29 @@
import unittest
from dual_quaternions import DualQuaternion
from src.interaction_policy_representation.interaction_policy_representation.models.dual_quaternion_splines import DualQuaternionSplines
class TestDualQuaternionSplines(unittest.TestCase):
def setUp(self):
# Example input vector (position + quaternion for 2 poses)
self.input_vector = [1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0.]
# Expected dual quaternion for the first pose, adapted based on actual class functionality
self.expected_first_dq = DualQuaternion.from_quat_pose_array(self.input_vector[:7])
def test_parse_input_vector(self):
trajectory = DualQuaternionSplines(self.input_vector)
# Check if the first dual quaternion matches the expected one
# This assumes the implementation details about how dual quaternions are stored
self.assertAlmostEqual(trajectory.dual_quaternions[0].q_r.w, self.expected_first_dq.q_r.w, places=5)
self.assertAlmostEqual(trajectory.dual_quaternions[0].q_r.x, self.expected_first_dq.q_r.x, places=5)
# Add more assertions as needed for y, z, dual part
def test_generate_trajectory(self):
trajectory = DualQuaternionSplines(self.input_vector)
interpolated_trajectory = trajectory.generate_trajectory()
# Ensure the interpolated trajectory is not empty
self.assertTrue(len(interpolated_trajectory) > 0)
# Add more specific tests, e.g., comparing specific interpolated values
if __name__ == '__main__':
unittest.main()