DLR-RM/stable-baselines3
GitHub で見るImplement sampling and training asynchronously using the SAC algorithm
Open
#715 opened on 2022年1月2日
help wantedquestion
説明
Question
I'm trying to implement sampling and training asynchronously using the SAC algorithm. I made the attempt shown in the code below. But I always get an error because there seems to be a confusion between training and evaluation modes. The training mode (False or True) is configured in the policy. And this is shared between the train and collect_rollouts methods. Is it possible to do collect_rollouts asynchronously?
Additional context
Reference code:
_rollouts_ = queue.Queue()
def train_async(sac, max_steps):
global _rollouts_
iteraction = 0
while True:
if not _rollouts_.empty():
rollout = _rollouts_.get()
if rollout is not None:
iteraction += 1
gradient_steps = sac.gradient_steps if sac.gradient_steps >= 0 else rollout.episode_timesteps
if gradient_steps > 0:
sac.train(gradient_steps, sac.batch_size)
else:
print("Training ending")
break
else:
print("waiting for rollouts....")
time.sleep(1)
def rollouts_async(sac, max_steps, callback, log_interval=None):
steps = 0
global _rollouts_
while True:
rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
if rollout.continue_training is False:
_rollouts_.put(None)
break
else:
_rollouts_.put( rollout )
steps += 1
if steps >= max_steps:
_rollouts_.put(None)
callback.on_training_end()
break
def learn_async(sac, total_timesteps = 1000000, callback=None, log_interval=None, tb_log_name="run", reset_num_timesteps=True):
total_timesteps, callback = sac._setup_learn(total_timesteps, None, callback, 0, 0, None, reset_num_timesteps, tb_log_name)
callback.on_training_start(locals(), globals())
t1 = threading.Thread(target=rollouts_async, args=(sac, total_timesteps, callback, log_interval))
t1.start()
t2 = threading.Thread(target=train_async, args=(sac, total_timesteps))
t2.start()
t1.join()
t2.join()
Error:
Traceback (most recent call last):
File "C:\Users\gilza\anaconda3\lib\threading.py", line 973, in _bootstrap_inner
self.run()
File "C:\Users\gilza\anaconda3\lib\threading.py", line 910, in run
self._target(*self._args, **self._kwargs)
File "C:\Users\gilza\doc\lab\nav\NavProAI4U\scripts\sb3sacutils.py", line 134, in rollouts_async
rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 589, in collect_rollouts
self._store_transition(replay_buffer, buffer_action, new_obs, reward, done, infos)
File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 498, in _store_transition
replay_buffer.add(
File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\buffers.py", line 562, in add
self.actions[self.pos] = np.array(action).copy()
ValueError: could not broadcast input array from shape (256,4) into shape (4,)
Checklist
- I have read the documentation (required) OK
- I have checked that there is no similar issue in the repo (required) OK