acrobot
error from pendulum fixed
This commit is contained in:
parent
7ea78d5e0f
commit
c6a99c6c3b
@ -1,4 +1,5 @@
|
|||||||
string env
|
string env
|
||||||
|
bool fixed_seed
|
||||||
string metric
|
string metric
|
||||||
uint16 nr_weights
|
uint16 nr_weights
|
||||||
uint16 max_steps
|
uint16 max_steps
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
string env
|
string env
|
||||||
|
uint32 seed
|
||||||
float32[] policy
|
float32[] policy
|
||||||
float32[] weights
|
float32[] weights
|
@ -196,7 +196,7 @@ class AcrobotEnv(core.Env):
|
|||||||
def step(self, a):
|
def step(self, a):
|
||||||
s = self.state
|
s = self.state
|
||||||
assert s is not None, "Call reset before using AcrobotEnv object."
|
assert s is not None, "Call reset before using AcrobotEnv object."
|
||||||
torque = self.AVAIL_TORQUE[a]
|
torque = a
|
||||||
|
|
||||||
# Add noise to the force action
|
# Add noise to the force action
|
||||||
if self.torque_noise_max > 0:
|
if self.torque_noise_max > 0:
|
||||||
|
@ -227,8 +227,19 @@ class PendulumEnv(gym.Env):
|
|||||||
self.surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
|
self.surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
|
||||||
)
|
)
|
||||||
|
|
||||||
fname = path.join(path.dirname(__file__), "../../resource/clockwise.png")
|
try:
|
||||||
|
import ament_index_python
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"ament_index_python is not installed`"
|
||||||
|
)
|
||||||
|
|
||||||
|
package_name = 'active_bo_ros'
|
||||||
|
|
||||||
|
package_path = ament_index_python.get_package_share_directory(package_name)
|
||||||
|
fname = path.join(package_path, 'assets', 'clockwise.png')
|
||||||
img = pygame.image.load(fname)
|
img = pygame.image.load(fname)
|
||||||
|
|
||||||
if self.last_u is not None:
|
if self.last_u is not None:
|
||||||
scale_img = pygame.transform.smoothscale(
|
scale_img = pygame.transform.smoothscale(
|
||||||
img,
|
img,
|
||||||
|
@ -64,7 +64,6 @@ class ActiveBOTopic(Node):
|
|||||||
|
|
||||||
# RL Environments and BO
|
# RL Environments and BO
|
||||||
self.env = None
|
self.env = None
|
||||||
self.distance_penalty = 0
|
|
||||||
|
|
||||||
self.BO = None
|
self.BO = None
|
||||||
self.nr_init = 3
|
self.nr_init = 3
|
||||||
|
@ -65,7 +65,6 @@ class ActiveRLService(Node):
|
|||||||
# RL Environments
|
# RL Environments
|
||||||
self.env = None
|
self.env = None
|
||||||
|
|
||||||
self.distance_penalty = 0
|
|
||||||
self.best_pol_shown = False
|
self.best_pol_shown = False
|
||||||
|
|
||||||
# Main loop timer object
|
# Main loop timer object
|
||||||
@ -114,11 +113,8 @@ class ActiveRLService(Node):
|
|||||||
def next_image(self, policy):
|
def next_image(self, policy):
|
||||||
action = policy[self.rl_step]
|
action = policy[self.rl_step]
|
||||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||||
self.get_logger().info(str(action_clipped) + str(type(action_clipped)))
|
|
||||||
output = self.env.step(action_clipped.astype(np.float32))
|
output = self.env.step(action_clipped.astype(np.float32))
|
||||||
|
|
||||||
self.get_logger().info(str(output))
|
|
||||||
|
|
||||||
self.rl_reward += output[1]
|
self.rl_reward += output[1]
|
||||||
done = output[2]
|
done = output[2]
|
||||||
self.rl_step += 1
|
self.rl_step += 1
|
||||||
@ -141,8 +137,6 @@ class ActiveRLService(Node):
|
|||||||
self.image_pub.publish(feedback_msg)
|
self.image_pub.publish(feedback_msg)
|
||||||
|
|
||||||
if not done and self.rl_step == len(policy):
|
if not done and self.rl_step == len(policy):
|
||||||
distance = -(self.env.goal_position - output[0][0])
|
|
||||||
self.rl_reward += distance * self.distance_penalty
|
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
return done
|
return done
|
||||||
|
@ -20,7 +20,6 @@ class RLService(Node):
|
|||||||
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
||||||
|
|
||||||
self.env = None
|
self.env = None
|
||||||
self.distance_penalty = 0
|
|
||||||
|
|
||||||
def rl_callback(self, request, response):
|
def rl_callback(self, request, response):
|
||||||
|
|
||||||
@ -47,7 +46,6 @@ class RLService(Node):
|
|||||||
for i in range(len(policy)):
|
for i in range(len(policy)):
|
||||||
action = policy[i]
|
action = policy[i]
|
||||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||||
self.get_logger().info(str(action_clipped) + str(type(action_clipped)))
|
|
||||||
output = self.env.step(action_clipped.astype(np.float32))
|
output = self.env.step(action_clipped.astype(np.float32))
|
||||||
|
|
||||||
reward += output[1]
|
reward += output[1]
|
||||||
@ -72,10 +70,6 @@ class RLService(Node):
|
|||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not done and i == len(policy):
|
|
||||||
distance = -(self.env.goal_position - output[0][0])
|
|
||||||
reward += distance * self.distance_penalty
|
|
||||||
|
|
||||||
response.reward = reward
|
response.reward = reward
|
||||||
response.final_step = step_count
|
response.final_step = step_count
|
||||||
|
|
||||||
|
Before Width: | Height: | Size: 6.8 KiB After Width: | Height: | Size: 6.8 KiB |
@ -18,5 +18,6 @@
|
|||||||
|
|
||||||
<export>
|
<export>
|
||||||
<build_type>ament_python</build_type>
|
<build_type>ament_python</build_type>
|
||||||
|
<your_package_name resource="assets/clockwise.png"/>
|
||||||
</export>
|
</export>
|
||||||
</package>
|
</package>
|
||||||
|
@ -17,6 +17,7 @@ setup(
|
|||||||
['resource/' + package_name]),
|
['resource/' + package_name]),
|
||||||
('share/' + package_name, ['package.xml']),
|
('share/' + package_name, ['package.xml']),
|
||||||
(os.path.join('share', package_name), glob('launch/*.launch.py')),
|
(os.path.join('share', package_name), glob('launch/*.launch.py')),
|
||||||
|
('share/' + package_name + '/assets', ['assets/clockwise.png']),
|
||||||
],
|
],
|
||||||
install_requires=['setuptools', 'gym', 'numpy'],
|
install_requires=['setuptools', 'gym', 'numpy'],
|
||||||
zip_safe=True,
|
zip_safe=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user