From 54a2eb9bbacddd7cb1bd586742fa5de0d8deb05b Mon Sep 17 00:00:00 2001 From: Niko Date: Mon, 5 Jun 2023 14:56:00 +0200 Subject: [PATCH] debugging --- plotter/reward_plotter.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/plotter/reward_plotter.py b/plotter/reward_plotter.py index 257ba08..a36e3be 100644 --- a/plotter/reward_plotter.py +++ b/plotter/reward_plotter.py @@ -2,6 +2,7 @@ import numpy as np import matplotlib.pyplot as plt import os + def plot_csv(paths, x_axis, y_axis): for path_ in paths: data = np.genfromtxt(path_, delimiter=',', skip_header=1, dtype=float) @@ -13,7 +14,10 @@ def plot_csv(paths, x_axis, y_axis): # Extract the first part of the filename and use it as a label label = os.path.basename(path_).split('-')[0:5] - label = f"{label[1]}, {label[2]}, {float(label[3].replace('_','.'))}, nrbfs = {int(label[4])}" + label = f"{label[1]}," \ + f" {label[2]}," \ + f" {label[3].replace('_', '.') if label[3] != '' else 'base'}," \ + f" nrbfs = {int(label[4])}" plt.plot(x, mean, label=label) plt.fill_between( @@ -26,16 +30,23 @@ def plot_csv(paths, x_axis, y_axis): plt.xlim([0, mean.shape[0]]) plt.ylabel(y_axis) plt.grid(True) - plt.legend(loc="best") + plt.legend(loc="lower right") plt.show() if __name__ == '__main__': - - filenames = ['mc-ei-random-1_0-5-1685622201_6965265.csv', - 'mc-pi-random-1_0-5-1685622464_9843714.csv', - 'mc-cb-random-1_0-5-1685622728_8990934.csv'] + filenames = ['BO/mc-ei-bo--5-1685952362_3531659.csv', + 'random-0_95/mc-ei-random-0_95-5-1685956146_775975.csv', + 'regular-10_0/mc-ei-regular-10-5-1685968651_7080765.csv', + ] home_dir = os.path.expanduser('~') - file_path = os.path.join(home_dir, 'Documents/IntRLResults') + file_path = os.path.join(home_dir, 'Documents/IntRLResults/mc-e50r10') paths = [os.path.join(file_path, filename) for filename in filenames] plot_csv(paths, 'Episodes', 'Reward') + # + # filenames2 = ['mc-ei-random-1_0-10-1685708208_6402516.csv', + # 'mc-pi-random-1_0-10-1685709168_7113624.csv', + # 'mc-cb-random-1_0-10-1685714487_4603446.csv'] + # + # paths2 = [os.path.join(file_path, filename) for filename in filenames2] + # plot_csv(paths2, 'Episodes', 'Reward')