publisher doesnt work
This commit is contained in:
parent
c27675d156
commit
f25d1ff559
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,3 +3,4 @@
|
|||||||
/build/
|
/build/
|
||||||
/install/
|
/install/
|
||||||
/log/
|
/log/
|
||||||
|
/src/active_bo_ros/active_bo_ros/dump/
|
||||||
|
@ -25,8 +25,8 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
|||||||
"srv/BO.srv"
|
"srv/BO.srv"
|
||||||
"srv/ActiveBO.srv"
|
"srv/ActiveBO.srv"
|
||||||
"srv/ActiveRL.srv"
|
"srv/ActiveRL.srv"
|
||||||
"srv/ActiveRLEval.srv"
|
|
||||||
"msg/ImageFeedback.msg"
|
"msg/ImageFeedback.msg"
|
||||||
|
"msg/ActiveRLEval.msg"
|
||||||
)
|
)
|
||||||
|
|
||||||
if(BUILD_TESTING)
|
if(BUILD_TESTING)
|
||||||
|
2
src/active_bo_msgs/msg/ActiveRLEval.msg
Normal file
2
src/active_bo_msgs/msg/ActiveRLEval.msg
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
float32[] policy
|
||||||
|
float32[] weights
|
@ -1,10 +1,12 @@
|
|||||||
from active_bo_msgs.srv import ActiveRL
|
from active_bo_msgs.srv import ActiveRL
|
||||||
from active_bo_msgs.srv import ActiveRLEval
|
|
||||||
from active_bo_msgs.msg import ImageFeedback
|
from active_bo_msgs.msg import ImageFeedback
|
||||||
|
from active_bo_msgs.msg import ActiveRLEval
|
||||||
|
|
||||||
import rclpy
|
import rclpy
|
||||||
from rclpy.node import Node
|
from rclpy.node import Node
|
||||||
|
|
||||||
|
from rclpy.callback_groups import ReentrantCallbackGroup
|
||||||
|
|
||||||
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
|
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -13,14 +15,32 @@ import numpy as np
|
|||||||
class ActiveRLService(Node):
|
class ActiveRLService(Node):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('active_rl_service')
|
super().__init__('active_rl_service')
|
||||||
self.srv = self.create_service(ActiveRL, 'active_rl_srv', self.active_rl_callback)
|
srv_callback_group = ReentrantCallbackGroup()
|
||||||
self.eval_srv = self.create_client(ActiveRLEval, 'active_rl_eval_srv')
|
sub_callback_group = ReentrantCallbackGroup()
|
||||||
|
|
||||||
|
self.srv = self.create_service(ActiveRL,
|
||||||
|
'active_rl_srv',
|
||||||
|
self.active_rl_callback,
|
||||||
|
callback_group=srv_callback_group)
|
||||||
|
|
||||||
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
||||||
|
|
||||||
|
self.eval_pub = self.create_publisher(ActiveRLEval, 'active_rl_eval_request', 1)
|
||||||
|
self.eval_sub = self.create_subscription(ActiveRLEval,
|
||||||
|
'active_rl_eval_response',
|
||||||
|
self.active_rl_eval_callback,
|
||||||
|
10,
|
||||||
|
callback_group=sub_callback_group)
|
||||||
|
self.eval_response_received = False
|
||||||
|
self.eval_response = None
|
||||||
|
|
||||||
self.env = Continuous_MountainCarEnv(render_mode='rgb_array')
|
self.env = Continuous_MountainCarEnv(render_mode='rgb_array')
|
||||||
self.distance_penalty = 0
|
self.distance_penalty = 0
|
||||||
|
|
||||||
|
def active_rl_eval_callback(self, response):
|
||||||
|
self.eval_response = response
|
||||||
|
self.eval_response_received = True
|
||||||
|
|
||||||
def active_rl_callback(self, request, response):
|
def active_rl_callback(self, request, response):
|
||||||
|
|
||||||
feedback_msg = ImageFeedback()
|
feedback_msg = ImageFeedback()
|
||||||
@ -30,9 +50,9 @@ class ActiveRLService(Node):
|
|||||||
old_policy = request.old_policy
|
old_policy = request.old_policy
|
||||||
old_weights = request.old_weights
|
old_weights = request.old_weights
|
||||||
|
|
||||||
eval_request = ActiveRLEval.Request()
|
eval_request = ActiveRLEval()
|
||||||
eval_request.old_policy = old_policy
|
eval_request.policy = old_policy
|
||||||
eval_request.old_weights = old_weights
|
eval_request.weights = old_weights
|
||||||
|
|
||||||
self.env.reset()
|
self.env.reset()
|
||||||
|
|
||||||
@ -63,9 +83,20 @@ class ActiveRLService(Node):
|
|||||||
break
|
break
|
||||||
|
|
||||||
self.get_logger().info('Enter new solution!')
|
self.get_logger().info('Enter new solution!')
|
||||||
eval_response = self.eval_srv.call(eval_request)
|
self.eval_pub.publish(eval_request)
|
||||||
|
|
||||||
new_policy = eval_response.new_policy.tolist()
|
while not self.eval_response_received:
|
||||||
|
rclpy.spin_once(self)
|
||||||
|
|
||||||
|
self.get_logger().info('Topic responded!')
|
||||||
|
new_policy = self.eval_response.policy
|
||||||
|
new_weights = self.eval_response.weights
|
||||||
|
self.eval_response_received = False
|
||||||
|
self.eval_response = None
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
step_count = 0
|
||||||
|
done = False
|
||||||
|
|
||||||
for i in range(len(new_policy)):
|
for i in range(len(new_policy)):
|
||||||
action = new_policy[i]
|
action = new_policy[i]
|
||||||
@ -97,7 +128,8 @@ class ActiveRLService(Node):
|
|||||||
distance = -(self.env.goal_position - output[0][0])
|
distance = -(self.env.goal_position - output[0][0])
|
||||||
reward += distance * self.distance_penalty
|
reward += distance * self.distance_penalty
|
||||||
|
|
||||||
response.new_weights = eval_response.Response.new_weights
|
self.get_logger().info(str(reward))
|
||||||
|
response.new_weights = new_weights
|
||||||
response.reward = reward
|
response.reward = reward
|
||||||
response.final_step = step_count
|
response.final_step = step_count
|
||||||
|
|
||||||
|
117
src/active_bo_ros/active_bo_ros/dump/active_rl_service_dump.py
Normal file
117
src/active_bo_ros/active_bo_ros/dump/active_rl_service_dump.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
from active_bo_msgs.srv import ActiveRL
|
||||||
|
from active_bo_msgs.srv import ActiveRLEval
|
||||||
|
from active_bo_msgs.msg import ImageFeedback
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
|
||||||
|
from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous_MountainCarEnv
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveRLService(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('active_rl_service')
|
||||||
|
self.srv = self.create_service(ActiveRL, 'active_rl_srv', self.active_rl_callback)
|
||||||
|
self.eval_srv = self.create_client(ActiveRLEval, 'active_rl_eval_srv')
|
||||||
|
|
||||||
|
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
||||||
|
|
||||||
|
self.env = Continuous_MountainCarEnv(render_mode='rgb_array')
|
||||||
|
self.distance_penalty = 0
|
||||||
|
|
||||||
|
def active_rl_callback(self, request, response):
|
||||||
|
|
||||||
|
feedback_msg = ImageFeedback()
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
step_count = 0
|
||||||
|
old_policy = request.old_policy
|
||||||
|
old_weights = request.old_weights
|
||||||
|
|
||||||
|
eval_request = ActiveRLEval.Request()
|
||||||
|
eval_request.old_policy = old_policy
|
||||||
|
eval_request.old_weights = old_weights
|
||||||
|
|
||||||
|
self.env.reset()
|
||||||
|
|
||||||
|
self.get_logger().info('Best policy so far!')
|
||||||
|
|
||||||
|
for i in range(len(old_policy)):
|
||||||
|
action = old_policy[i]
|
||||||
|
output = self.env.step(action)
|
||||||
|
|
||||||
|
done = output[2]
|
||||||
|
|
||||||
|
rgb_array = self.env.render()
|
||||||
|
rgb_shape = rgb_array.shape
|
||||||
|
|
||||||
|
red = rgb_array[:, :, 0].flatten().tolist()
|
||||||
|
green = rgb_array[:, :, 1].flatten().tolist()
|
||||||
|
blue = rgb_array[:, :, 2].flatten().tolist()
|
||||||
|
|
||||||
|
feedback_msg.height = rgb_shape[0]
|
||||||
|
feedback_msg.width = rgb_shape[1]
|
||||||
|
feedback_msg.red = red
|
||||||
|
feedback_msg.green = green
|
||||||
|
feedback_msg.blue = blue
|
||||||
|
|
||||||
|
self.publisher.publish(feedback_msg)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.get_logger().info('Enter new solution!')
|
||||||
|
eval_response = self.eval_srv.call(eval_request)
|
||||||
|
self.get_logger().info('Service responded!')
|
||||||
|
|
||||||
|
new_policy = eval_response.new_policy
|
||||||
|
|
||||||
|
for i in range(len(new_policy)):
|
||||||
|
action = new_policy[i]
|
||||||
|
output = self.env.step(action)
|
||||||
|
|
||||||
|
reward += output[1]
|
||||||
|
done = output[2]
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
rgb_array = self.env.render()
|
||||||
|
rgb_shape = rgb_array.shape
|
||||||
|
|
||||||
|
red = rgb_array[:, :, 0].flatten().tolist()
|
||||||
|
green = rgb_array[:, :, 1].flatten().tolist()
|
||||||
|
blue = rgb_array[:, :, 2].flatten().tolist()
|
||||||
|
|
||||||
|
feedback_msg.height = rgb_shape[0]
|
||||||
|
feedback_msg.width = rgb_shape[1]
|
||||||
|
feedback_msg.red = red
|
||||||
|
feedback_msg.green = green
|
||||||
|
feedback_msg.blue = blue
|
||||||
|
|
||||||
|
self.publisher.publish(feedback_msg)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not done and i == len(new_policy):
|
||||||
|
distance = -(self.env.goal_position - output[0][0])
|
||||||
|
reward += distance * self.distance_penalty
|
||||||
|
|
||||||
|
response.new_weights = eval_response.Response.new_weights
|
||||||
|
response.reward = reward
|
||||||
|
response.final_step = step_count
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
|
||||||
|
active_rl_service = ActiveRLService()
|
||||||
|
|
||||||
|
rclpy.spin(active_rl_service)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -8,6 +8,7 @@ from active_bo_ros.ReinforcementLearning.ContinuousMountainCar import Continuous
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class RLService(Node):
|
class RLService(Node):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('rl_service')
|
super().__init__('rl_service')
|
||||||
@ -43,16 +44,6 @@ class RLService(Node):
|
|||||||
green = rgb_array[:, :, 1].flatten().tolist()
|
green = rgb_array[:, :, 1].flatten().tolist()
|
||||||
blue = rgb_array[:, :, 2].flatten().tolist()
|
blue = rgb_array[:, :, 2].flatten().tolist()
|
||||||
|
|
||||||
# red = [255] * 28800 + [0] * 28800 + [0] * 28800
|
|
||||||
# green = [0] * 28800 + [255] * 28800 + [0] * 28800
|
|
||||||
# blue = [0] * 28800 + [0] * 28800 + [255] * 28800
|
|
||||||
|
|
||||||
|
|
||||||
# random int data
|
|
||||||
# red = np.random.randint(0, 255, 240000).tolist()
|
|
||||||
# green = np.random.randint(0, 255, 240000).tolist()
|
|
||||||
# blue = np.random.randint(0, 255, 240000).tolist()
|
|
||||||
|
|
||||||
feedback_msg.height = rgb_shape[0]
|
feedback_msg.height = rgb_shape[0]
|
||||||
feedback_msg.width = rgb_shape[1]
|
feedback_msg.width = rgb_shape[1]
|
||||||
feedback_msg.red = red
|
feedback_msg.red = red
|
||||||
|
Loading…
Reference in New Issue
Block a user