prepared for vertical line plot in policy plot
This commit is contained in:
parent
c6a99c6c3b
commit
08eac34c2a
@ -1,5 +1,6 @@
|
||||
int16 height
|
||||
int16 width
|
||||
uint16 current_time
|
||||
uint16[] red
|
||||
uint16[] green
|
||||
uint16[] blue
|
@ -43,11 +43,11 @@ class BayesianOptimization:
|
||||
self.episode = 0
|
||||
self.best_reward = np.empty((1, 1))
|
||||
|
||||
def runner(self, policy):
|
||||
def runner(self, policy, seed=None):
|
||||
env_reward = 0.0
|
||||
step_count = 0
|
||||
|
||||
self.env.reset()
|
||||
self.env.reset(seed=seed)
|
||||
|
||||
for i in range(len(policy)):
|
||||
action = policy[i]
|
||||
@ -67,11 +67,11 @@ class BayesianOptimization:
|
||||
env_reward += distance * self.distance_penalty
|
||||
self.counter_array = np.vstack((self.counter_array, step_count))
|
||||
|
||||
self.env.reset()
|
||||
self.env.reset(seed=seed)
|
||||
return env_reward, step_count
|
||||
|
||||
def initialize(self):
|
||||
self.env.reset()
|
||||
def initialize(self, seed=None):
|
||||
self.env.reset(seed=seed)
|
||||
self.reset_bo()
|
||||
|
||||
self.X = np.zeros((self.nr_init, self.nr_policy_weights))
|
||||
@ -124,11 +124,11 @@ class BayesianOptimization:
|
||||
|
||||
return x_next
|
||||
|
||||
def eval_new_observation(self, x_next):
|
||||
def eval_new_observation(self, x_next, seed=None):
|
||||
self.policy_model.weights = x_next
|
||||
policy = self.policy_model.rollout()
|
||||
|
||||
reward, step_count = self.runner(policy)
|
||||
reward, step_count = self.runner(policy, seed=seed)
|
||||
|
||||
self.X = np.vstack((self.X, x_next))
|
||||
self.Y = np.vstack((self.Y, reward))
|
||||
|
@ -153,7 +153,10 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
theta_dot = theta_dot + self.tau * thetaacc
|
||||
theta = theta + self.tau * theta_dot
|
||||
|
||||
self.state = (x, x_dot, theta, theta_dot)
|
||||
try:
|
||||
self.state = (x, x_dot[0], theta, theta_dot[0])
|
||||
except:
|
||||
self.state = (x, x_dot, theta, theta_dot)
|
||||
|
||||
terminated = bool(
|
||||
x < -self.x_threshold
|
||||
@ -181,6 +184,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
|
||||
if self.render_mode == "human":
|
||||
self.render()
|
||||
|
||||
return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
|
||||
|
||||
def reset(
|
||||
|
@ -39,6 +39,7 @@ class ActiveBOTopic(Node):
|
||||
|
||||
self.active_bo_pending = False
|
||||
self.bo_env = None
|
||||
self.bo_fixed_seed = False
|
||||
self.bo_nr_weights = None
|
||||
self.bo_steps = None
|
||||
self.bo_episodes = None
|
||||
@ -47,6 +48,7 @@ class ActiveBOTopic(Node):
|
||||
self.bo_metric_parameter = None
|
||||
self.current_run = 0
|
||||
self.current_episode = 0
|
||||
self.seed = None
|
||||
|
||||
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
|
||||
self.active_rl_pub = self.create_publisher(ActiveRL,
|
||||
@ -80,6 +82,7 @@ class ActiveBOTopic(Node):
|
||||
|
||||
def reset_bo_request(self):
|
||||
self.bo_env = None
|
||||
self.bo_fixed_seed = False
|
||||
self.bo_nr_weights = None
|
||||
self.bo_steps = None
|
||||
self.bo_episodes = None
|
||||
@ -94,6 +97,7 @@ class ActiveBOTopic(Node):
|
||||
self.get_logger().info('Active Bayesian Optimization request pending!')
|
||||
self.active_bo_pending = True
|
||||
self.bo_env = msg.env
|
||||
self.bo_fixed_seed = msg.fixed_seed
|
||||
self.bo_nr_weights = msg.nr_weights
|
||||
self.bo_steps = msg.max_steps
|
||||
self.bo_episodes = msg.nr_episodes
|
||||
@ -107,6 +111,12 @@ class ActiveBOTopic(Node):
|
||||
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
|
||||
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
|
||||
|
||||
# set the seed
|
||||
if self.bo_fixed_seed:
|
||||
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
else:
|
||||
self.seed = None
|
||||
|
||||
def reset_rl_response(self):
|
||||
self.rl_weights = None
|
||||
self.rl_final_step = None
|
||||
@ -122,6 +132,7 @@ class ActiveBOTopic(Node):
|
||||
|
||||
def mainloop_callback(self):
|
||||
if self.active_bo_pending:
|
||||
|
||||
# set rl environment
|
||||
if self.bo_env == "Mountain Car":
|
||||
self.env = Continuous_MountainCarEnv()
|
||||
@ -175,6 +186,7 @@ class ActiveBOTopic(Node):
|
||||
old_policy, _, old_weights = self.BO.get_best_result()
|
||||
|
||||
active_rl_request.env = self.bo_env
|
||||
active_rl_request.seed = self.seed
|
||||
active_rl_request.policy = old_policy.tolist()
|
||||
active_rl_request.weights = old_weights.tolist()
|
||||
|
||||
|
@ -37,6 +37,7 @@ class ActiveRLService(Node):
|
||||
|
||||
self.active_rl_pending = False
|
||||
self.rl_env = None
|
||||
self.rl_seed = None
|
||||
self.rl_policy = None
|
||||
self.rl_weights = None
|
||||
self.rl_reward = 0.0
|
||||
@ -75,11 +76,13 @@ class ActiveRLService(Node):
|
||||
|
||||
def reset_rl_request(self):
|
||||
self.rl_env = None
|
||||
self.rl_seed = None
|
||||
self.rl_policy = None
|
||||
self.rl_weights = None
|
||||
|
||||
def active_rl_callback(self, msg):
|
||||
self.rl_env = msg.env
|
||||
self.rl_seed = msg.seed
|
||||
self.rl_policy = np.array(msg.policy, dtype=np.float32)
|
||||
self.rl_weights = msg.weights
|
||||
|
||||
@ -95,7 +98,7 @@ class ActiveRLService(Node):
|
||||
raise NotImplementedError
|
||||
|
||||
self.get_logger().info('Active RL: Called!')
|
||||
self.env.reset()
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
self.active_rl_pending = True
|
||||
|
||||
def reset_eval_request(self):
|
||||
@ -107,7 +110,7 @@ class ActiveRLService(Node):
|
||||
self.eval_weights = msg.weights
|
||||
|
||||
self.get_logger().info('Active RL Eval: Responded!')
|
||||
self.env.reset()
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
self.eval_response_received = True
|
||||
|
||||
def next_image(self, policy):
|
||||
@ -130,6 +133,7 @@ class ActiveRLService(Node):
|
||||
|
||||
feedback_msg.height = rgb_shape[0]
|
||||
feedback_msg.width = rgb_shape[1]
|
||||
feedback_msg.current_time = self.rl_step
|
||||
feedback_msg.red = red
|
||||
feedback_msg.green = green
|
||||
feedback_msg.blue = blue
|
||||
@ -149,6 +153,7 @@ class ActiveRLService(Node):
|
||||
if done:
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
|
||||
eval_request = ActiveRL()
|
||||
eval_request.policy = self.rl_policy.tolist()
|
||||
@ -158,8 +163,6 @@ class ActiveRLService(Node):
|
||||
self.get_logger().info('Active RL: Called!')
|
||||
self.get_logger().info('Active RL: Waiting for Eval!')
|
||||
|
||||
self.env.reset()
|
||||
|
||||
self.best_pol_shown = True
|
||||
|
||||
elif self.best_pol_shown:
|
||||
@ -177,7 +180,7 @@ class ActiveRLService(Node):
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
|
||||
self.env.reset()
|
||||
self.env.reset(seed=self.rl_seed)
|
||||
|
||||
# reset flags and attributes
|
||||
self.reset_eval_request()
|
||||
|
Loading…
Reference in New Issue
Block a user