Dateien hochladen nach „plotter“
This commit is contained in:
parent
32703efe74
commit
82d2b61e6b
66
plotter/reward_plotter_final.py
Normal file
66
plotter/reward_plotter_final.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def plot_csv_with_titles(paths_dict, x_axis, y_axis, subplot_titles, y_limits=(0, 200)):
|
||||||
|
# Adjustments for dynamic subplot creation based on the dictionary size
|
||||||
|
num_subplots = len(paths_dict)
|
||||||
|
fig, axs = plt.subplots(1, num_subplots, figsize=(4*num_subplots, 4))
|
||||||
|
|
||||||
|
# If only one subplot, axs is not an array, so we need to convert it
|
||||||
|
if num_subplots == 1:
|
||||||
|
axs = [axs]
|
||||||
|
|
||||||
|
for idx, (_, file_list) in enumerate(paths_dict.items()):
|
||||||
|
for path_, color in file_list:
|
||||||
|
data = np.genfromtxt(path_, delimiter=',', skip_header=0, dtype=float)
|
||||||
|
|
||||||
|
mean = np.mean(data, axis=1)
|
||||||
|
std = np.std(data, axis=1)
|
||||||
|
x = np.linspace(0, mean.shape[0], mean.shape[0])
|
||||||
|
|
||||||
|
axs[idx].plot(x, mean, color=color)
|
||||||
|
axs[idx].fill_between(
|
||||||
|
x,
|
||||||
|
mean - 1.96 * std,
|
||||||
|
mean + 1.96 * std,
|
||||||
|
color=color,
|
||||||
|
alpha=0.5
|
||||||
|
)
|
||||||
|
axs[idx].set_title(subplot_titles[idx])
|
||||||
|
axs[idx].set_xlabel(x_axis)
|
||||||
|
axs[idx].set_xlim([0, mean.shape[0]])
|
||||||
|
axs[idx].grid(True)
|
||||||
|
axs[idx].set_ylim(y_limits)
|
||||||
|
|
||||||
|
# Only label the y-axis for the leftmost plot
|
||||||
|
if idx == 0:
|
||||||
|
axs[idx].set_ylabel(y_axis)
|
||||||
|
else:
|
||||||
|
axs[idx].set_ylabel('')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
filepaths = [
|
||||||
|
"/mnt/data/cp-ei-regular-20_0-15-1686582970_1112866.csv",
|
||||||
|
"/mnt/data/cp-ei-random-0_95-15-1686579274_2881138.csv",
|
||||||
|
"/mnt/data/cp-cb-random-1_0-15-1686575989_8880587.csv",
|
||||||
|
"/mnt/data/cp-pi-random-1_0-15-1686575712_588163.csv"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Demonstrating the adjusted function with subplot titles
|
||||||
|
titles = ["Plot 1", "Plot 2", "Plot 3", "Plot 4"]
|
||||||
|
|
||||||
|
data_dict_colored = {
|
||||||
|
'subplot1': [(filepaths[0], 'blue'), (filepaths[1], 'green')],
|
||||||
|
'subplot2': [(filepaths[2], 'red')],
|
||||||
|
'subplot3': [(filepaths[3], 'purple')],
|
||||||
|
'subplot4': []
|
||||||
|
}
|
||||||
|
|
||||||
|
plot_csv_with_titles(data_dict_colored, 'Episodes', 'Reward', titles)
|
Loading…
Reference in New Issue
Block a user