test for active rl
This commit is contained in:
parent
25317ec7fb
commit
37dcf957f4
@ -26,6 +26,7 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
"srv/ActiveBO.srv"
|
||||
"srv/ActiveRL.srv"
|
||||
"srv/ActiveRLEval.srv"
|
||||
"msg/ActiveRLEval.msg"
|
||||
"msg/ImageFeedback.msg"
|
||||
|
||||
)
|
||||
|
@ -67,6 +67,7 @@ class ActiveBOService(Node):
|
||||
arl_request.old_weights = old_weights.tolist()
|
||||
self.get_logger().info('Calling: Active RL')
|
||||
future_rl = self.active_rl_client.call_async(arl_request)
|
||||
self.get_logger().info(str(future_rl))
|
||||
|
||||
while not future_rl.done():
|
||||
rclpy.spin_once(self)
|
||||
|
@ -13,7 +13,6 @@ import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
|
||||
class ActiveRLService(Node):
|
||||
def __init__(self):
|
||||
super().__init__('active_rl_service')
|
||||
@ -25,14 +24,17 @@ class ActiveRLService(Node):
|
||||
self.active_rl_callback,
|
||||
callback_group=srv_callback_group)
|
||||
|
||||
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
||||
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1, callback_group=srv_callback_group)
|
||||
|
||||
self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1)
|
||||
self.eval_pub = self.create_publisher(ActiveRLEval,
|
||||
'active_rl_eval_request',
|
||||
1,
|
||||
callback_group=srv_callback_group)
|
||||
self.eval_sub = self.create_subscription(ActiveRLEval,
|
||||
'active_rl_eval_response',
|
||||
self.active_rl_eval_callback,
|
||||
1,
|
||||
callback_group=sub_callback_group)
|
||||
callback_group=srv_callback_group)
|
||||
self.eval_response_received = False
|
||||
self.eval_response = None
|
||||
self.eval_response_received_first = False
|
||||
@ -50,6 +52,8 @@ class ActiveRLService(Node):
|
||||
self.eval_response = response
|
||||
self.eval_response_received = True
|
||||
|
||||
|
||||
|
||||
def active_rl_callback(self, request, response):
|
||||
|
||||
self.get_logger().info('Active RL: Called')
|
||||
|
45
src/active_bo_ros/active_bo_ros/active_rl_test_node.py
Normal file
45
src/active_bo_ros/active_bo_ros/active_rl_test_node.py
Normal file
@ -0,0 +1,45 @@
|
||||
from active_bo_msgs.srv import ActiveRL
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
class ActiveRLTest(Node):
|
||||
def __init__(self):
|
||||
super().__init__('active_rl_test')
|
||||
|
||||
self.client = self.create_client(ActiveRL, 'active_rl_srv')
|
||||
|
||||
self.main_loop = self.create_timer(20.0, self.main_callback)
|
||||
|
||||
def main_callback(self):
|
||||
|
||||
random_policy = np.random.uniform(-1.0, 1.0, 100).tolist()
|
||||
random_weights = np.random.uniform(-1.0, 1.0, 5).tolist()
|
||||
|
||||
self.get_logger().info(str(random_policy))
|
||||
|
||||
rl_request = ActiveRL.Request()
|
||||
rl_request.old_policy = random_policy
|
||||
rl_request.old_weights = random_weights
|
||||
|
||||
future = self.client.call_async(rl_request)
|
||||
|
||||
self.get_logger().info(str(future))
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
|
||||
active_rl_test = ActiveRLTest()
|
||||
|
||||
rclpy.spin(active_rl_test)
|
||||
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -32,6 +32,7 @@ setup(
|
||||
'bo_srv = active_bo_ros.bo_service:main',
|
||||
'active_bo_srv = active_bo_ros.active_bo_service:main',
|
||||
'active_rl_srv = active_bo_ros.active_rl_service:main',
|
||||
'active_rl_test = active_bo_ros.active_rl_test_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user