random metric works first results
This commit is contained in:
parent
55b21d667a
commit
334f64e22d
@ -1,4 +1,5 @@
|
||||
string env
|
||||
uint32 seed
|
||||
bool final_run
|
||||
float32[] policy
|
||||
float32[] weights
|
@ -21,7 +21,6 @@ class BayesianOptimization:
|
||||
self.episode = 0
|
||||
self.counter_array = np.empty((1, 1))
|
||||
self.best_reward = np.empty((1, 1))
|
||||
self.distance_penalty = 0
|
||||
|
||||
self.nr_policy_weights = nr_weights
|
||||
self.nr_steps = nr_steps
|
||||
@ -63,8 +62,6 @@ class BayesianOptimization:
|
||||
break
|
||||
|
||||
if not done and i == len(policy):
|
||||
distance = -(self.env.goal_position - output[0][0])
|
||||
env_reward += distance * self.distance_penalty
|
||||
self.counter_array = np.vstack((self.counter_array, step_count))
|
||||
|
||||
self.env.reset(seed=seed)
|
||||
@ -157,11 +154,11 @@ class BayesianOptimization:
|
||||
self.episode += 1
|
||||
|
||||
def get_best_result(self):
|
||||
y_hat = self.GP.predict(self.X)
|
||||
idx = np.argmax(y_hat)
|
||||
y_max = np.max(self.Y)
|
||||
idx = np.argmax(self.Y)
|
||||
x_max = self.X[idx, :]
|
||||
|
||||
self.policy_model.weights = x_max
|
||||
best_policy = self.policy_model.rollout().reshape(-1,)
|
||||
|
||||
return best_policy, y_hat[idx], x_max
|
||||
return best_policy, y_max, x_max
|
||||
|
@ -189,17 +189,52 @@ class ActiveBOTopic(Node):
|
||||
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
||||
|
||||
if self.save_result:
|
||||
if self.bo_env == "Mountain Car":
|
||||
env = 'mc'
|
||||
elif self.bo_env == "Cartpole":
|
||||
env = 'cp'
|
||||
elif self.bo_env == "Acrobot":
|
||||
env = 'ab'
|
||||
elif self.bo_env == "Pendulum":
|
||||
env = 'pd'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.bo_acq_fcn == "Expected Improvement":
|
||||
acq = 'ei'
|
||||
elif self.bo_acq_fcn == "Probability of Improvement":
|
||||
acq = 'pi'
|
||||
elif self.bo_acq_fcn == "Upper Confidence Bound":
|
||||
acq = 'cb'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
home_dir = os.path.expanduser('~')
|
||||
file_path = os.path.join(home_dir, 'Documents/IntRLResults')
|
||||
filename = self.bo_metric + '-'\
|
||||
filename = env + '-' + acq + '-' + self.bo_metric + '-' \
|
||||
+ str(self.bo_metric_parameter) + '-' \
|
||||
+ str(self.bo_nr_weights) + '-' + str(time.time())
|
||||
filename = filename.replace('.', '_') + '.csv'
|
||||
path = os.path.join(file_path, filename)
|
||||
self.get_logger().info(path)
|
||||
|
||||
np.savetxt(path, self.reward, delimiter=',')
|
||||
|
||||
active_rl_request = ActiveRL()
|
||||
|
||||
if self.seed is None:
|
||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
else:
|
||||
seed = self.seed
|
||||
|
||||
active_rl_request.env = self.bo_env
|
||||
active_rl_request.seed = seed
|
||||
active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist()
|
||||
active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist()
|
||||
active_rl_request.final_run = True
|
||||
|
||||
self.get_logger().info('Calling: Active RL')
|
||||
self.active_rl_pub.publish(active_rl_request)
|
||||
|
||||
self.get_logger().info('Responding: Active BO')
|
||||
self.active_bo_pub.publish(bo_response)
|
||||
self.reset_bo_request()
|
||||
@ -245,6 +280,7 @@ class ActiveBOTopic(Node):
|
||||
active_rl_request.seed = seed
|
||||
active_rl_request.policy = old_policy.tolist()
|
||||
active_rl_request.weights = old_weights.tolist()
|
||||
active_rl_request.final_run = False
|
||||
|
||||
self.get_logger().info('Calling: Active RL')
|
||||
self.active_rl_pub.publish(active_rl_request)
|
||||
@ -282,6 +318,7 @@ class ActiveBOTopic(Node):
|
||||
state_msg.last_user_reward = self.rl_reward
|
||||
self.state_pub.publish(state_msg)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
|
||||
|
@ -69,6 +69,7 @@ class ActiveRLService(Node):
|
||||
self.best_pol_shown = False
|
||||
self.policy_sent = False
|
||||
self.active_rl_pending = False
|
||||
self.final_run = False
|
||||
|
||||
# Main loop timer object
|
||||
self.mainloop_timer_period = 0.05
|
||||
@ -87,6 +88,7 @@ class ActiveRLService(Node):
|
||||
self.rl_seed = msg.seed
|
||||
self.rl_policy = np.array(msg.policy, dtype=np.float32)
|
||||
self.rl_weights = msg.weights
|
||||
self.final_run = msg.final_run
|
||||
|
||||
if self.rl_env == "Mountain Car":
|
||||
self.env = Continuous_MountainCarEnv(render_mode="rgb_array")
|
||||
@ -151,6 +153,7 @@ class ActiveRLService(Node):
|
||||
|
||||
def mainloop_callback(self):
|
||||
if self.active_rl_pending:
|
||||
if not self.final_run:
|
||||
if not self.best_pol_shown:
|
||||
if not self.policy_sent:
|
||||
self.rl_step = 0
|
||||
@ -201,6 +204,29 @@ class ActiveRLService(Node):
|
||||
self.best_pol_shown = False
|
||||
self.eval_response_received = False
|
||||
self.active_rl_pending = False
|
||||
else:
|
||||
if not self.policy_sent:
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
|
||||
eval_request = ActiveRL()
|
||||
eval_request.policy = self.rl_policy.tolist()
|
||||
eval_request.weights = self.rl_weights
|
||||
|
||||
self.eval_pub.publish(eval_request)
|
||||
self.get_logger().info('Active RL: Called!')
|
||||
self.get_logger().info('Active RL: Waiting for Eval!')
|
||||
|
||||
self.policy_sent = True
|
||||
|
||||
done = self.next_image(self.rl_policy)
|
||||
|
||||
if done:
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
self.final_run = False
|
||||
self.active_rl_pending = False
|
||||
|
||||
|
||||
def main(args=None):
|
||||
|
Loading…
Reference in New Issue
Block a user