Works for one use input
This commit is contained in:
parent
76b8f5e2d2
commit
10f938cc13
@ -4,6 +4,8 @@ from active_bo_msgs.srv import ActiveRL
|
|||||||
import rclpy
|
import rclpy
|
||||||
from rclpy.node import Node
|
from rclpy.node import Node
|
||||||
|
|
||||||
|
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||||
|
|
||||||
from active_bo_ros.BayesianOptimization.BayesianOptimization import BayesianOptimization
|
from active_bo_ros.BayesianOptimization.BayesianOptimization import BayesianOptimization
|
||||||
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
|
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
|
||||||
|
|
||||||
@ -13,11 +15,18 @@ import numpy as np
|
|||||||
class ActiveBOService(Node):
|
class ActiveBOService(Node):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('active_bo_service')
|
super().__init__('active_bo_service')
|
||||||
self.active_bo_srv = self.create_service(ActiveBO, 'active_bo_srv', self.active_bo_callback)
|
|
||||||
|
|
||||||
self.active_rl_client = self.create_client(ActiveRL, 'active_rl_srv')
|
bo_callback_group = ReentrantCallbackGroup()
|
||||||
|
rl_callback_group = ReentrantCallbackGroup()
|
||||||
|
|
||||||
self.rl_trigger_ = False
|
self.srv = self.create_service(ActiveBO,
|
||||||
|
'active_bo_srv',
|
||||||
|
self.active_bo_callback,
|
||||||
|
callback_group=bo_callback_group)
|
||||||
|
|
||||||
|
self.active_rl_client = self.create_client(ActiveRL,
|
||||||
|
'active_rl_srv',
|
||||||
|
callback_group=rl_callback_group)
|
||||||
|
|
||||||
self.env = Continuous_MountainCarEnv()
|
self.env = Continuous_MountainCarEnv()
|
||||||
self.distance_penalty = 0
|
self.distance_penalty = 0
|
||||||
@ -56,18 +65,22 @@ class ActiveBOService(Node):
|
|||||||
|
|
||||||
arl_request.old_policy = old_policy.tolist()
|
arl_request.old_policy = old_policy.tolist()
|
||||||
arl_request.old_weights = old_weights.tolist()
|
arl_request.old_weights = old_weights.tolist()
|
||||||
|
self.get_logger().info('Calling: Active RL')
|
||||||
future_rl = self.active_rl_client.call_async(arl_request)
|
future_rl = self.active_rl_client.call_async(arl_request)
|
||||||
|
|
||||||
while rclpy.ok():
|
while not future_rl.done():
|
||||||
rclpy.spin_once(self)
|
rclpy.spin_once(self)
|
||||||
if future_rl.done():
|
self.get_logger().info('waiting for response!')
|
||||||
try:
|
|
||||||
arl_response = future_rl.result()
|
self.get_logger().info('Received: Active RL')
|
||||||
self.get_logger().info('active RL Response: %s' % arl_response)
|
|
||||||
BO.add_new_observation(arl_response.reward, arl_response.new_weights)
|
try:
|
||||||
except Exception as e:
|
arl_response = future_rl.result()
|
||||||
self.get_logger().error('active RL Service failed %r' % (e,))
|
BO.add_new_observation(arl_response.reward, arl_response.new_weights)
|
||||||
break
|
except Exception as e:
|
||||||
|
self.get_logger().error('active RL Service failed %r' % (e,))
|
||||||
|
|
||||||
|
future_rl = None
|
||||||
|
|
||||||
# BO part
|
# BO part
|
||||||
else:
|
else:
|
||||||
@ -96,6 +109,8 @@ def main(args=None):
|
|||||||
|
|
||||||
rclpy.spin(active_bo_service)
|
rclpy.spin(active_bo_service)
|
||||||
|
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from active_bo_msgs.srv import ActiveRL
|
from active_bo_msgs.srv import ActiveRL
|
||||||
from active_bo_msgs.msg import ImageFeedback
|
from active_bo_msgs.msg import ImageFeedback
|
||||||
from active_bo_msgs.srv import ActiveRLEval
|
from active_bo_msgs.msg import ActiveRLEval
|
||||||
|
|
||||||
import rclpy
|
import rclpy
|
||||||
from rclpy.node import Node
|
from rclpy.node import Node
|
||||||
@ -17,44 +17,42 @@ class ActiveRLService(Node):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('active_rl_service')
|
super().__init__('active_rl_service')
|
||||||
srv_callback_group = ReentrantCallbackGroup()
|
srv_callback_group = ReentrantCallbackGroup()
|
||||||
eval_callback_group = ReentrantCallbackGroup()
|
sub_callback_group = ReentrantCallbackGroup()
|
||||||
|
|
||||||
self.rl_srv = self.create_service(ActiveRL,
|
self.srv = self.create_service(ActiveRL,
|
||||||
'active_rl_srv',
|
'active_rl_srv',
|
||||||
self.active_rl_callback,
|
self.active_rl_callback,
|
||||||
callback_group=srv_callback_group)
|
callback_group=srv_callback_group)
|
||||||
|
|
||||||
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
||||||
|
|
||||||
self.active_rl_eval_client = self.create_client(ActiveRLEval,
|
self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1)
|
||||||
'active_rl_eval_srv',
|
self.eval_sub = self.create_subscription(ActiveRLEval,
|
||||||
callback_group=eval_callback_group)
|
'active_rl_eval_response',
|
||||||
|
self.active_rl_eval_callback,
|
||||||
# self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1)
|
1,
|
||||||
# self.eval_sub = self.create_subscription(ActiveRLEval,
|
callback_group=sub_callback_group)
|
||||||
# 'active_rl_eval_response',
|
self.eval_response_received = False
|
||||||
# self.active_rl_eval_callback,
|
self.eval_response = None
|
||||||
# 10,
|
self.eval_response_received_first = False
|
||||||
# callback_group=sub_callback_group)
|
|
||||||
# self.eval_response_received = False
|
|
||||||
# self.eval_response = None
|
|
||||||
# self.eval_response_received_first = False
|
|
||||||
|
|
||||||
self.env = Continuous_MountainCarEnv(render_mode='rgb_array')
|
self.env = Continuous_MountainCarEnv(render_mode='rgb_array')
|
||||||
self.distance_penalty = 0
|
self.distance_penalty = 0
|
||||||
|
|
||||||
# def active_rl_eval_callback(self, response):
|
def active_rl_eval_callback(self, response):
|
||||||
# # if not self.eval_response_received_first:
|
# if not self.eval_response_received_first:
|
||||||
# # self.eval_response_received_first = True
|
# self.eval_response_received_first = True
|
||||||
# # self.get_logger().info('/active_rl_eval_response connected!')
|
# self.get_logger().info('/active_rl_eval_response connected!')
|
||||||
# # else:
|
# else:
|
||||||
# # self.eval_response = response
|
# self.eval_response = response
|
||||||
# # self.eval_response_received = True
|
# self.eval_response_received = True
|
||||||
# self.eval_response = response
|
self.eval_response = response
|
||||||
# self.eval_response_received = True
|
self.eval_response_received = True
|
||||||
|
|
||||||
def active_rl_callback(self, request, response):
|
def active_rl_callback(self, request, response):
|
||||||
|
|
||||||
|
self.get_logger().info('Active RL: Called')
|
||||||
|
|
||||||
feedback_msg = ImageFeedback()
|
feedback_msg = ImageFeedback()
|
||||||
|
|
||||||
reward = 0
|
reward = 0
|
||||||
@ -62,9 +60,9 @@ class ActiveRLService(Node):
|
|||||||
old_policy = request.old_policy
|
old_policy = request.old_policy
|
||||||
old_weights = request.old_weights
|
old_weights = request.old_weights
|
||||||
|
|
||||||
eval_request = ActiveRLEval.Request()
|
eval_request = ActiveRLEval()
|
||||||
eval_request.old_policy = old_policy
|
eval_request.policy = old_policy
|
||||||
eval_request.old_weights = old_weights
|
eval_request.weights = old_weights
|
||||||
|
|
||||||
self.env.reset()
|
self.env.reset()
|
||||||
|
|
||||||
@ -95,17 +93,16 @@ class ActiveRLService(Node):
|
|||||||
break
|
break
|
||||||
|
|
||||||
self.get_logger().info('Enter new solution!')
|
self.get_logger().info('Enter new solution!')
|
||||||
# self.eval_pub.publish(eval_request)
|
self.eval_pub.publish(eval_request)
|
||||||
#
|
|
||||||
# while not self.eval_response_received:
|
|
||||||
# rclpy.spin_once(self)
|
|
||||||
|
|
||||||
eval_response = self.active_rl_eval_client.call(eval_request)
|
while not self.eval_response_received:
|
||||||
self.get_logger().info('Active RL Eval Srv started!')
|
rclpy.spin_once(self)
|
||||||
|
|
||||||
new_policy = eval_response.new_policy
|
|
||||||
new_weights = eval_response.new_weights
|
|
||||||
|
|
||||||
|
self.get_logger().info('Topic responded!')
|
||||||
|
new_policy = self.eval_response.policy
|
||||||
|
new_weights = self.eval_response.weights
|
||||||
|
self.eval_response_received = False
|
||||||
|
self.eval_response = None
|
||||||
|
|
||||||
reward = 0
|
reward = 0
|
||||||
step_count = 0
|
step_count = 0
|
||||||
@ -157,6 +154,8 @@ def main(args=None):
|
|||||||
|
|
||||||
rclpy.spin(active_rl_service)
|
rclpy.spin(active_rl_service)
|
||||||
|
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
@ -71,5 +71,7 @@ def main(args=None):
|
|||||||
|
|
||||||
rclpy.spin(rl_service)
|
rclpy.spin(rl_service)
|
||||||
|
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user