preparation for 2d policy

This commit is contained in:
Niko Feith 2023-08-28 16:37:20 +02:00
parent 4770ffb12d
commit 9a70229171
3 changed files with 46 additions and 17 deletions

View File

@ -144,8 +144,19 @@ const active_rl_eval_sub = new ROS.Topic({
const pendingRequest = ref(false); const pendingRequest = ref(false);
active_rl_eval_sub.subscribe((msg) => { active_rl_eval_sub.subscribe((msg) => {
pstore.setPolicy(msg.policy); const nr_steps = msg.nr_steps;
pstore.setWeights(msg.weights); const nr_weights = msg.nr_weights;
const pol_x = msg.policy.slice(0, nr_steps);
const pol_y = msg.policy.slice(nr_steps);
const weights_x = msg.weights.slice(0, nr_weights);
const weights_y = msg.weights.slice(nr_weights);
pstore.setPolicy(pol_x);
pstore.setPolicy_y(pol_y);
pstore.setWeights(weights_x);
pstore.setWeights(weights_y);
pendingRequest.value = true; pendingRequest.value = true;
}); });
@ -162,10 +173,13 @@ watch(
return; return;
} }
console.log("Button pressed!"); console.log("Button pressed!");
const policy = pstore.policy.concat(pstore.policy_y);
const weights = pstore.weights.concat(pstore.weights_y);
const weights_fixed = pstore.weights_fixed.concat(pstore.weights_fixed_y);
const active_eval_response = new ROS.Message({ const active_eval_response = new ROS.Message({
policy: pstore.policy, policy: policy,
weights: pstore.weights, weights: weights,
overwrite_weight: pstore.weights_fixed, overwrite_weight: weights_fixed,
}); });
console.log(active_eval_response); console.log(active_eval_response);
@ -178,12 +192,6 @@ watch(
const active_bo_pending = ref(false); const active_bo_pending = ref(false);
const active_bo_request = new ROS.Topic({
ros: ros,
name: "/active_bo_request",
messageType: "active_bo_msgs/msg/ActiveBORequest",
});
const active_bo_response = new ROS.Topic({ const active_bo_response = new ROS.Topic({
ros: ros, ros: ros,
name: "/active_bo_response", name: "/active_bo_response",
@ -191,13 +199,30 @@ const active_bo_response = new ROS.Topic({
}); });
active_bo_response.subscribe((msg) => { active_bo_response.subscribe((msg) => {
pstore.setPolicy(msg.best_policy); const nr_steps = msg.nr_steps;
pstore.setWeights(msg.best_weights); const nr_weights = msg.nr_weights;
const pol_x = msg.best_policy.slice(0, nr_steps);
const pol_y = msg.best_policy.slice(nr_steps);
const weights_x = msg.best_weights.slice(0, nr_weights);
const weights_y = msg.best_weights.slice(nr_weights);
pstore.setPolicy(pol_x);
pstore.setPolicy_y(pol_y);
pstore.setWeights(weights_x);
pstore.setWeights(weights_y);
rstore.setMean(msg.reward_mean); rstore.setMean(msg.reward_mean);
rstore.setStd(msg.reward_std); rstore.setStd(msg.reward_std);
active_bo_pending.value = false; active_bo_pending.value = false;
}); });
const active_bo_request = new ROS.Topic({
ros: ros,
name: "/active_bo_request",
messageType: "active_bo_msgs/msg/ActiveBORequest",
});
watch( watch(
() => cstore.getRunner, () => cstore.getRunner,
() => { () => {
@ -207,6 +232,7 @@ watch(
fixed_seed: cstore.fixed_seed, fixed_seed: cstore.fixed_seed,
nr_weights: pstore.nr_weights, nr_weights: pstore.nr_weights,
max_steps: pstore.max_steps, max_steps: pstore.max_steps,
nr_dims: cstore.env_dim[cstore.env],
nr_episodes: cstore.nr_episodes, nr_episodes: cstore.nr_episodes,
nr_runs: cstore.nr_runs, nr_runs: cstore.nr_runs,
acquisition_function: cstore.acq_fun, acquisition_function: cstore.acq_fun,

View File

@ -1,10 +1,14 @@
import { defineStore } from "pinia"; import { defineStore } from "pinia";
import {state} from "vue-tsc/out/shared";
export const useCStore = defineStore("Control Store", { export const useCStore = defineStore("Control Store", {
state: () => { state: () => {
return { return {
env: "Mountain Car", env: "Reacher",
envs: ["Mountain Car", "Cartpole", "Acrobot", "Pendulum"], envs: ["Reacher"],
env_dim: {
Reacher: 2,
},
metric: "random", metric: "random",
metrics: ["random", "regular", "improvement", "max_acquisition"], metrics: ["random", "regular", "improvement", "max_acquisition"],
metrics_label: { metrics_label: {
@ -39,6 +43,7 @@ export const useCStore = defineStore("Control Store", {
}, },
getters: { getters: {
getEnv: (state) => state.env, getEnv: (state) => state.env,
getEnvdims: (state) => state.env_dims,
getMetric: (state) => state.metric, getMetric: (state) => state.metric,
getNrEpisodes: (state) => state.nr_episodes, getNrEpisodes: (state) => state.nr_episodes,
getNrRuns: (state) => state.nr_runs, getNrRuns: (state) => state.nr_runs,

View File

@ -57,8 +57,6 @@ export const usePStore = defineStore("Policy Store", {
}, },
resetWeights_Fixed() { resetWeights_Fixed() {
this.weights_fixed = Array(this.nr_weights).fill(false); this.weights_fixed = Array(this.nr_weights).fill(false);
},
resetWeights_Fixed_y() {
this.weights_fixed_y = Array(this.nr_weights).fill(false); this.weights_fixed_y = Array(this.nr_weights).fill(false);
}, },
}, },