changed back to gaussian basis from dmp
This commit is contained in:
parent
1dbad3fa0d
commit
54c196715e
@ -4,7 +4,6 @@ from active_bo_msgs.msg import ActiveBOResponse
|
||||
from active_bo_msgs.msg import ActiveRLRequest
|
||||
from active_bo_msgs.msg import ActiveRLResponse
|
||||
from active_bo_msgs.msg import ActiveBOState
|
||||
from active_bo_msgs.msg import DMP
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
@ -60,7 +59,7 @@ class ActiveBOTopic(Node):
|
||||
self.save_result = False
|
||||
|
||||
# Active Reinforcement Learning Publisher, Subscriber and Message attributes
|
||||
self.active_rl_pub = self.create_publisher(DMP,
|
||||
self.active_rl_pub = self.create_publisher(ActiveRLRequest,
|
||||
'active_rl_request',
|
||||
1, callback_group=rl_callback_group)
|
||||
self.active_rl_sub = self.create_subscription(ActiveRLResponse,
|
||||
@ -143,7 +142,7 @@ class ActiveBOTopic(Node):
|
||||
self.overwrite = msg.overwrite
|
||||
|
||||
# initialize
|
||||
self.reward = np.ones((self.bo_runs, self.bo_episodes + self.nr_init - 1)) * -self.bo_steps
|
||||
self.reward = np.ones((self.bo_runs, self.bo_episodes + self.nr_init - 1)) * -200
|
||||
self.best_pol_reward = np.ones((self.bo_runs, 1)) * -self.bo_steps
|
||||
self.best_policy = np.zeros((self.bo_runs, self.bo_steps, self.bo_nr_dims))
|
||||
self.best_weights = np.zeros((self.bo_runs, self.bo_nr_weights, self.bo_nr_dims))
|
||||
@ -204,25 +203,18 @@ class ActiveBOTopic(Node):
|
||||
|
||||
self.BO.reset_bo()
|
||||
|
||||
# self.BO.initialize()
|
||||
self.init_pending = True
|
||||
self.get_logger().info('BO Initialization is starting!')
|
||||
# self.get_logger().info(f'{self.rl_pending}')
|
||||
|
||||
if self.init_pending:
|
||||
if self.bo_fixed_seed:
|
||||
seed = self.seed
|
||||
else:
|
||||
seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||
|
||||
w = self.BO.policy_model.random_weights()
|
||||
|
||||
rl_msg = DMP()
|
||||
rl_msg = ActiveRLRequest()
|
||||
rl_msg.interactive_run = 2
|
||||
rl_msg.p_x = w[:, 0]
|
||||
rl_msg.p_y = w[:, 1]
|
||||
rl_msg.weights = self.BO.policy_model.random_weights().flatten('F').tolist()
|
||||
rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist()
|
||||
rl_msg.nr_weights = self.bo_nr_weights
|
||||
rl_msg.nr_steps = self.bo_steps
|
||||
rl_msg.nr_bfs = self.bo_nr_weights
|
||||
rl_msg.nr_dims = self.bo_nr_dims
|
||||
|
||||
self.active_rl_pub.publish(rl_msg)
|
||||
|
||||
@ -271,16 +263,13 @@ class ActiveBOTopic(Node):
|
||||
|
||||
np.savetxt(path, data, delimiter=',')
|
||||
|
||||
|
||||
|
||||
w = self.best_weights[best_policy_idx, :, :]
|
||||
|
||||
rl_msg = DMP()
|
||||
rl_msg = ActiveRLRequest()
|
||||
rl_msg.interactive_run = 1
|
||||
rl_msg.p_x = w[:, 0]
|
||||
rl_msg.p_y = w[:, 1]
|
||||
rl_msg.weights = self.BO.policy_model.random_weights().flatten('F').tolist()
|
||||
rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist()
|
||||
rl_msg.nr_weights = self.bo_nr_weights
|
||||
rl_msg.nr_steps = self.bo_steps
|
||||
rl_msg.nr_bfs = self.bo_nr_weights
|
||||
rl_msg.nr_dims = self.bo_nr_dims
|
||||
|
||||
self.active_rl_pub.publish(rl_msg)
|
||||
|
||||
@ -328,14 +317,13 @@ class ActiveBOTopic(Node):
|
||||
|
||||
self.BO.policy_model.set_weights(old_weights)
|
||||
|
||||
w = self.BO.policy_model.weights
|
||||
|
||||
rl_msg = DMP()
|
||||
rl_msg = ActiveRLRequest()
|
||||
rl_msg.interactive_run = 0
|
||||
rl_msg.p_x = w[:, 0]
|
||||
rl_msg.p_y = w[:, 1]
|
||||
rl_msg.weights = self.BO.policy_model.random_weights().flatten('F').tolist()
|
||||
rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist()
|
||||
rl_msg.nr_weights = self.bo_nr_weights
|
||||
rl_msg.nr_steps = self.bo_steps
|
||||
rl_msg.nr_bfs = self.bo_nr_weights
|
||||
rl_msg.nr_dims = self.bo_nr_dims
|
||||
|
||||
self.active_rl_pub.publish(rl_msg)
|
||||
|
||||
@ -348,14 +336,13 @@ class ActiveBOTopic(Node):
|
||||
x_next = self.BO.next_observation()
|
||||
self.BO.policy_model.set_weights(np.around(x_next, decimals=8))
|
||||
|
||||
w = self.BO.policy_model.weights
|
||||
|
||||
rl_msg = DMP()
|
||||
rl_msg = ActiveRLRequest()
|
||||
rl_msg.interactive_run = 1
|
||||
rl_msg.p_x = w[:, 0]
|
||||
rl_msg.p_y = w[:, 1]
|
||||
rl_msg.weights = self.BO.policy_model.random_weights().flatten('F').tolist()
|
||||
rl_msg.policy = self.BO.policy_model.rollout().flatten('F').tolist()
|
||||
rl_msg.nr_weights = self.bo_nr_weights
|
||||
rl_msg.nr_steps = self.bo_steps
|
||||
rl_msg.nr_bfs = self.bo_nr_weights
|
||||
rl_msg.nr_dims = self.bo_nr_dims
|
||||
|
||||
self.active_rl_pub.publish(rl_msg)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user