DLR-RM/stable-baselines3

Implement sampling and training asynchronously using the SAC algorithm

Open

#715 创建于 2022年1月2日

在 GitHub 查看
 (6 评论) (0 反应) (0 负责人)Python (6,550 star) (1,407 fork)batch import
help wantedquestion

描述

Question

I'm trying to implement sampling and training asynchronously using the SAC algorithm. I made the attempt shown in the code below. But I always get an error because there seems to be a confusion between training and evaluation modes. The training mode (False or True) is configured in the policy. And this is shared between the train and collect_rollouts methods. Is it possible to do collect_rollouts asynchronously?

Additional context

Reference code:

_rollouts_ = queue.Queue()

def train_async(sac, max_steps):
	global _rollouts_
	iteraction = 0
	while True:
		if not _rollouts_.empty():
			rollout = _rollouts_.get()
			if rollout is not None:
				iteraction += 1
				gradient_steps = sac.gradient_steps if sac.gradient_steps >= 0 else rollout.episode_timesteps
				if gradient_steps > 0:
					sac.train(gradient_steps, sac.batch_size)
			else:
				print("Training ending")
				break
		else:
			print("waiting for rollouts....")
			time.sleep(1)

def rollouts_async(sac, max_steps,  callback, log_interval=None):
	steps  = 0
	global _rollouts_
	while True:
		rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
		if rollout.continue_training is False:
			_rollouts_.put(None)
			break
		else:
			_rollouts_.put( rollout )
			steps += 1
			if steps >= max_steps:
				_rollouts_.put(None)
				callback.on_training_end()
				break

def learn_async(sac, total_timesteps = 1000000, callback=None, log_interval=None, tb_log_name="run", reset_num_timesteps=True):
	total_timesteps, callback = sac._setup_learn(total_timesteps, None, callback, 0, 0, None, reset_num_timesteps, tb_log_name)
	callback.on_training_start(locals(), globals())
	t1 = threading.Thread(target=rollouts_async, args=(sac, total_timesteps, callback, log_interval))
	t1.start()
	t2 = threading.Thread(target=train_async, args=(sac, total_timesteps))
	t2.start()
	t1.join()
	t2.join()

Error:

Traceback (most recent call last):
  File "C:\Users\gilza\anaconda3\lib\threading.py", line 973, in _bootstrap_inner
    self.run()
  File "C:\Users\gilza\anaconda3\lib\threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\gilza\doc\lab\nav\NavProAI4U\scripts\sb3sacutils.py", line 134, in rollouts_async
    rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 589, in collect_rollouts
    self._store_transition(replay_buffer, buffer_action, new_obs, reward, done, infos)
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 498, in _store_transition
    replay_buffer.add(
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\buffers.py", line 562, in add
    self.actions[self.pos] = np.array(action).copy()
ValueError: could not broadcast input array from shape (256,4) into shape (4,)

Checklist

  • I have read the documentation (required) OK
  • I have checked that there is no similar issue in the repo (required) OK

贡献者指南