test for active rl
This commit is contained in:
parent
25317ec7fb
commit
37dcf957f4
@ -26,6 +26,7 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
|||||||
"srv/ActiveBO.srv"
|
"srv/ActiveBO.srv"
|
||||||
"srv/ActiveRL.srv"
|
"srv/ActiveRL.srv"
|
||||||
"srv/ActiveRLEval.srv"
|
"srv/ActiveRLEval.srv"
|
||||||
|
"msg/ActiveRLEval.msg"
|
||||||
"msg/ImageFeedback.msg"
|
"msg/ImageFeedback.msg"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
@ -67,6 +67,7 @@ class ActiveBOService(Node):
|
|||||||
arl_request.old_weights = old_weights.tolist()
|
arl_request.old_weights = old_weights.tolist()
|
||||||
self.get_logger().info('Calling: Active RL')
|
self.get_logger().info('Calling: Active RL')
|
||||||
future_rl = self.active_rl_client.call_async(arl_request)
|
future_rl = self.active_rl_client.call_async(arl_request)
|
||||||
|
self.get_logger().info(str(future_rl))
|
||||||
|
|
||||||
while not future_rl.done():
|
while not future_rl.done():
|
||||||
rclpy.spin_once(self)
|
rclpy.spin_once(self)
|
||||||
|
@ -13,7 +13,6 @@ import numpy as np
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ActiveRLService(Node):
|
class ActiveRLService(Node):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('active_rl_service')
|
super().__init__('active_rl_service')
|
||||||
@ -25,14 +24,17 @@ class ActiveRLService(Node):
|
|||||||
self.active_rl_callback,
|
self.active_rl_callback,
|
||||||
callback_group=srv_callback_group)
|
callback_group=srv_callback_group)
|
||||||
|
|
||||||
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1, callback_group=srv_callback_group)
|
||||||
|
|
||||||
self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1)
|
self.eval_pub = self.create_publisher(ActiveRLEval,
|
||||||
|
'active_rl_eval_request',
|
||||||
|
1,
|
||||||
|
callback_group=srv_callback_group)
|
||||||
self.eval_sub = self.create_subscription(ActiveRLEval,
|
self.eval_sub = self.create_subscription(ActiveRLEval,
|
||||||
'active_rl_eval_response',
|
'active_rl_eval_response',
|
||||||
self.active_rl_eval_callback,
|
self.active_rl_eval_callback,
|
||||||
1,
|
1,
|
||||||
callback_group=sub_callback_group)
|
callback_group=srv_callback_group)
|
||||||
self.eval_response_received = False
|
self.eval_response_received = False
|
||||||
self.eval_response = None
|
self.eval_response = None
|
||||||
self.eval_response_received_first = False
|
self.eval_response_received_first = False
|
||||||
@ -50,6 +52,8 @@ class ActiveRLService(Node):
|
|||||||
self.eval_response = response
|
self.eval_response = response
|
||||||
self.eval_response_received = True
|
self.eval_response_received = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def active_rl_callback(self, request, response):
|
def active_rl_callback(self, request, response):
|
||||||
|
|
||||||
self.get_logger().info('Active RL: Called')
|
self.get_logger().info('Active RL: Called')
|
||||||
|
45
src/active_bo_ros/active_bo_ros/active_rl_test_node.py
Normal file
45
src/active_bo_ros/active_bo_ros/active_rl_test_node.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from active_bo_msgs.srv import ActiveRL
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveRLTest(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('active_rl_test')
|
||||||
|
|
||||||
|
self.client = self.create_client(ActiveRL, 'active_rl_srv')
|
||||||
|
|
||||||
|
self.main_loop = self.create_timer(20.0, self.main_callback)
|
||||||
|
|
||||||
|
def main_callback(self):
|
||||||
|
|
||||||
|
random_policy = np.random.uniform(-1.0, 1.0, 100).tolist()
|
||||||
|
random_weights = np.random.uniform(-1.0, 1.0, 5).tolist()
|
||||||
|
|
||||||
|
self.get_logger().info(str(random_policy))
|
||||||
|
|
||||||
|
rl_request = ActiveRL.Request()
|
||||||
|
rl_request.old_policy = random_policy
|
||||||
|
rl_request.old_weights = random_weights
|
||||||
|
|
||||||
|
future = self.client.call_async(rl_request)
|
||||||
|
|
||||||
|
self.get_logger().info(str(future))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
|
||||||
|
active_rl_test = ActiveRLTest()
|
||||||
|
|
||||||
|
rclpy.spin(active_rl_test)
|
||||||
|
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -32,6 +32,7 @@ setup(
|
|||||||
'bo_srv = active_bo_ros.bo_service:main',
|
'bo_srv = active_bo_ros.bo_service:main',
|
||||||
'active_bo_srv = active_bo_ros.active_bo_service:main',
|
'active_bo_srv = active_bo_ros.active_bo_service:main',
|
||||||
'active_rl_srv = active_bo_ros.active_rl_service:main',
|
'active_rl_srv = active_bo_ros.active_rl_service:main',
|
||||||
|
'active_rl_test = active_bo_ros.active_rl_test_node:main',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user