diff --git a/plotter/reward_plotter.py b/plotter/reward_plotter.py index 3b546ea..257ba08 100644 --- a/plotter/reward_plotter.py +++ b/plotter/reward_plotter.py @@ -12,8 +12,8 @@ def plot_csv(paths, x_axis, y_axis): x = np.linspace(0, mean.shape[0], mean.shape[0]) # Extract the first part of the filename and use it as a label - label = os.path.basename(path_).split('-')[0:3] - label = f"{label[0]} {float(label[1].replace('_','.'))}, nrbfs = {int(label[2])}" + label = os.path.basename(path_).split('-')[0:5] + label = f"{label[1]}, {label[2]}, {float(label[3].replace('_','.'))}, nrbfs = {int(label[4])}" plt.plot(x, mean, label=label) plt.fill_between( @@ -32,7 +32,9 @@ def plot_csv(paths, x_axis, y_axis): if __name__ == '__main__': - filenames = ['random-1_0-5-1685552722_2243946.csv'] + filenames = ['mc-ei-random-1_0-5-1685622201_6965265.csv', + 'mc-pi-random-1_0-5-1685622464_9843714.csv', + 'mc-cb-random-1_0-5-1685622728_8990934.csv'] home_dir = os.path.expanduser('~') file_path = os.path.join(home_dir, 'Documents/IntRLResults') paths = [os.path.join(file_path, filename) for filename in filenames]