debugging regular
This commit is contained in:
parent
334f64e22d
commit
3a8acb6807
@ -1,13 +1,12 @@
|
||||
class RegularQuery:
|
||||
def __init__(self, regular):
|
||||
self.regular = regular
|
||||
self.counter = 0
|
||||
def __init__(self, regular, episode):
|
||||
self.regular = int(regular)
|
||||
self.counter = episode
|
||||
|
||||
def query(self):
|
||||
if self.counter < self.regular:
|
||||
self.counter += 1
|
||||
return False
|
||||
|
||||
if self.counter % self.regular == 0 and self.counter != 0:
|
||||
return True
|
||||
|
||||
else:
|
||||
self.counter = 0
|
||||
return True
|
||||
return False
|
||||
|
@ -212,7 +212,7 @@ class ActiveBOTopic(Node):
|
||||
home_dir = os.path.expanduser('~')
|
||||
file_path = os.path.join(home_dir, 'Documents/IntRLResults')
|
||||
filename = env + '-' + acq + '-' + self.bo_metric + '-' \
|
||||
+ str(self.bo_metric_parameter) + '-' \
|
||||
+ str(round(self.bo_metric_parameter, 2)) + '-' \
|
||||
+ str(self.bo_nr_weights) + '-' + str(time.time())
|
||||
filename = filename.replace('.', '_') + '.csv'
|
||||
path = os.path.join(file_path, filename)
|
||||
@ -251,7 +251,7 @@ class ActiveBOTopic(Node):
|
||||
user_query = RandomQuery(self.bo_metric_parameter)
|
||||
|
||||
elif self.bo_metric == "regular":
|
||||
user_query = RegularQuery(self.bo_metric_parameter)
|
||||
user_query = RegularQuery(self.bo_metric_parameter, self.current_episode)
|
||||
|
||||
elif self.bo_metric == "max acquisition":
|
||||
user_query = MaxAcqQuery(self.bo_metric_parameter,
|
||||
|
@ -186,7 +186,7 @@ class ActiveRLService(Node):
|
||||
|
||||
if done:
|
||||
rl_response = ActiveRLResponse()
|
||||
rl_response.weights = self.rl_weights
|
||||
rl_response.weights = self.eval_weights
|
||||
rl_response.reward = self.rl_reward
|
||||
rl_response.final_step = self.rl_step
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user