fixing the parameters is working
This commit is contained in:
parent
d32bf54d78
commit
dcdd257ead
@ -28,7 +28,9 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
"msg/ActiveBORequest.msg"
|
||||
"msg/ActiveBOResponse.msg"
|
||||
"msg/ActiveRLResponse.msg"
|
||||
"msg/ActiveRL.msg"
|
||||
"msg/ActiveRLRequest.msg"
|
||||
"msg/ActiveRLEvalRequest.msg"
|
||||
"msg/ActiveRLEvalResponse.msg"
|
||||
"msg/ImageFeedback.msg"
|
||||
"msg/ActiveBOState.msg"
|
||||
|
||||
|
2
src/active_bo_msgs/msg/ActiveRLEvalRequest.msg
Normal file
2
src/active_bo_msgs/msg/ActiveRLEvalRequest.msg
Normal file
@ -0,0 +1,2 @@
|
||||
float64[] policy
|
||||
float64[] weights
|
3
src/active_bo_msgs/msg/ActiveRLEvalResponse.msg
Normal file
3
src/active_bo_msgs/msg/ActiveRLEvalResponse.msg
Normal file
@ -0,0 +1,3 @@
|
||||
bool[] overwrite_weight
|
||||
float64[] policy
|
||||
float64[] weights
|
@ -2,7 +2,5 @@ string env
|
||||
uint32 seed
|
||||
bool display_run
|
||||
uint8 interactive_run
|
||||
bool overwrite_param
|
||||
bool[] overwrite_weight
|
||||
float64[] policy
|
||||
float64[] weights
|
@ -1,3 +1,4 @@
|
||||
float64[] weights
|
||||
bool[] overwrite_weight
|
||||
uint16 final_step
|
||||
float64 reward
|
@ -1,5 +1,7 @@
|
||||
from active_bo_msgs.msg import ActiveRL
|
||||
from active_bo_msgs.msg import ActiveRLRequest
|
||||
from active_bo_msgs.msg import ActiveRLResponse
|
||||
from active_bo_msgs.msg import ActiveRLEvalRequest
|
||||
from active_bo_msgs.msg import ActiveRLEvalResponse
|
||||
|
||||
from active_bo_msgs.msg import ImageFeedback
|
||||
|
||||
@ -19,7 +21,7 @@ import time
|
||||
import copy
|
||||
|
||||
|
||||
class ActiveRLService(Node):
|
||||
class ActiveRL(Node):
|
||||
def __init__(self):
|
||||
super().__init__('active_rl_service')
|
||||
rl_callback_group = ReentrantCallbackGroup()
|
||||
@ -30,7 +32,7 @@ class ActiveRLService(Node):
|
||||
self.active_rl_pub = self.create_publisher(ActiveRLResponse,
|
||||
'active_rl_response',
|
||||
1, callback_group=rl_callback_group)
|
||||
self.active_rl_sub = self.create_subscription(ActiveRL,
|
||||
self.active_rl_sub = self.create_subscription(ActiveRLRequest,
|
||||
'active_rl_request',
|
||||
self.active_rl_callback,
|
||||
1, callback_group=rl_callback_group)
|
||||
@ -48,11 +50,11 @@ class ActiveRLService(Node):
|
||||
1, callback_group=topic_callback_group)
|
||||
|
||||
# Active RL Evaluation Publisher, Subscriber and Message attributes
|
||||
self.eval_pub = self.create_publisher(ActiveRL,
|
||||
self.eval_pub = self.create_publisher(ActiveRLEvalRequest,
|
||||
'active_rl_eval_request',
|
||||
1,
|
||||
callback_group=topic_callback_group)
|
||||
self.eval_sub = self.create_subscription(ActiveRL,
|
||||
self.eval_sub = self.create_subscription(ActiveRLEvalResponse,
|
||||
'active_rl_eval_response',
|
||||
self.active_rl_eval_callback,
|
||||
1,
|
||||
@ -61,6 +63,7 @@ class ActiveRLService(Node):
|
||||
self.eval_response_received = False
|
||||
self.eval_policy = None
|
||||
self.eval_weights = None
|
||||
self.overwrite_weight = None
|
||||
|
||||
# RL Environments
|
||||
self.env = None
|
||||
@ -69,7 +72,7 @@ class ActiveRLService(Node):
|
||||
self.best_pol_shown = False
|
||||
self.policy_sent = False
|
||||
self.rl_pending = False
|
||||
self.interactive_run = False
|
||||
self.interactive_run = 0
|
||||
self.display_run = False
|
||||
|
||||
# Main loop timer object
|
||||
@ -84,6 +87,7 @@ class ActiveRLService(Node):
|
||||
self.rl_policy = None
|
||||
self.rl_weights = None
|
||||
self.interactive_run = 0
|
||||
self.display_run = False
|
||||
|
||||
def active_rl_callback(self, msg):
|
||||
self.rl_env = msg.env
|
||||
@ -117,6 +121,7 @@ class ActiveRLService(Node):
|
||||
def active_rl_eval_callback(self, msg):
|
||||
self.eval_policy = np.array(msg.policy, dtype=np.float64)
|
||||
self.eval_weights = msg.weights
|
||||
self.overwrite_weight = msg.overwrite_weight
|
||||
|
||||
self.get_logger().info('Active RL Eval: Responded!')
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
@ -185,7 +190,7 @@ class ActiveRLService(Node):
|
||||
self.rl_reward = 0.0
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
|
||||
eval_request = ActiveRL()
|
||||
eval_request = ActiveRLEvalRequest()
|
||||
eval_request.policy = self.rl_policy.tolist()
|
||||
eval_request.weights = self.rl_weights
|
||||
|
||||
@ -214,6 +219,7 @@ class ActiveRLService(Node):
|
||||
rl_response.weights = self.eval_weights
|
||||
rl_response.reward = self.rl_reward
|
||||
rl_response.final_step = self.rl_step
|
||||
rl_response.overwrite_weight = self.overwrite_weight
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
|
||||
@ -236,7 +242,7 @@ class ActiveRLService(Node):
|
||||
self.rl_reward = 0.0
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
|
||||
eval_request = ActiveRL()
|
||||
eval_request = ActiveRLEvalRequest()
|
||||
eval_request.policy = self.rl_policy.tolist()
|
||||
eval_request.weights = self.rl_weights
|
||||
|
||||
@ -260,45 +266,23 @@ class ActiveRLService(Node):
|
||||
rl_response.weights = self.rl_weights
|
||||
rl_response.reward = env_reward
|
||||
rl_response.final_step = step_count
|
||||
if self.overwrite_weight is None:
|
||||
overwrite_weight = [False] * len(self.rl_weights)
|
||||
else:
|
||||
overwrite_weight = self.overwrite_weight
|
||||
|
||||
rl_response.overwrite_weight = overwrite_weight
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
|
||||
self.reset_rl_request()
|
||||
self.rl_pending = False
|
||||
|
||||
# if not self.policy_sent:
|
||||
# self.rl_step = 0
|
||||
# self.rl_reward = 0.0
|
||||
# self.env.reset(seed=self.rl_seed)
|
||||
# self.policy_sent = True
|
||||
# done = self.next_image(self.rl_policy, self.display_run)
|
||||
#
|
||||
# if done:
|
||||
# rl_response = ActiveRLResponse()
|
||||
# rl_response.weights = self.rl_weights
|
||||
# rl_response.reward = self.rl_reward
|
||||
# rl_response.final_step = self.rl_step
|
||||
#
|
||||
# self.active_rl_pub.publish(rl_response)
|
||||
# self.end_time = time.time()
|
||||
# self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
|
||||
# self.begin_time = None
|
||||
# self.end_time = None
|
||||
#
|
||||
# # reset flags and attributes
|
||||
# self.reset_eval_request()
|
||||
# self.reset_rl_request()
|
||||
#
|
||||
# self.rl_step = 0
|
||||
# self.rl_reward = 0.0
|
||||
#
|
||||
# self.rl_pending = False
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
|
||||
active_rl_service = ActiveRLService()
|
||||
active_rl_service = ActiveRL()
|
||||
|
||||
rclpy.spin(active_rl_service)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from active_bo_msgs.msg import ActiveBORequest
|
||||
from active_bo_msgs.msg import ActiveBOResponse
|
||||
from active_bo_msgs.msg import ActiveRL
|
||||
from active_bo_msgs.msg import ActiveRLRequest
|
||||
from active_bo_msgs.msg import ActiveRLResponse
|
||||
from active_bo_msgs.msg import ActiveBOState
|
||||
|
||||
@ -54,6 +54,7 @@ class ActiveBOTopic(Node):
|
||||
self.bo_runs = 0
|
||||
self.bo_acq_fcn = None
|
||||
self.bo_metric_parameter = None
|
||||
self.bo_metric_parameter_2 = None
|
||||
self.current_run = 0
|
||||
self.current_episode = 0
|
||||
self.seed = None
|
||||
@ -61,7 +62,7 @@ class ActiveBOTopic(Node):
|
||||
self.save_result = False
|
||||
|
||||
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
|
||||
self.active_rl_pub = self.create_publisher(ActiveRL,
|
||||
self.active_rl_pub = self.create_publisher(ActiveRLRequest,
|
||||
'active_rl_request',
|
||||
1, callback_group=rl_callback_group)
|
||||
self.active_rl_sub = self.create_subscription(ActiveRLResponse,
|
||||
@ -73,6 +74,7 @@ class ActiveBOTopic(Node):
|
||||
self.rl_weights = None
|
||||
self.rl_final_step = None
|
||||
self.rl_reward = 0.0
|
||||
self.overwrite_weight = None
|
||||
|
||||
# State Publisher
|
||||
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
|
||||
@ -112,6 +114,7 @@ class ActiveBOTopic(Node):
|
||||
self.bo_runs = 0
|
||||
self.bo_acq_fcn = None
|
||||
self.bo_metric_parameter = None
|
||||
self.bo_metric_parameter_2 = None
|
||||
self.current_run = 0
|
||||
self.current_episode = 0
|
||||
self.save_result = False
|
||||
@ -172,6 +175,7 @@ class ActiveBOTopic(Node):
|
||||
if self.rl_pending:
|
||||
# self.get_logger().info('Active Reinforcement Learning response received!')
|
||||
self.rl_weights = np.array(msg.weights, dtype=np.float64)
|
||||
self.overwrite_weight = np.array(msg.overwrite_weight, dtype=bool)
|
||||
self.rl_final_step = msg.final_step
|
||||
self.rl_reward = msg.reward
|
||||
|
||||
@ -193,7 +197,7 @@ class ActiveBOTopic(Node):
|
||||
self.user_asked = False
|
||||
|
||||
self.rl_pending = False
|
||||
self.reset_rl_response()
|
||||
# self.reset_rl_response()
|
||||
|
||||
def mainloop_callback(self):
|
||||
if not self.active_bo_pending:
|
||||
@ -223,7 +227,7 @@ class ActiveBOTopic(Node):
|
||||
else:
|
||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
|
||||
rl_msg = ActiveRL()
|
||||
rl_msg = ActiveRLRequest()
|
||||
rl_msg.env = self.bo_env
|
||||
rl_msg.seed = seed
|
||||
rl_msg.display_run = False
|
||||
@ -282,7 +286,7 @@ class ActiveBOTopic(Node):
|
||||
|
||||
np.savetxt(path, data, delimiter=',')
|
||||
|
||||
active_rl_request = ActiveRL()
|
||||
active_rl_request = ActiveRLRequest()
|
||||
|
||||
if self.bo_fixed_seed:
|
||||
seed = int(self.seed_array[0, best_policy_idx])
|
||||
@ -292,10 +296,10 @@ class ActiveBOTopic(Node):
|
||||
|
||||
active_rl_request.env = self.bo_env
|
||||
active_rl_request.seed = seed
|
||||
active_rl_request.display_run = True
|
||||
active_rl_request.interactive_run = 1
|
||||
active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist()
|
||||
active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist()
|
||||
active_rl_request.interactive_run = 1
|
||||
active_rl_request.display_run = True
|
||||
|
||||
self.active_rl_pub.publish(active_rl_request)
|
||||
|
||||
@ -335,7 +339,7 @@ class ActiveBOTopic(Node):
|
||||
if user_query.query():
|
||||
self.last_query = self.current_episode
|
||||
self.user_asked = True
|
||||
active_rl_request = ActiveRL()
|
||||
active_rl_request = ActiveRLRequest()
|
||||
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
|
||||
|
||||
# self.get_logger().info(f'Best: {y_max}, w:{old_weights}')
|
||||
@ -349,9 +353,9 @@ class ActiveBOTopic(Node):
|
||||
active_rl_request.env = self.bo_env
|
||||
active_rl_request.seed = seed
|
||||
active_rl_request.display_run = True
|
||||
active_rl_request.interactive_run = 0
|
||||
active_rl_request.policy = old_policy.tolist()
|
||||
active_rl_request.weights = old_weights.tolist()
|
||||
active_rl_request.interactive_run = 0
|
||||
|
||||
# self.get_logger().info('Calling: Active RL')
|
||||
self.active_rl_pub.publish(active_rl_request)
|
||||
@ -359,6 +363,16 @@ class ActiveBOTopic(Node):
|
||||
|
||||
else:
|
||||
x_next = self.BO.next_observation()
|
||||
self.get_logger().info(f'x_next: {x_next}')
|
||||
self.get_logger().info(f'overwrite: {self.overwrite_weight}')
|
||||
self.get_logger().info(f'rl_weights: {self.rl_weights}')
|
||||
|
||||
if self.overwrite:
|
||||
if self.overwrite_weight is not None and self.rl_weights is not None:
|
||||
x_next[self.overwrite_weight] = self.rl_weights[self.overwrite_weight]
|
||||
self.get_logger().info(f'x_next: {x_next}')
|
||||
self.get_logger().info(f'overwrite: {self.overwrite_weight}')
|
||||
self.get_logger().info(f'rl_weights: {self.rl_weights}')
|
||||
# self.get_logger().info('Next Observation BO!')
|
||||
self.BO.policy_model.weights = np.around(x_next, decimals=8)
|
||||
if self.bo_fixed_seed:
|
||||
@ -366,13 +380,14 @@ class ActiveBOTopic(Node):
|
||||
else:
|
||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
|
||||
rl_msg = ActiveRL()
|
||||
rl_msg = ActiveRLRequest()
|
||||
rl_msg.env = self.bo_env
|
||||
rl_msg.seed = seed
|
||||
rl_msg.display_run = False
|
||||
rl_msg.interactive_run = 2
|
||||
rl_msg.weights = x_next.tolist()
|
||||
rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
|
||||
rl_msg.weights = x_next.tolist()
|
||||
|
||||
self.rl_pending = True
|
||||
|
||||
self.active_rl_pub.publish(rl_msg)
|
||||
@ -381,15 +396,12 @@ class ActiveBOTopic(Node):
|
||||
self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y)
|
||||
self.get_logger().info(f'Current Episode: {self.current_episode},'
|
||||
f' best reward: {self.reward[self.current_episode, self.current_run]}')
|
||||
|
||||
else:
|
||||
self.best_policy[:, self.current_run], \
|
||||
self.best_pol_reward[:, self.current_run], \
|
||||
self.best_weights[:, self.current_run], idx = self.BO.get_best_result()
|
||||
|
||||
# self.get_logger().info(f'best idx: {idx}')
|
||||
|
||||
# self.reward[:, self.current_run] = self.BO.best_reward.T
|
||||
|
||||
if self.current_run < self.bo_runs - 1:
|
||||
self.BO = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user