diff --git a/train.py b/train.py index 027ee42..e69f81a 100644 --- a/train.py +++ b/train.py @@ -318,7 +318,7 @@ def main(): ) # stack several consecutive frames together - if args.encoder_type == 'pixel': + if args.encoder_type.startswith('pixel'): env = utils.FrameStack(env, k=args.frame_stack) eval_env = utils.FrameStack(eval_env, k=args.frame_stack)