acrobot
error from pendulum fixed
This commit is contained in:
parent
7ea78d5e0f
commit
c6a99c6c3b
@ -1,4 +1,5 @@
|
||||
string env
|
||||
bool fixed_seed
|
||||
string metric
|
||||
uint16 nr_weights
|
||||
uint16 max_steps
|
||||
|
@ -1,3 +1,4 @@
|
||||
string env
|
||||
uint32 seed
|
||||
float32[] policy
|
||||
float32[] weights
|
@ -196,7 +196,7 @@ class AcrobotEnv(core.Env):
|
||||
def step(self, a):
|
||||
s = self.state
|
||||
assert s is not None, "Call reset before using AcrobotEnv object."
|
||||
torque = self.AVAIL_TORQUE[a]
|
||||
torque = a
|
||||
|
||||
# Add noise to the force action
|
||||
if self.torque_noise_max > 0:
|
||||
|
@ -227,8 +227,19 @@ class PendulumEnv(gym.Env):
|
||||
self.surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
|
||||
)
|
||||
|
||||
fname = path.join(path.dirname(__file__), "../../resource/clockwise.png")
|
||||
try:
|
||||
import ament_index_python
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"ament_index_python is not installed`"
|
||||
)
|
||||
|
||||
package_name = 'active_bo_ros'
|
||||
|
||||
package_path = ament_index_python.get_package_share_directory(package_name)
|
||||
fname = path.join(package_path, 'assets', 'clockwise.png')
|
||||
img = pygame.image.load(fname)
|
||||
|
||||
if self.last_u is not None:
|
||||
scale_img = pygame.transform.smoothscale(
|
||||
img,
|
||||
|
@ -64,7 +64,6 @@ class ActiveBOTopic(Node):
|
||||
|
||||
# RL Environments and BO
|
||||
self.env = None
|
||||
self.distance_penalty = 0
|
||||
|
||||
self.BO = None
|
||||
self.nr_init = 3
|
||||
|
@ -65,7 +65,6 @@ class ActiveRLService(Node):
|
||||
# RL Environments
|
||||
self.env = None
|
||||
|
||||
self.distance_penalty = 0
|
||||
self.best_pol_shown = False
|
||||
|
||||
# Main loop timer object
|
||||
@ -114,11 +113,8 @@ class ActiveRLService(Node):
|
||||
def next_image(self, policy):
|
||||
action = policy[self.rl_step]
|
||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||
self.get_logger().info(str(action_clipped) + str(type(action_clipped)))
|
||||
output = self.env.step(action_clipped.astype(np.float32))
|
||||
|
||||
self.get_logger().info(str(output))
|
||||
|
||||
self.rl_reward += output[1]
|
||||
done = output[2]
|
||||
self.rl_step += 1
|
||||
@ -141,8 +137,6 @@ class ActiveRLService(Node):
|
||||
self.image_pub.publish(feedback_msg)
|
||||
|
||||
if not done and self.rl_step == len(policy):
|
||||
distance = -(self.env.goal_position - output[0][0])
|
||||
self.rl_reward += distance * self.distance_penalty
|
||||
done = True
|
||||
|
||||
return done
|
||||
|
@ -20,7 +20,6 @@ class RLService(Node):
|
||||
self.publisher = self.create_publisher(ImageFeedback, 'rl_feedback', 1)
|
||||
|
||||
self.env = None
|
||||
self.distance_penalty = 0
|
||||
|
||||
def rl_callback(self, request, response):
|
||||
|
||||
@ -47,7 +46,6 @@ class RLService(Node):
|
||||
for i in range(len(policy)):
|
||||
action = policy[i]
|
||||
action_clipped = action.clip(min=-1.0, max=1.0)
|
||||
self.get_logger().info(str(action_clipped) + str(type(action_clipped)))
|
||||
output = self.env.step(action_clipped.astype(np.float32))
|
||||
|
||||
reward += output[1]
|
||||
@ -72,10 +70,6 @@ class RLService(Node):
|
||||
if done:
|
||||
break
|
||||
|
||||
if not done and i == len(policy):
|
||||
distance = -(self.env.goal_position - output[0][0])
|
||||
reward += distance * self.distance_penalty
|
||||
|
||||
response.reward = reward
|
||||
response.final_step = step_count
|
||||
|
||||
|
Before Width: | Height: | Size: 6.8 KiB After Width: | Height: | Size: 6.8 KiB |
@ -18,5 +18,6 @@
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
<your_package_name resource="assets/clockwise.png"/>
|
||||
</export>
|
||||
</package>
|
||||
|
@ -17,6 +17,7 @@ setup(
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name), glob('launch/*.launch.py')),
|
||||
('share/' + package_name + '/assets', ['assets/clockwise.png']),
|
||||
],
|
||||
install_requires=['setuptools', 'gym', 'numpy'],
|
||||
zip_safe=True,
|
||||
|
Loading…
Reference in New Issue
Block a user