debugging
This commit is contained in:
parent
3e278fb0fc
commit
26a64db6d6
@ -7,6 +7,11 @@ from active_bo_ros.AcquisitionFunctions.ExpectedImprovement import ExpectedImpro
|
|||||||
from active_bo_ros.AcquisitionFunctions.ProbabilityOfImprovement import ProbabilityOfImprovement
|
from active_bo_ros.AcquisitionFunctions.ProbabilityOfImprovement import ProbabilityOfImprovement
|
||||||
from active_bo_ros.AcquisitionFunctions.ConfidenceBound import ConfidenceBound
|
from active_bo_ros.AcquisitionFunctions.ConfidenceBound import ConfidenceBound
|
||||||
|
|
||||||
|
from sklearn.exceptions import ConvergenceWarning
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore', category=ConvergenceWarning)
|
||||||
|
|
||||||
|
|
||||||
class BayesianOptimization:
|
class BayesianOptimization:
|
||||||
def __init__(self, env, nr_steps, nr_init=3, acq='ei', nr_weights=6, policy_seed=None):
|
def __init__(self, env, nr_steps, nr_init=3, acq='ei', nr_weights=6, policy_seed=None):
|
||||||
|
@ -208,6 +208,7 @@ class ActiveRLService(Node):
|
|||||||
self.best_pol_shown = False
|
self.best_pol_shown = False
|
||||||
self.eval_response_received = False
|
self.eval_response_received = False
|
||||||
self.rl_pending = False
|
self.rl_pending = False
|
||||||
|
|
||||||
elif self.interactive_run == 1:
|
elif self.interactive_run == 1:
|
||||||
if not self.policy_sent:
|
if not self.policy_sent:
|
||||||
self.rl_step = 0
|
self.rl_step = 0
|
||||||
|
@ -129,7 +129,7 @@ class ActiveBOTopic(Node):
|
|||||||
self.seed_array = np.zeros((1, self.bo_runs))
|
self.seed_array = np.zeros((1, self.bo_runs))
|
||||||
|
|
||||||
# initialize
|
# initialize
|
||||||
self.reward = np.zeros((self.bo_episodes, self.bo_runs))
|
self.reward = np.zeros((self.bo_episodes+self.nr_init, self.bo_runs))
|
||||||
self.best_pol_reward = np.zeros((1, self.bo_runs))
|
self.best_pol_reward = np.zeros((1, self.bo_runs))
|
||||||
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
|
self.best_policy = np.zeros((self.bo_steps, self.bo_runs))
|
||||||
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
|
self.best_weights = np.zeros((self.bo_nr_weights, self.bo_runs))
|
||||||
@ -194,6 +194,7 @@ class ActiveBOTopic(Node):
|
|||||||
# self.BO.initialize()
|
# self.BO.initialize()
|
||||||
self.init_pending = True
|
self.init_pending = True
|
||||||
self.get_logger().info('BO Initialization is starting!')
|
self.get_logger().info('BO Initialization is starting!')
|
||||||
|
self.get_logger().info(f'{self.rl_pending}')
|
||||||
|
|
||||||
if self.init_pending and not self.rl_pending:
|
if self.init_pending and not self.rl_pending:
|
||||||
|
|
||||||
@ -275,6 +276,7 @@ class ActiveBOTopic(Node):
|
|||||||
active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist()
|
active_rl_request.policy = self.best_policy[:, best_policy_idx].tolist()
|
||||||
active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist()
|
active_rl_request.weights = self.best_weights[:, best_policy_idx].tolist()
|
||||||
active_rl_request.interactive_run = 1
|
active_rl_request.interactive_run = 1
|
||||||
|
active_rl_request.display_run = True
|
||||||
|
|
||||||
self.active_rl_pub.publish(active_rl_request)
|
self.active_rl_pub.publish(active_rl_request)
|
||||||
|
|
||||||
@ -365,13 +367,14 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
self.reward[:, self.current_run] = self.BO.best_reward.T
|
self.reward[:, self.current_run] = self.BO.best_reward.T
|
||||||
|
|
||||||
self.BO = None
|
if self.current_run < self.bo_runs - 1:
|
||||||
|
self.BO = None
|
||||||
|
|
||||||
self.current_episode = 0
|
self.current_episode = 0
|
||||||
if self.bo_fixed_seed:
|
if self.bo_fixed_seed:
|
||||||
self.seed_array[0, self.current_run] = self.seed
|
self.seed_array[0, self.current_run] = self.seed
|
||||||
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
self.seed = int(np.random.randint(1, 2147483647, 1)[0])
|
||||||
self.get_logger().info(f'{self.seed}')
|
# self.get_logger().info(f'{self.seed}')
|
||||||
self.current_run += 1
|
self.current_run += 1
|
||||||
self.get_logger().info(f'Current Run: {self.current_run}')
|
self.get_logger().info(f'Current Run: {self.current_run}')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user