test for active rl

This commit is contained in:
Niko Feith 2023-03-29 17:46:29 +02:00
parent 25317ec7fb
commit 37dcf957f4
5 changed files with 56 additions and 4 deletions

View File

@ -26,6 +26,7 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"srv/ActiveBO.srv" "srv/ActiveBO.srv"
"srv/ActiveRL.srv" "srv/ActiveRL.srv"
"srv/ActiveRLEval.srv" "srv/ActiveRLEval.srv"
"msg/ActiveRLEval.msg"
"msg/ImageFeedback.msg" "msg/ImageFeedback.msg"
) )

View File

@ -67,6 +67,7 @@ class ActiveBOService(Node):
arl_request.old_weights = old_weights.tolist() arl_request.old_weights = old_weights.tolist()
self.get_logger().info('Calling: Active RL') self.get_logger().info('Calling: Active RL')
future_rl = self.active_rl_client.call_async(arl_request) future_rl = self.active_rl_client.call_async(arl_request)
self.get_logger().info(str(future_rl))
while not future_rl.done(): while not future_rl.done():
rclpy.spin_once(self) rclpy.spin_once(self)

View File

@ -13,7 +13,6 @@ import numpy as np
import time import time
class ActiveRLService(Node): class ActiveRLService(Node):
def __init__(self): def __init__(self):
super().__init__('active_rl_service') super().__init__('active_rl_service')
@ -25,14 +24,17 @@ class ActiveRLService(Node):
self.active_rl_callback, self.active_rl_callback,
callback_group=srv_callback_group) callback_group=srv_callback_group)
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1) self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1, callback_group=srv_callback_group)
self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1) self.eval_pub = self.create_publisher(ActiveRLEval,
'active_rl_eval_request',
1,
callback_group=srv_callback_group)
self.eval_sub = self.create_subscription(ActiveRLEval, self.eval_sub = self.create_subscription(ActiveRLEval,
'active_rl_eval_response', 'active_rl_eval_response',
self.active_rl_eval_callback, self.active_rl_eval_callback,
1, 1,
callback_group=sub_callback_group) callback_group=srv_callback_group)
self.eval_response_received = False self.eval_response_received = False
self.eval_response = None self.eval_response = None
self.eval_response_received_first = False self.eval_response_received_first = False
@ -50,6 +52,8 @@ class ActiveRLService(Node):
self.eval_response = response self.eval_response = response
self.eval_response_received = True self.eval_response_received = True
def active_rl_callback(self, request, response): def active_rl_callback(self, request, response):
self.get_logger().info('Active RL: Called') self.get_logger().info('Active RL: Called')

View File

@ -0,0 +1,45 @@
from active_bo_msgs.srv import ActiveRL
import rclpy
from rclpy.node import Node
import numpy as np
import time
class ActiveRLTest(Node):
def __init__(self):
super().__init__('active_rl_test')
self.client = self.create_client(ActiveRL, 'active_rl_srv')
self.main_loop = self.create_timer(20.0, self.main_callback)
def main_callback(self):
random_policy = np.random.uniform(-1.0, 1.0, 100).tolist()
random_weights = np.random.uniform(-1.0, 1.0, 5).tolist()
self.get_logger().info(str(random_policy))
rl_request = ActiveRL.Request()
rl_request.old_policy = random_policy
rl_request.old_weights = random_weights
future = self.client.call_async(rl_request)
self.get_logger().info(str(future))
def main(args=None):
rclpy.init(args=args)
active_rl_test = ActiveRLTest()
rclpy.spin(active_rl_test)
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -32,6 +32,7 @@ setup(
'bo_srv = active_bo_ros.bo_service:main', 'bo_srv = active_bo_ros.bo_service:main',
'active_bo_srv = active_bo_ros.active_bo_service:main', 'active_bo_srv = active_bo_ros.active_bo_service:main',
'active_rl_srv = active_bo_ros.active_rl_service:main', 'active_rl_srv = active_bo_ros.active_rl_service:main',
'active_rl_test = active_bo_ros.active_rl_test_node:main',
], ],
}, },
) )