random metric works first results
This commit is contained in:
parent
55b21d667a
commit
334f64e22d
@ -1,4 +1,5 @@
|
|||||||
string env
|
string env
|
||||||
uint32 seed
|
uint32 seed
|
||||||
|
bool final_run
|
||||||
float32[] policy
|
float32[] policy
|
||||||
float32[] weights
|
float32[] weights
|
@ -21,7 +21,6 @@ class BayesianOptimization:
|
|||||||
self.episode = 0
|
self.episode = 0
|
||||||
self.counter_array = np.empty((1, 1))
|
self.counter_array = np.empty((1, 1))
|
||||||
self.best_reward = np.empty((1, 1))
|
self.best_reward = np.empty((1, 1))
|
||||||
self.distance_penalty = 0
|
|
||||||
|
|
||||||
self.nr_policy_weights = nr_weights
|
self.nr_policy_weights = nr_weights
|
||||||
self.nr_steps = nr_steps
|
self.nr_steps = nr_steps
|
||||||
@ -63,8 +62,6 @@ class BayesianOptimization:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not done and i == len(policy):
|
if not done and i == len(policy):
|
||||||
distance = -(self.env.goal_position - output[0][0])
|
|
||||||
env_reward += distance * self.distance_penalty
|
|
||||||
self.counter_array = np.vstack((self.counter_array, step_count))
|
self.counter_array = np.vstack((self.counter_array, step_count))
|
||||||
|
|
||||||
self.env.reset(seed=seed)
|
self.env.reset(seed=seed)
|
||||||
@ -157,11 +154,11 @@ class BayesianOptimization:
|
|||||||
self.episode += 1
|
self.episode += 1
|
||||||
|
|
||||||
def get_best_result(self):
|
def get_best_result(self):
|
||||||
y_hat = self.GP.predict(self.X)
|
y_max = np.max(self.Y)
|
||||||
idx = np.argmax(y_hat)
|
idx = np.argmax(self.Y)
|
||||||
x_max = self.X[idx, :]
|
x_max = self.X[idx, :]
|
||||||
|
|
||||||
self.policy_model.weights = x_max
|
self.policy_model.weights = x_max
|
||||||
best_policy = self.policy_model.rollout().reshape(-1,)
|
best_policy = self.policy_model.rollout().reshape(-1,)
|
||||||
|
|
||||||
return best_policy, y_hat[idx], x_max
|
return best_policy, y_max, x_max
|
||||||
|
@ -189,17 +189,52 @@ class ActiveBOTopic(Node):
|
|||||||
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
||||||
|
|
||||||
if self.save_result:
|
if self.save_result:
|
||||||
|
if self.bo_env == "Mountain Car":
|
||||||
|
env = 'mc'
|
||||||
|
elif self.bo_env == "Cartpole":
|
||||||
|
env = 'cp'
|
||||||
|
elif self.bo_env == "Acrobot":
|
||||||
|
env = 'ab'
|
||||||
|
elif self.bo_env == "Pendulum":
|
||||||
|
env = 'pd'
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if self.bo_acq_fcn == "Expected Improvement":
|
||||||
|
acq = 'ei'
|
||||||
|
elif self.bo_acq_fcn == "Probability of Improvement":
|
||||||
|
acq = 'pi'
|
||||||
|
elif self.bo_acq_fcn == "Upper Confidence Bound":
|
||||||
|
acq = 'cb'
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
home_dir = os.path.expanduser('~')
|
home_dir = os.path.expanduser('~')
|
||||||
file_path = os.path.join(home_dir, 'Documents/IntRLResults')
|
file_path = os.path.join(home_dir, 'Documents/IntRLResults')
|
||||||
filename = self.bo_metric + '-'\
|
filename = env + '-' + acq + '-' + self.bo_metric + '-' \
|
||||||
+ str(self.bo_metric_parameter) + '-' \
|
+ str(self.bo_metric_parameter) + '-' \
|
||||||
+ str(self.bo_nr_weights) + '-' + str(time.time())
|
+ str(self.bo_nr_weights) + '-' + str(time.time())
|
||||||
filename = filename.replace('.', '_') + '.csv'
|
filename = filename.replace('.', '_') + '.csv'
|
||||||
path = os.path.join(file_path, filename)
|
path = os.path.join(file_path, filename)
|
||||||
self.get_logger().info(path)
|
|
||||||
|
|
||||||
np.savetxt(path, self.reward, delimiter=',')
|
np.savetxt(path, self.reward, delimiter=',')
|
||||||
|
|
||||||
|
active_rl_request = ActiveRL()
|
||||||
|
|
||||||
|
if self.seed is None:
|
||||||
|
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||||
|
else:
|
||||||
|
seed = self.seed
|
||||||
|
|
||||||
|
active_rl_request.env = self.bo_env
|
||||||
|
active_rl_request.seed = seed
|
||||||
|
active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist()
|
||||||
|
active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist()
|
||||||
|
active_rl_request.final_run = True
|
||||||
|
|
||||||
|
self.get_logger().info('Calling: Active RL')
|
||||||
|
self.active_rl_pub.publish(active_rl_request)
|
||||||
|
|
||||||
self.get_logger().info('Responding: Active BO')
|
self.get_logger().info('Responding: Active BO')
|
||||||
self.active_bo_pub.publish(bo_response)
|
self.active_bo_pub.publish(bo_response)
|
||||||
self.reset_bo_request()
|
self.reset_bo_request()
|
||||||
@ -245,6 +280,7 @@ class ActiveBOTopic(Node):
|
|||||||
active_rl_request.seed = seed
|
active_rl_request.seed = seed
|
||||||
active_rl_request.policy = old_policy.tolist()
|
active_rl_request.policy = old_policy.tolist()
|
||||||
active_rl_request.weights = old_weights.tolist()
|
active_rl_request.weights = old_weights.tolist()
|
||||||
|
active_rl_request.final_run = False
|
||||||
|
|
||||||
self.get_logger().info('Calling: Active RL')
|
self.get_logger().info('Calling: Active RL')
|
||||||
self.active_rl_pub.publish(active_rl_request)
|
self.active_rl_pub.publish(active_rl_request)
|
||||||
@ -282,6 +318,7 @@ class ActiveBOTopic(Node):
|
|||||||
state_msg.last_user_reward = self.rl_reward
|
state_msg.last_user_reward = self.rl_reward
|
||||||
self.state_pub.publish(state_msg)
|
self.state_pub.publish(state_msg)
|
||||||
|
|
||||||
|
|
||||||
def main(args=None):
|
def main(args=None):
|
||||||
rclpy.init(args=args)
|
rclpy.init(args=args)
|
||||||
|
|
||||||
|
@ -69,6 +69,7 @@ class ActiveRLService(Node):
|
|||||||
self.best_pol_shown = False
|
self.best_pol_shown = False
|
||||||
self.policy_sent = False
|
self.policy_sent = False
|
||||||
self.active_rl_pending = False
|
self.active_rl_pending = False
|
||||||
|
self.final_run = False
|
||||||
|
|
||||||
# Main loop timer object
|
# Main loop timer object
|
||||||
self.mainloop_timer_period = 0.05
|
self.mainloop_timer_period = 0.05
|
||||||
@ -87,6 +88,7 @@ class ActiveRLService(Node):
|
|||||||
self.rl_seed = msg.seed
|
self.rl_seed = msg.seed
|
||||||
self.rl_policy = np.array(msg.policy, dtype=np.float32)
|
self.rl_policy = np.array(msg.policy, dtype=np.float32)
|
||||||
self.rl_weights = msg.weights
|
self.rl_weights = msg.weights
|
||||||
|
self.final_run = msg.final_run
|
||||||
|
|
||||||
if self.rl_env == "Mountain Car":
|
if self.rl_env == "Mountain Car":
|
||||||
self.env = Continuous_MountainCarEnv(render_mode="rgb_array")
|
self.env = Continuous_MountainCarEnv(render_mode="rgb_array")
|
||||||
@ -151,6 +153,7 @@ class ActiveRLService(Node):
|
|||||||
|
|
||||||
def mainloop_callback(self):
|
def mainloop_callback(self):
|
||||||
if self.active_rl_pending:
|
if self.active_rl_pending:
|
||||||
|
if not self.final_run:
|
||||||
if not self.best_pol_shown:
|
if not self.best_pol_shown:
|
||||||
if not self.policy_sent:
|
if not self.policy_sent:
|
||||||
self.rl_step = 0
|
self.rl_step = 0
|
||||||
@ -201,6 +204,29 @@ class ActiveRLService(Node):
|
|||||||
self.best_pol_shown = False
|
self.best_pol_shown = False
|
||||||
self.eval_response_received = False
|
self.eval_response_received = False
|
||||||
self.active_rl_pending = False
|
self.active_rl_pending = False
|
||||||
|
else:
|
||||||
|
if not self.policy_sent:
|
||||||
|
self.rl_step = 0
|
||||||
|
self.rl_reward = 0.0
|
||||||
|
self.env.reset(seed=self.rl_seed)
|
||||||
|
|
||||||
|
eval_request = ActiveRL()
|
||||||
|
eval_request.policy = self.rl_policy.tolist()
|
||||||
|
eval_request.weights = self.rl_weights
|
||||||
|
|
||||||
|
self.eval_pub.publish(eval_request)
|
||||||
|
self.get_logger().info('Active RL: Called!')
|
||||||
|
self.get_logger().info('Active RL: Waiting for Eval!')
|
||||||
|
|
||||||
|
self.policy_sent = True
|
||||||
|
|
||||||
|
done = self.next_image(self.rl_policy)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
self.rl_step = 0
|
||||||
|
self.rl_reward = 0.0
|
||||||
|
self.final_run = False
|
||||||
|
self.active_rl_pending = False
|
||||||
|
|
||||||
|
|
||||||
def main(args=None):
|
def main(args=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user