27 lines
646 B
Python
27 lines
646 B
Python
|
import numpy as np
|
||
|
|
||
|
class ImprovementQuery:
|
||
|
def __init__(self, threshold, period, rewards):
|
||
|
self.threshold = threshold
|
||
|
self.period = period
|
||
|
self.rewards = rewards
|
||
|
|
||
|
def query(self):
|
||
|
if self.rewards.shape[0] < self.period:
|
||
|
return False
|
||
|
|
||
|
else:
|
||
|
first = self.rewards[-self.period]
|
||
|
last = self.rewards[-1]
|
||
|
|
||
|
slope = (last - first) / self.period
|
||
|
print(slope)
|
||
|
|
||
|
return slope < self.threshold
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
rewards = np.array([0, 1, 2, 3, 4, 5])
|
||
|
Query = ImprovementQuery(0.05, 5, rewards)
|
||
|
|
||
|
print(Query.query())
|