bo 2d tested
This commit is contained in:
parent
597579bd98
commit
40132341e5
@ -2,3 +2,5 @@ float64[] best_policy
|
||||
float64[] best_weights
|
||||
float64[] reward_mean
|
||||
float64[] reward_std
|
||||
uint16 nr_weights
|
||||
uint16 nr_steps
|
@ -1,2 +1,4 @@
|
||||
float64[] policy
|
||||
float64[] weights
|
||||
uint16 nr_weights
|
||||
uint16 nr_steps
|
@ -4,4 +4,6 @@ bool display_run
|
||||
uint8 interactive_run
|
||||
float64[] policy
|
||||
float64[] weights
|
||||
uint16 nr_weights
|
||||
uint16 nr_steps
|
||||
uint16 nr_dims
|
@ -31,6 +31,7 @@ class GaussianRBF:
|
||||
def random_weights(self):
|
||||
for dim in range(self.nr_dims):
|
||||
self.weights[:, dim] = self.rng.uniform(self.lowerb, self.upperb, self.nr_weights)
|
||||
return self.weights
|
||||
|
||||
def rollout(self):
|
||||
self.trajectory = np.zeros((self.nr_steps, self.nr_dims))
|
||||
@ -43,7 +44,7 @@ class GaussianRBF:
|
||||
return self.trajectory
|
||||
|
||||
def set_weights(self, x):
|
||||
self.weights = x.reshape(self.nr_weights, self.nr_dims)
|
||||
self.weights = x.reshape((self.nr_weights, self.nr_dims), order='F')
|
||||
|
||||
def get_x(self):
|
||||
return self.weights.reshape(self.nr_weights * self.nr_dims, 1)
|
||||
return self.weights.flatten('F')
|
||||
|
@ -141,8 +141,8 @@ class ActiveBOTopic(Node):
|
||||
self.overwrite = msg.overwrite
|
||||
|
||||
# initialize
|
||||
self.reward = np.zeros((self.bo_runs, self.bo_episodes + self.nr_init - 1))
|
||||
self.best_pol_reward = np.zeros((self.bo_runs, 1))
|
||||
self.reward = np.ones((self.bo_runs, self.bo_episodes + self.nr_init - 1)) * -self.bo_steps
|
||||
self.best_pol_reward = np.ones((self.bo_runs, 1)) * -self.bo_steps
|
||||
self.best_policy = np.zeros((self.bo_runs, self.bo_steps, self.bo_nr_dims))
|
||||
self.best_weights = np.zeros((self.bo_runs, self.bo_nr_weights, self.bo_nr_dims))
|
||||
|
||||
@ -171,7 +171,7 @@ class ActiveBOTopic(Node):
|
||||
self.rl_reward = msg.reward
|
||||
|
||||
try:
|
||||
self.BO.add_new_observation(self.rl_reward, self.rl_weights)
|
||||
self.BO.add_observation(self.rl_reward, self.rl_weights)
|
||||
# self.get_logger().info('Active Reinforcement Learning added new observation!')
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Active Reinforcement Learning failed to add new observation: {e}')
|
||||
@ -224,8 +224,10 @@ class ActiveBOTopic(Node):
|
||||
rl_msg.seed = seed
|
||||
rl_msg.display_run = False
|
||||
rl_msg.interactive_run = 2
|
||||
rl_msg.weights = self.BO.policy_model.random_policy().flatten().tolist()
|
||||
rl_msg.policy = self.BO.policy_model.rollout().flatten().tolist()
|
||||
rl_msg.weights = self.BO.policy_model.random_weights().flatten('F').tolist()
|
||||
rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist()
|
||||
rl_msg.nr_weights = self.bo_nr_weights
|
||||
rl_msg.nr_steps = self.bo_steps
|
||||
rl_msg.nr_dims = self.bo_nr_dims
|
||||
|
||||
self.active_rl_pub.publish(rl_msg)
|
||||
@ -236,8 +238,10 @@ class ActiveBOTopic(Node):
|
||||
bo_response = ActiveBOResponse()
|
||||
|
||||
best_policy_idx = np.argmax(self.best_pol_reward)
|
||||
bo_response.best_policy = self.best_policy[best_policy_idx, :, :].flatten().tolist()
|
||||
bo_response.best_weights = self.best_weights[best_policy_idx, :, :].flatten().tolist()
|
||||
bo_response.best_policy = self.best_policy[best_policy_idx, :, :].flatten('F').tolist()
|
||||
bo_response.best_weights = self.best_weights[best_policy_idx, :, :].flatten('F').tolist()
|
||||
bo_response.nr_weights = self.bo_nr_weights
|
||||
bo_response.nr_steps = self.bo_steps
|
||||
|
||||
bo_response.reward_mean = np.mean(self.reward, axis=1).tolist()
|
||||
bo_response.reward_std = np.std(self.reward, axis=1).tolist()
|
||||
@ -282,8 +286,10 @@ class ActiveBOTopic(Node):
|
||||
active_rl_request.seed = seed
|
||||
active_rl_request.display_run = True
|
||||
active_rl_request.interactive_run = 1
|
||||
active_rl_request.policy = self.best_policy[best_policy_idx, :, :].flatten().tolist()
|
||||
active_rl_request.weights = self.best_weights[best_policy_idx, :, :].flatten().tolist()
|
||||
active_rl_request.policy = self.best_policy[best_policy_idx, :, :].flatten('F').tolist()
|
||||
active_rl_request.weights = self.best_weights[best_policy_idx, :, :].flatten('F').tolist()
|
||||
active_rl_request.nr_weights = self.bo_nr_weights
|
||||
active_rl_request.nr_steps = self.bo_steps
|
||||
active_rl_request.nr_dims = self.bo_nr_dims
|
||||
|
||||
self.active_rl_pub.publish(active_rl_request)
|
||||
@ -325,10 +331,9 @@ class ActiveBOTopic(Node):
|
||||
self.last_query = self.current_episode
|
||||
self.user_asked = True
|
||||
active_rl_request = ActiveRLRequest()
|
||||
old_policy, y_max, old_weights, _ = self.BO.get_best_result()
|
||||
y_max, old_weights, _ = self.BO.get_best_result()
|
||||
|
||||
# self.get_logger().info(f'Best: {y_max}, w:{old_weights}')
|
||||
# self.get_logger().info(f'Size of Y: {self.BO.Y.shape}, Size of X: {self.BO.X.shape}')
|
||||
self.BO.policy_model.set_weights(old_weights)
|
||||
|
||||
if self.bo_fixed_seed:
|
||||
seed = self.seed
|
||||
@ -339,8 +344,10 @@ class ActiveBOTopic(Node):
|
||||
active_rl_request.seed = seed
|
||||
active_rl_request.display_run = True
|
||||
active_rl_request.interactive_run = 0
|
||||
active_rl_request.policy = old_policy.flatten().tolist()
|
||||
active_rl_request.weights = old_weights.flatten().tolist()
|
||||
active_rl_request.policy = self.BO.policy_model.rollout().flatten('F').tolist()
|
||||
active_rl_request.weights = old_weights.flatten('F').tolist()
|
||||
active_rl_request.nr_weights = self.bo_nr_weights
|
||||
active_rl_request.nr_steps = self.bo_steps
|
||||
active_rl_request.nr_dims = self.bo_nr_dims
|
||||
|
||||
# self.get_logger().info('Calling: Active RL')
|
||||
@ -360,8 +367,10 @@ class ActiveBOTopic(Node):
|
||||
rl_msg.seed = seed
|
||||
rl_msg.display_run = False
|
||||
rl_msg.interactive_run = 2
|
||||
rl_msg.policy = self.BO.policy_model.rollout().flatten().tolist()
|
||||
rl_msg.weights = x_next.flatten().tolist()
|
||||
rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist()
|
||||
rl_msg.weights = x_next.flatten('F').tolist()
|
||||
rl_msg.nr_weights = self.bo_nr_weights
|
||||
rl_msg.nr_steps = self.bo_steps
|
||||
rl_msg.nr_dims = self.bo_nr_dims
|
||||
|
||||
self.rl_pending = True
|
||||
@ -374,7 +383,6 @@ class ActiveBOTopic(Node):
|
||||
self.current_episode += 1
|
||||
|
||||
else:
|
||||
self.best_policy[self.current_run, :, :], \
|
||||
self.best_pol_reward[self.current_run, :], \
|
||||
self.best_weights[self.current_run, :, :], idx = self.BO.get_best_result()
|
||||
|
||||
|
@ -64,6 +64,8 @@ class ActiveRL(Node):
|
||||
# RL Environments
|
||||
self.env = None
|
||||
self.rl_spec = None
|
||||
self.nr_weights = None
|
||||
self.nr_steps = None
|
||||
self.rl_dims = None
|
||||
self.pol_dims = None
|
||||
|
||||
@ -89,6 +91,8 @@ class ActiveRL(Node):
|
||||
self.interactive_run = 0
|
||||
self.display_run = False
|
||||
self.rl_dims = None
|
||||
self.nr_weights = None
|
||||
self.nr_steps = None
|
||||
self.pol_dims = None
|
||||
|
||||
def rl_callback(self, msg):
|
||||
@ -97,19 +101,12 @@ class ActiveRL(Node):
|
||||
self.display_run = msg.display_run
|
||||
self.rl_dims = msg.nr_dims
|
||||
self.rl_weights = msg.weights
|
||||
pol = msg.policy
|
||||
self.pol_dims = (len(pol)/self.rl_dims, self.rl_dims)
|
||||
self.rl_policy = np.array(pol, dtype=np.float64).reshape(self.pol_dims)
|
||||
self.nr_weights = msg.nr_weights
|
||||
self.nr_steps = msg.nr_steps
|
||||
self.pol_dims = (self.nr_steps, self.rl_dims)
|
||||
self.rl_policy = np.array(msg.policy, dtype=np.float64).reshape(self.pol_dims, order='F')
|
||||
self.interactive_run = msg.interactive_run
|
||||
|
||||
if self.rl_env == "Reacher":
|
||||
random_state = np.random.RandomState(seed=self.rl_seed)
|
||||
self.env = suite.load('reacher', 'hard', task_kwargs={'random': random_state})
|
||||
self.rl_spec = self.env.action_spec()
|
||||
self.env.reset()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.rl_pending = True
|
||||
self.policy_sent = False
|
||||
|
||||
@ -118,12 +115,11 @@ class ActiveRL(Node):
|
||||
self.eval_weights = None
|
||||
|
||||
def rl_eval_callback(self, msg):
|
||||
self.eval_policy = np.array(msg.policy, dtype=np.float64).reshape(self.pol_dims)
|
||||
self.eval_policy = np.array(msg.policy, dtype=np.float64).reshape(self.pol_dims, order='F')
|
||||
self.eval_weights = msg.weights
|
||||
self.weight_preference = msg.weight_preference
|
||||
|
||||
self.get_logger().info('Active RL Eval: Responded!')
|
||||
self.env.reset()
|
||||
self.eval_response_received = True
|
||||
|
||||
def step(self, policy, display_run):
|
||||
@ -133,14 +129,16 @@ class ActiveRL(Node):
|
||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||
output = self.env.step(action_clipped.astype(np.float64))
|
||||
|
||||
self.rl_step += 1
|
||||
|
||||
if output.reward != 0.0:
|
||||
self.rl_reward += output.reward * 10
|
||||
done = True
|
||||
else:
|
||||
self.rl_step -= 1.0
|
||||
self.rl_reward -= 1.0
|
||||
|
||||
if display_run:
|
||||
rgb_array = self.env.physics.render(camera_id=0, height=400, width=600)
|
||||
rgb_array = self.env.physics.render(camera_id=0, height=320, width=480)
|
||||
rgb_shape = rgb_array.shape
|
||||
|
||||
red = rgb_array[:, :, 0].flatten().tolist()
|
||||
@ -168,8 +166,6 @@ class ActiveRL(Node):
|
||||
step_count = 0
|
||||
done = False
|
||||
|
||||
self.env.reset()
|
||||
|
||||
for i in range(policy.shape[0]):
|
||||
action = policy[i, :]
|
||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||
@ -180,7 +176,7 @@ class ActiveRL(Node):
|
||||
self.rl_reward += output.reward * 10
|
||||
done = True
|
||||
else:
|
||||
self.rl_step -= 1.0
|
||||
self.rl_reward -= 1.0
|
||||
|
||||
if done:
|
||||
break
|
||||
@ -193,11 +189,22 @@ class ActiveRL(Node):
|
||||
if not self.policy_sent:
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
if self.rl_env == "Reacher":
|
||||
np.random.seed(self.rl_seed)
|
||||
random_state = np.random.RandomState(seed=self.rl_seed)
|
||||
self.env = suite.load('reacher',
|
||||
'hard',
|
||||
task_kwargs={'random': random_state})
|
||||
self.rl_spec = self.env.action_spec()
|
||||
self.env.reset()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
eval_request = ActiveRLEvalRequest()
|
||||
eval_request.policy = self.rl_policy.flatten().tolist()
|
||||
eval_request.policy = self.rl_policy.flatten('F').tolist()
|
||||
eval_request.weights = self.rl_weights
|
||||
eval_request.nr_steps = self.nr_steps
|
||||
eval_request.nr_weights = self.nr_weights
|
||||
|
||||
self.eval_pub.publish(eval_request)
|
||||
self.get_logger().info('Active RL: Called!')
|
||||
@ -211,6 +218,16 @@ class ActiveRL(Node):
|
||||
self.best_pol_shown = True
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
if self.rl_env == "Reacher":
|
||||
np.random.seed(self.rl_seed)
|
||||
random_state = np.random.RandomState(seed=self.rl_seed)
|
||||
self.env = suite.load('reacher',
|
||||
'hard',
|
||||
task_kwargs={'random': random_state})
|
||||
self.rl_spec = self.env.action_spec()
|
||||
self.env.reset()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
elif self.best_pol_shown:
|
||||
if not self.eval_response_received:
|
||||
@ -228,8 +245,6 @@ class ActiveRL(Node):
|
||||
|
||||
self.active_rl_pub.publish(rl_response)
|
||||
|
||||
self.env.reset()
|
||||
|
||||
# reset flags and attributes
|
||||
self.reset_eval_request()
|
||||
self.reset_rl_request()
|
||||
@ -245,10 +260,18 @@ class ActiveRL(Node):
|
||||
if not self.policy_sent:
|
||||
self.rl_step = 0
|
||||
self.rl_reward = 0.0
|
||||
|
||||
if self.rl_env == "Reacher":
|
||||
np.random.seed(self.rl_seed)
|
||||
random_state = np.random.RandomState(seed=self.rl_seed)
|
||||
self.env = suite.load('reacher', 'hard', task_kwargs={'random': random_state})
|
||||
self.rl_spec = self.env.action_spec()
|
||||
self.env.reset()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
eval_request = ActiveRLEvalRequest()
|
||||
eval_request.policy = self.rl_policy.flatten().tolist()
|
||||
eval_request.policy = self.rl_policy.flatten('F').tolist()
|
||||
eval_request.weights = self.rl_weights
|
||||
|
||||
self.eval_pub.publish(eval_request)
|
||||
@ -265,6 +288,15 @@ class ActiveRL(Node):
|
||||
self.rl_pending = False
|
||||
|
||||
elif self.interactive_run == 2:
|
||||
if self.rl_env == "Reacher":
|
||||
np.random.seed(self.rl_seed)
|
||||
random_state = np.random.RandomState(seed=self.rl_seed)
|
||||
self.env = suite.load('reacher', 'hard', task_kwargs={'random': random_state})
|
||||
self.rl_spec = self.env.action_spec()
|
||||
self.env.reset()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
env_reward, step_count = self.runner(self.rl_policy)
|
||||
|
||||
rl_response = ActiveRLResponse()
|
||||
|
30
src/active_bo_ros/launch/interactive_bo_2d.launch.py
Executable file
30
src/active_bo_ros/launch/interactive_bo_2d.launch.py
Executable file
@ -0,0 +1,30 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.actions import IncludeLaunchDescription
|
||||
from launch.launch_description_sources import PythonLaunchDescriptionSource
|
||||
|
||||
from ament_index_python import get_package_share_directory
|
||||
import os
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
return LaunchDescription([
|
||||
IncludeLaunchDescription(
|
||||
PythonLaunchDescriptionSource(
|
||||
os.path.join(
|
||||
get_package_share_directory('active_bo_ros'),
|
||||
'rosbridge_server.launch.py'
|
||||
)
|
||||
)
|
||||
),
|
||||
Node(
|
||||
package='active_bo_ros',
|
||||
executable='interactive_bo_2d',
|
||||
name='interactive_bo_2d'
|
||||
),
|
||||
Node(
|
||||
package='active_bo_ros',
|
||||
executable='interactive_rl_2d',
|
||||
name='interactive_rl_2d'
|
||||
),
|
||||
])
|
@ -35,7 +35,9 @@ setup(
|
||||
'bo_torch_srv = active_bo_ros.bo_torch_service:main',
|
||||
'active_bo_topic = active_bo_ros.active_bo_topic:main',
|
||||
'active_rl_topic = active_bo_ros.active_rl_topic:main',
|
||||
'interactive_bo = active_bo_ros.interactive_bo:main'
|
||||
'interactive_bo = active_bo_ros.interactive_bo:main',
|
||||
'interactive_bo_2d = active_bo_ros.interactive_bo_2d:main',
|
||||
'interactive_rl_2d = active_bo_ros.interactive_rl_2d:main'
|
||||
],
|
||||
},
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user