ActiveBOToytask/QueryTesting/ImprovementQuery.py

27 lines
646 B
Python
Raw Normal View History

2023-07-05 11:24:03 +00:00
import numpy as np
class ImprovementQuery:
def __init__(self, threshold, period, rewards):
self.threshold = threshold
self.period = period
self.rewards = rewards
def query(self):
if self.rewards.shape[0] < self.period:
return False
else:
first = self.rewards[-self.period]
last = self.rewards[-1]
slope = (last - first) / self.period
print(slope)
return slope < self.threshold
if __name__ == "__main__":
rewards = np.array([0, 1, 2, 3, 4, 5])
Query = ImprovementQuery(0.05, 5, rewards)
print(Query.query())