fixing the parameters is working

This commit is contained in:
Niko Feith 2023-06-20 14:43:17 +02:00
parent d32bf54d78
commit dcdd257ead
9 changed files with 57 additions and 55 deletions

View File

@ -28,7 +28,9 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"msg/ActiveBORequest.msg"
"msg/ActiveBOResponse.msg"
"msg/ActiveRLResponse.msg"
"msg/ActiveRL.msg"
"msg/ActiveRLRequest.msg"
"msg/ActiveRLEvalRequest.msg"
"msg/ActiveRLEvalResponse.msg"
"msg/ImageFeedback.msg"
"msg/ActiveBOState.msg"

View File

@ -0,0 +1,2 @@
float64[] policy
float64[] weights

View File

@ -0,0 +1,3 @@
bool[] overwrite_weight
float64[] policy
float64[] weights

View File

@ -2,7 +2,5 @@ string env
uint32 seed
bool display_run
uint8 interactive_run
bool overwrite_param
bool[] overwrite_weight
float64[] policy
float64[] weights

View File

@ -1,3 +1,4 @@
float64[] weights
bool[] overwrite_weight
uint16 final_step
float64 reward

View File

@ -1,5 +1,7 @@
from active_bo_msgs.msg import ActiveRL
from active_bo_msgs.msg import ActiveRLRequest
from active_bo_msgs.msg import ActiveRLResponse
from active_bo_msgs.msg import ActiveRLEvalRequest
from active_bo_msgs.msg import ActiveRLEvalResponse
from active_bo_msgs.msg import ImageFeedback
@ -19,7 +21,7 @@ import time
import copy
class ActiveRLService(Node):
class ActiveRL(Node):
def __init__(self):
super().__init__('active_rl_service')
rl_callback_group = ReentrantCallbackGroup()
@ -30,7 +32,7 @@ class ActiveRLService(Node):
self.active_rl_pub = self.create_publisher(ActiveRLResponse,
'active_rl_response',
1, callback_group=rl_callback_group)
self.active_rl_sub = self.create_subscription(ActiveRL,
self.active_rl_sub = self.create_subscription(ActiveRLRequest,
'active_rl_request',
self.active_rl_callback,
1, callback_group=rl_callback_group)
@ -48,11 +50,11 @@ class ActiveRLService(Node):
1, callback_group=topic_callback_group)
# Active RL Evaluation Publisher, Subscriber and Message attributes
self.eval_pub = self.create_publisher(ActiveRL,
self.eval_pub = self.create_publisher(ActiveRLEvalRequest,
'active_rl_eval_request',
1,
callback_group=topic_callback_group)
self.eval_sub = self.create_subscription(ActiveRL,
self.eval_sub = self.create_subscription(ActiveRLEvalResponse,
'active_rl_eval_response',
self.active_rl_eval_callback,
1,
@ -61,6 +63,7 @@ class ActiveRLService(Node):
self.eval_response_received = False
self.eval_policy = None
self.eval_weights = None
self.overwrite_weight = None
# RL Environments
self.env = None
@ -69,7 +72,7 @@ class ActiveRLService(Node):
self.best_pol_shown = False
self.policy_sent = False
self.rl_pending = False
self.interactive_run = False
self.interactive_run = 0
self.display_run = False
# Main loop timer object
@ -84,6 +87,7 @@ class ActiveRLService(Node):
self.rl_policy = None
self.rl_weights = None
self.interactive_run = 0
self.display_run = False
def active_rl_callback(self, msg):
self.rl_env = msg.env
@ -117,6 +121,7 @@ class ActiveRLService(Node):
def active_rl_eval_callback(self, msg):
self.eval_policy = np.array(msg.policy, dtype=np.float64)
self.eval_weights = msg.weights
self.overwrite_weight = msg.overwrite_weight
self.get_logger().info('Active RL Eval: Responded!')
self.env.reset(seed=self.rl_seed)
@ -185,7 +190,7 @@ class ActiveRLService(Node):
self.rl_reward = 0.0
self.env.reset(seed=self.rl_seed)
eval_request = ActiveRL()
eval_request = ActiveRLEvalRequest()
eval_request.policy = self.rl_policy.tolist()
eval_request.weights = self.rl_weights
@ -214,6 +219,7 @@ class ActiveRLService(Node):
rl_response.weights = self.eval_weights
rl_response.reward = self.rl_reward
rl_response.final_step = self.rl_step
rl_response.overwrite_weight = self.overwrite_weight
self.active_rl_pub.publish(rl_response)
@ -236,7 +242,7 @@ class ActiveRLService(Node):
self.rl_reward = 0.0
self.env.reset(seed=self.rl_seed)
eval_request = ActiveRL()
eval_request = ActiveRLEvalRequest()
eval_request.policy = self.rl_policy.tolist()
eval_request.weights = self.rl_weights
@ -260,45 +266,23 @@ class ActiveRLService(Node):
rl_response.weights = self.rl_weights
rl_response.reward = env_reward
rl_response.final_step = step_count
if self.overwrite_weight is None:
overwrite_weight = [False] * len(self.rl_weights)
else:
overwrite_weight = self.overwrite_weight
rl_response.overwrite_weight = overwrite_weight
self.active_rl_pub.publish(rl_response)
self.reset_rl_request()
self.rl_pending = False
# if not self.policy_sent:
# self.rl_step = 0
# self.rl_reward = 0.0
# self.env.reset(seed=self.rl_seed)
# self.policy_sent = True
# done = self.next_image(self.rl_policy, self.display_run)
#
# if done:
# rl_response = ActiveRLResponse()
# rl_response.weights = self.rl_weights
# rl_response.reward = self.rl_reward
# rl_response.final_step = self.rl_step
#
# self.active_rl_pub.publish(rl_response)
# self.end_time = time.time()
# self.get_logger().info(f'RL Time: {self.end_time - self.begin_time}, mode: {self.interactive_run}')
# self.begin_time = None
# self.end_time = None
#
# # reset flags and attributes
# self.reset_eval_request()
# self.reset_rl_request()
#
# self.rl_step = 0
# self.rl_reward = 0.0
#
# self.rl_pending = False
def main(args=None):
rclpy.init(args=args)
active_rl_service = ActiveRLService()
active_rl_service = ActiveRL()
rclpy.spin(active_rl_service)

View File

@ -1,6 +1,6 @@
from active_bo_msgs.msg import ActiveBORequest
from active_bo_msgs.msg import ActiveBOResponse
from active_bo_msgs.msg import ActiveRL
from active_bo_msgs.msg import ActiveRLRequest
from active_bo_msgs.msg import ActiveRLResponse
from active_bo_msgs.msg import ActiveBOState
@ -54,6 +54,7 @@ class ActiveBOTopic(Node):
self.bo_runs = 0
self.bo_acq_fcn = None
self.bo_metric_parameter = None
self.bo_metric_parameter_2 = None
self.current_run = 0
self.current_episode = 0
self.seed = None
@ -61,7 +62,7 @@ class ActiveBOTopic(Node):
self.save_result = False
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
self.active_rl_pub = self.create_publisher(ActiveRL,
self.active_rl_pub = self.create_publisher(ActiveRLRequest,
'active_rl_request',
1, callback_group=rl_callback_group)
self.active_rl_sub = self.create_subscription(ActiveRLResponse,
@ -73,6 +74,7 @@ class ActiveBOTopic(Node):
self.rl_weights = None
self.rl_final_step = None
self.rl_reward = 0.0
self.overwrite_weight = None
# State Publisher
self.state_pub = self.create_publisher(ActiveBOState, 'active_bo_state', 1)
@ -112,6 +114,7 @@ class ActiveBOTopic(Node):
self.bo_runs = 0
self.bo_acq_fcn = None
self.bo_metric_parameter = None
self.bo_metric_parameter_2 = None
self.current_run = 0
self.current_episode = 0
self.save_result = False
@ -172,6 +175,7 @@ class ActiveBOTopic(Node):
if self.rl_pending:
# self.get_logger().info('Active Reinforcement Learning response received!')
self.rl_weights = np.array(msg.weights, dtype=np.float64)
self.overwrite_weight = np.array(msg.overwrite_weight, dtype=bool)
self.rl_final_step = msg.final_step
self.rl_reward = msg.reward
@ -193,7 +197,7 @@ class ActiveBOTopic(Node):
self.user_asked = False
self.rl_pending = False
self.reset_rl_response()
# self.reset_rl_response()
def mainloop_callback(self):
if not self.active_bo_pending:
@ -223,7 +227,7 @@ class ActiveBOTopic(Node):
else:
seed = int(np.random.randint(1, 2147483647, 1)[0])
rl_msg = ActiveRL()
rl_msg = ActiveRLRequest()
rl_msg.env = self.bo_env
rl_msg.seed = seed
rl_msg.display_run = False
@ -282,7 +286,7 @@ class ActiveBOTopic(Node):
np.savetxt(path, data, delimiter=',')
active_rl_request = ActiveRL()
active_rl_request = ActiveRLRequest()
if self.bo_fixed_seed:
seed = int(self.seed_array[0, best_policy_idx])
@ -292,10 +296,10 @@ class ActiveBOTopic(Node):
active_rl_request.env = self.bo_env
active_rl_request.seed = seed
active_rl_request.display_run = True
active_rl_request.interactive_run = 1
active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist()
active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist()
active_rl_request.interactive_run = 1
active_rl_request.display_run = True
self.active_rl_pub.publish(active_rl_request)
@ -335,7 +339,7 @@ class ActiveBOTopic(Node):
if user_query.query():
self.last_query = self.current_episode
self.user_asked = True
active_rl_request = ActiveRL()
active_rl_request = ActiveRLRequest()
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
# self.get_logger().info(f'Best: {y_max}, w:{old_weights}')
@ -349,9 +353,9 @@ class ActiveBOTopic(Node):
active_rl_request.env = self.bo_env
active_rl_request.seed = seed
active_rl_request.display_run = True
active_rl_request.interactive_run = 0
active_rl_request.policy = old_policy.tolist()
active_rl_request.weights = old_weights.tolist()
active_rl_request.interactive_run = 0
# self.get_logger().info('Calling: Active RL')
self.active_rl_pub.publish(active_rl_request)
@ -359,6 +363,16 @@ class ActiveBOTopic(Node):
else:
x_next = self.BO.next_observation()
self.get_logger().info(f'x_next: {x_next}')
self.get_logger().info(f'overwrite: {self.overwrite_weight}')
self.get_logger().info(f'rl_weights: {self.rl_weights}')
if self.overwrite:
if self.overwrite_weight is not None and self.rl_weights is not None:
x_next[self.overwrite_weight] = self.rl_weights[self.overwrite_weight]
self.get_logger().info(f'x_next: {x_next}')
self.get_logger().info(f'overwrite: {self.overwrite_weight}')
self.get_logger().info(f'rl_weights: {self.rl_weights}')
# self.get_logger().info('Next Observation BO!')
self.BO.policy_model.weights = np.around(x_next, decimals=8)
if self.bo_fixed_seed:
@ -366,13 +380,14 @@ class ActiveBOTopic(Node):
else:
seed = int(np.random.randint(1, 2147483647, 1)[0])
rl_msg = ActiveRL()
rl_msg = ActiveRLRequest()
rl_msg.env = self.bo_env
rl_msg.seed = seed
rl_msg.display_run = False
rl_msg.interactive_run = 2
rl_msg.weights = x_next.tolist()
rl_msg.policy = self.BO.policy_model.rollout().reshape(-1,).tolist()
rl_msg.weights = x_next.tolist()
self.rl_pending = True
self.active_rl_pub.publish(rl_msg)
@ -381,15 +396,12 @@ class ActiveBOTopic(Node):
self.reward[self.current_episode, self.current_run] = np.max(self.BO.Y)
self.get_logger().info(f'Current Episode: {self.current_episode},'
f' best reward: {self.reward[self.current_episode, self.current_run]}')
else:
self.best_policy[:, self.current_run], \
self.best_pol_reward[:, self.current_run], \
self.best_weights[:, self.current_run], idx = self.BO.get_best_result()
# self.get_logger().info(f'best idx: {idx}')
# self.reward[:, self.current_run] = self.BO.best_reward.T
if self.current_run < self.bo_runs - 1:
self.BO = None