prepared for vertical line plot in policy plot

This commit is contained in:
Niko Feith 2023-05-25 17:52:38 +02:00
parent c6a99c6c3b
commit 08eac34c2a
5 changed files with 33 additions and 13 deletions

View File

@ -1,5 +1,6 @@
int16 height int16 height
int16 width int16 width
uint16 current_time
uint16[] red uint16[] red
uint16[] green uint16[] green
uint16[] blue uint16[] blue

View File

@ -43,11 +43,11 @@ class BayesianOptimization:
self.episode = 0 self.episode = 0
self.best_reward = np.empty((1, 1)) self.best_reward = np.empty((1, 1))
def runner(self, policy): def runner(self, policy, seed=None):
env_reward = 0.0 env_reward = 0.0
step_count = 0 step_count = 0
self.env.reset() self.env.reset(seed=seed)
for i in range(len(policy)): for i in range(len(policy)):
action = policy[i] action = policy[i]
@ -67,11 +67,11 @@ class BayesianOptimization:
env_reward += distance * self.distance_penalty env_reward += distance * self.distance_penalty
self.counter_array = np.vstack((self.counter_array, step_count)) self.counter_array = np.vstack((self.counter_array, step_count))
self.env.reset() self.env.reset(seed=seed)
return env_reward, step_count return env_reward, step_count
def initialize(self): def initialize(self, seed=None):
self.env.reset() self.env.reset(seed=seed)
self.reset_bo() self.reset_bo()
self.X = np.zeros((self.nr_init, self.nr_policy_weights)) self.X = np.zeros((self.nr_init, self.nr_policy_weights))
@ -124,11 +124,11 @@ class BayesianOptimization:
return x_next return x_next
def eval_new_observation(self, x_next): def eval_new_observation(self, x_next, seed=None):
self.policy_model.weights = x_next self.policy_model.weights = x_next
policy = self.policy_model.rollout() policy = self.policy_model.rollout()
reward, step_count = self.runner(policy) reward, step_count = self.runner(policy, seed=seed)
self.X = np.vstack((self.X, x_next)) self.X = np.vstack((self.X, x_next))
self.Y = np.vstack((self.Y, reward)) self.Y = np.vstack((self.Y, reward))

View File

@ -153,6 +153,9 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
theta_dot = theta_dot + self.tau * thetaacc theta_dot = theta_dot + self.tau * thetaacc
theta = theta + self.tau * theta_dot theta = theta + self.tau * theta_dot
try:
self.state = (x, x_dot[0], theta, theta_dot[0])
except:
self.state = (x, x_dot, theta, theta_dot) self.state = (x, x_dot, theta, theta_dot)
terminated = bool( terminated = bool(
@ -181,6 +184,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
if self.render_mode == "human": if self.render_mode == "human":
self.render() self.render()
return np.array(self.state, dtype=np.float32), reward, terminated, False, {} return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
def reset( def reset(

View File

@ -39,6 +39,7 @@ class ActiveBOTopic(Node):
self.active_bo_pending = False self.active_bo_pending = False
self.bo_env = None self.bo_env = None
self.bo_fixed_seed = False
self.bo_nr_weights = None self.bo_nr_weights = None
self.bo_steps = None self.bo_steps = None
self.bo_episodes = None self.bo_episodes = None
@ -47,6 +48,7 @@ class ActiveBOTopic(Node):
self.bo_metric_parameter = None self.bo_metric_parameter = None
self.current_run = 0 self.current_run = 0
self.current_episode = 0 self.current_episode = 0
self.seed = None
# Active Reinforcement Learning Publisher, Subscriber and Message attributes # Active Reinforcement Learning Publisher, Subscriber and Message attributes
self.active_rl_pub = self.create_publisher(ActiveRL, self.active_rl_pub = self.create_publisher(ActiveRL,
@ -80,6 +82,7 @@ class ActiveBOTopic(Node):
def reset_bo_request(self): def reset_bo_request(self):
self.bo_env = None self.bo_env = None
self.bo_fixed_seed = False
self.bo_nr_weights = None self.bo_nr_weights = None
self.bo_steps = None self.bo_steps = None
self.bo_episodes = None self.bo_episodes = None
@ -94,6 +97,7 @@ class ActiveBOTopic(Node):
self.get_logger().info('Active Bayesian Optimization request pending!') self.get_logger().info('Active Bayesian Optimization request pending!')
self.active_bo_pending = True self.active_bo_pending = True
self.bo_env = msg.env self.bo_env = msg.env
self.bo_fixed_seed = msg.fixed_seed
self.bo_nr_weights = msg.nr_weights self.bo_nr_weights = msg.nr_weights
self.bo_steps = msg.max_steps self.bo_steps = msg.max_steps
self.bo_episodes = msg.nr_episodes self.bo_episodes = msg.nr_episodes
@ -107,6 +111,12 @@ class ActiveBOTopic(Node):
self.best_policy = np.zeros((self.bo_steps, self.bo_runs)) self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs)) self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
# set the seed
if self.bo_fixed_seed:
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
else:
self.seed = None
def reset_rl_response(self): def reset_rl_response(self):
self.rl_weights = None self.rl_weights = None
self.rl_final_step = None self.rl_final_step = None
@ -122,6 +132,7 @@ class ActiveBOTopic(Node):
def mainloop_callback(self): def mainloop_callback(self):
if self.active_bo_pending: if self.active_bo_pending:
# set rl environment # set rl environment
if self.bo_env == "Mountain Car": if self.bo_env == "Mountain Car":
self.env = Continuous_MountainCarEnv() self.env = Continuous_MountainCarEnv()
@ -175,6 +186,7 @@ class ActiveBOTopic(Node):
old_policy, _, old_weights = self.BO.get_best_result() old_policy, _, old_weights = self.BO.get_best_result()
active_rl_request.env = self.bo_env active_rl_request.env = self.bo_env
active_rl_request.seed = self.seed
active_rl_request.policy = old_policy.tolist() active_rl_request.policy = old_policy.tolist()
active_rl_request.weights = old_weights.tolist() active_rl_request.weights = old_weights.tolist()

View File

@ -37,6 +37,7 @@ class ActiveRLService(Node):
self.active_rl_pending = False self.active_rl_pending = False
self.rl_env = None self.rl_env = None
self.rl_seed = None
self.rl_policy = None self.rl_policy = None
self.rl_weights = None self.rl_weights = None
self.rl_reward = 0.0 self.rl_reward = 0.0
@ -75,11 +76,13 @@ class ActiveRLService(Node):
def reset_rl_request(self): def reset_rl_request(self):
self.rl_env = None self.rl_env = None
self.rl_seed = None
self.rl_policy = None self.rl_policy = None
self.rl_weights = None self.rl_weights = None
def active_rl_callback(self, msg): def active_rl_callback(self, msg):
self.rl_env = msg.env self.rl_env = msg.env
self.rl_seed = msg.seed
self.rl_policy = np.array(msg.policy, dtype=np.float32) self.rl_policy = np.array(msg.policy, dtype=np.float32)
self.rl_weights = msg.weights self.rl_weights = msg.weights
@ -95,7 +98,7 @@ class ActiveRLService(Node):
raise NotImplementedError raise NotImplementedError
self.get_logger().info('Active RL: Called!') self.get_logger().info('Active RL: Called!')
self.env.reset() self.env.reset(seed=self.rl_seed)
self.active_rl_pending = True self.active_rl_pending = True
def reset_eval_request(self): def reset_eval_request(self):
@ -107,7 +110,7 @@ class ActiveRLService(Node):
self.eval_weights = msg.weights self.eval_weights = msg.weights
self.get_logger().info('Active RL Eval: Responded!') self.get_logger().info('Active RL Eval: Responded!')
self.env.reset() self.env.reset(seed=self.rl_seed)
self.eval_response_received = True self.eval_response_received = True
def next_image(self, policy): def next_image(self, policy):
@ -130,6 +133,7 @@ class ActiveRLService(Node):
feedback_msg.height = rgb_shape[0] feedback_msg.height = rgb_shape[0]
feedback_msg.width = rgb_shape[1] feedback_msg.width = rgb_shape[1]
feedback_msg.current_time = self.rl_step
feedback_msg.red = red feedback_msg.red = red
feedback_msg.green = green feedback_msg.green = green
feedback_msg.blue = blue feedback_msg.blue = blue
@ -149,6 +153,7 @@ class ActiveRLService(Node):
if done: if done:
self.rl_step = 0 self.rl_step = 0
self.rl_reward = 0.0 self.rl_reward = 0.0
self.env.reset(seed=self.rl_seed)
eval_request = ActiveRL() eval_request = ActiveRL()
eval_request.policy = self.rl_policy.tolist() eval_request.policy = self.rl_policy.tolist()
@ -158,8 +163,6 @@ class ActiveRLService(Node):
self.get_logger().info('Active RL: Called!') self.get_logger().info('Active RL: Called!')
self.get_logger().info('Active RL: Waiting for Eval!') self.get_logger().info('Active RL: Waiting for Eval!')
self.env.reset()
self.best_pol_shown = True self.best_pol_shown = True
elif self.best_pol_shown: elif self.best_pol_shown:
@ -177,7 +180,7 @@ class ActiveRLService(Node):
self.active_rl_pub.publish(rl_response) self.active_rl_pub.publish(rl_response)
self.env.reset() self.env.reset(seed=self.rl_seed)
# reset flags and attributes # reset flags and attributes
self.reset_eval_request() self.reset_eval_request()