ActiveBOToytask/plotter/reward_plotter.py

51 lines
1.6 KiB
Python
Raw Normal View History

2023-06-01 08:57:18 +00:00
import numpy as np
import matplotlib.pyplot as plt
import os
2023-06-05 12:56:00 +00:00
2023-06-01 08:57:18 +00:00
def plot_csv(paths, x_axis, y_axis):
for path_ in paths:
2023-07-05 11:24:03 +00:00
data = np.genfromtxt(path_, delimiter=',', skip_header=0, dtype=float)
2023-06-01 08:57:18 +00:00
2023-09-15 15:36:19 +00:00
mean = np.mean(data, axis=0)
std = np.std(data, axis=0)
2023-06-01 08:57:18 +00:00
x = np.linspace(0, mean.shape[0], mean.shape[0])
# Extract the first part of the filename and use it as a label
2023-06-01 13:46:13 +00:00
label = os.path.basename(path_).split('-')[0:5]
2023-06-05 12:56:00 +00:00
label = f"{label[1]}," \
f" {label[2]}," \
f" {label[3].replace('_', '.') if label[3] != '' else 'base'}," \
f" nrbfs = {int(label[4])}"
2023-06-01 08:57:18 +00:00
plt.plot(x, mean, label=label)
plt.fill_between(
x,
mean - 1.96 * std,
mean + 1.96 * std,
alpha=0.5
)
plt.xlabel(x_axis)
plt.xlim([0, mean.shape[0]])
plt.ylabel(y_axis)
plt.grid(True)
2023-07-05 11:24:03 +00:00
plt.legend(loc="upper left")
2023-06-01 08:57:18 +00:00
plt.show()
if __name__ == '__main__':
2023-09-15 15:36:19 +00:00
filenames = ['franka-pei-random-1_0-6-1694787936_385925.csv',
2023-06-05 12:56:00 +00:00
]
2023-06-01 08:57:18 +00:00
home_dir = os.path.expanduser('~')
2023-09-15 15:36:19 +00:00
file_path = os.path.join(home_dir, 'Documents/IntRLResults/Franka-Results')
2023-06-01 08:57:18 +00:00
paths = [os.path.join(file_path, filename) for filename in filenames]
plot_csv(paths, 'Episodes', 'Reward')
2023-06-05 12:56:00 +00:00
#
# filenames2 = ['mc-ei-random-1_0-10-1685708208_6402516.csv',
# 'mc-pi-random-1_0-10-1685709168_7113624.csv',
# 'mc-cb-random-1_0-10-1685714487_4603446.csv']
#
# paths2 = [os.path.join(file_path, filename) for filename in filenames2]
# plot_csv(paths2, 'Episodes', 'Reward')