Compare commits

..

No commits in common. "25317ec7fb069a0920a63cddce29bacc41179ba6" and "76b8f5e2d26a83939f572b3d1c5ac349e78389a0" have entirely different histories.

3 changed files with 52 additions and 69 deletions

View File

@ -4,8 +4,6 @@ from active_bo_msgs.srv import ActiveRL
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
from rclpy.callback_groups import ReentrantCallbackGroup
from active_bo_ros.BayesianOptimization.BayesianOptimization import BayesianOptimization from active_bo_ros.BayesianOptimization.BayesianOptimization import BayesianOptimization
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
@ -15,18 +13,11 @@ import numpy as np
class ActiveBOService(Node): class ActiveBOService(Node):
def __init__(self): def __init__(self):
super().__init__('active_bo_service') super().__init__('active_bo_service')
self.active_bo_srv = self.create_service(ActiveBO, 'active_bo_srv', self.active_bo_callback)
bo_callback_group = ReentrantCallbackGroup() self.active_rl_client = self.create_client(ActiveRL, 'active_rl_srv')
rl_callback_group = ReentrantCallbackGroup()
self.srv = self.create_service(ActiveBO, self.rl_trigger_ = False
'active_bo_srv',
self.active_bo_callback,
callback_group=bo_callback_group)
self.active_rl_client = self.create_client(ActiveRL,
'active_rl_srv',
callback_group=rl_callback_group)
self.env = Continuous_MountainCarEnv() self.env = Continuous_MountainCarEnv()
self.distance_penalty = 0 self.distance_penalty = 0
@ -65,22 +56,18 @@ class ActiveBOService(Node):
arl_request.old_policy = old_policy.tolist() arl_request.old_policy = old_policy.tolist()
arl_request.old_weights = old_weights.tolist() arl_request.old_weights = old_weights.tolist()
self.get_logger().info('Calling: Active RL')
future_rl = self.active_rl_client.call_async(arl_request) future_rl = self.active_rl_client.call_async(arl_request)
while not future_rl.done(): while rclpy.ok():
rclpy.spin_once(self) rclpy.spin_once(self)
self.get_logger().info('waiting for response!') if future_rl.done():
try:
self.get_logger().info('Received: Active RL') arl_response = future_rl.result()
self.get_logger().info('active RL Response: %s' % arl_response)
try: BO.add_new_observation(arl_response.reward, arl_response.new_weights)
arl_response = future_rl.result() except Exception as e:
BO.add_new_observation(arl_response.reward, arl_response.new_weights) self.get_logger().error('active RL Service failed %r' % (e,))
except Exception as e: break
self.get_logger().error('active RL Service failed %r' % (e,))
future_rl = None
# BO part # BO part
else: else:
@ -109,8 +96,6 @@ def main(args=None):
rclpy.spin(active_bo_service) rclpy.spin(active_bo_service)
rclpy.shutdown()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,6 +1,6 @@
from active_bo_msgs.srv import ActiveRL from active_bo_msgs.srv import ActiveRL
from active_bo_msgs.msg import ImageFeedback from active_bo_msgs.msg import ImageFeedback
from active_bo_msgs.msg import ActiveRLEval from active_bo_msgs.srv import ActiveRLEval
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
@ -13,47 +13,48 @@ import numpy as np
import time import time
class ActiveRLService(Node): class ActiveRLService(Node):
def __init__(self): def __init__(self):
super().__init__('active_rl_service') super().__init__('active_rl_service')
srv_callback_group = ReentrantCallbackGroup() srv_callback_group = ReentrantCallbackGroup()
sub_callback_group = ReentrantCallbackGroup() eval_callback_group = ReentrantCallbackGroup()
self.srv = self.create_service(ActiveRL, self.rl_srv = self.create_service(ActiveRL,
'active_rl_srv', 'active_rl_srv',
self.active_rl_callback, self.active_rl_callback,
callback_group=srv_callback_group) callback_group=srv_callback_group)
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1) self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1) self.active_rl_eval_client = self.create_client(ActiveRLEval,
self.eval_sub = self.create_subscription(ActiveRLEval, 'active_rl_eval_srv',
'active_rl_eval_response', callback_group=eval_callback_group)
self.active_rl_eval_callback,
1, # self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1)
callback_group=sub_callback_group) # self.eval_sub = self.create_subscription(ActiveRLEval,
self.eval_response_received = False # 'active_rl_eval_response',
self.eval_response = None # self.active_rl_eval_callback,
self.eval_response_received_first = False # 10,
# callback_group=sub_callback_group)
# self.eval_response_received = False
# self.eval_response = None
# self.eval_response_received_first = False
self.env = Continuous_MountainCarEnv(render_mode='rgb_array') self.env = Continuous_MountainCarEnv(render_mode='rgb_array')
self.distance_penalty = 0 self.distance_penalty = 0
def active_rl_eval_callback(self, response): # def active_rl_eval_callback(self, response):
# if not self.eval_response_received_first: # # if not self.eval_response_received_first:
# self.eval_response_received_first = True # # self.eval_response_received_first = True
# self.get_logger().info('/active_rl_eval_response connected!') # # self.get_logger().info('/active_rl_eval_response connected!')
# else: # # else:
# self.eval_response = response # # self.eval_response = response
# self.eval_response_received = True # # self.eval_response_received = True
self.eval_response = response # self.eval_response = response
self.eval_response_received = True # self.eval_response_received = True
def active_rl_callback(self, request, response): def active_rl_callback(self, request, response):
self.get_logger().info('Active RL: Called')
feedback_msg = ImageFeedback() feedback_msg = ImageFeedback()
reward = 0 reward = 0
@ -61,9 +62,9 @@ class ActiveRLService(Node):
old_policy = request.old_policy old_policy = request.old_policy
old_weights = request.old_weights old_weights = request.old_weights
eval_request = ActiveRLEval() eval_request = ActiveRLEval.Request()
eval_request.policy = old_policy eval_request.old_policy = old_policy
eval_request.weights = old_weights eval_request.old_weights = old_weights
self.env.reset() self.env.reset()
@ -94,16 +95,17 @@ class ActiveRLService(Node):
break break
self.get_logger().info('Enter new solution!') self.get_logger().info('Enter new solution!')
self.eval_pub.publish(eval_request) # self.eval_pub.publish(eval_request)
#
# while not self.eval_response_received:
# rclpy.spin_once(self)
while not self.eval_response_received: eval_response = self.active_rl_eval_client.call(eval_request)
rclpy.spin_once(self) self.get_logger().info('Active RL Eval Srv started!')
new_policy = eval_response.new_policy
new_weights = eval_response.new_weights
self.get_logger().info('Topic responded!')
new_policy = self.eval_response.policy
new_weights = self.eval_response.weights
self.eval_response_received = False
self.eval_response = None
reward = 0 reward = 0
step_count = 0 step_count = 0
@ -155,8 +157,6 @@ def main(args=None):
rclpy.spin(active_rl_service) rclpy.spin(active_rl_service)
rclpy.shutdown()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -71,7 +71,5 @@ def main(args=None):
rclpy.spin(rl_service) rclpy.spin(rl_service)
rclpy.shutdown()
if __name__ == '__main__': if __name__ == '__main__':
main() main()