DLR-RM/stable-baselines3

Implement sampling and training asynchronously using the SAC algorithm

Open

#715 opened on Jan 2, 2022

View on GitHub
 (6 comments) (0 reactions) (0 assignees)Python (6,550 stars) (1,407 forks)batch import
help wantedquestion

Description

Question

I'm trying to implement sampling and training asynchronously using the SAC algorithm. I made the attempt shown in the code below. But I always get an error because there seems to be a confusion between training and evaluation modes. The training mode (False or True) is configured in the policy. And this is shared between the train and collect_rollouts methods. Is it possible to do collect_rollouts asynchronously?

Additional context

Reference code:

_rollouts_ = queue.Queue()

def train_async(sac, max_steps):
	global _rollouts_
	iteraction = 0
	while True:
		if not _rollouts_.empty():
			rollout = _rollouts_.get()
			if rollout is not None:
				iteraction += 1
				gradient_steps = sac.gradient_steps if sac.gradient_steps >= 0 else rollout.episode_timesteps
				if gradient_steps > 0:
					sac.train(gradient_steps, sac.batch_size)
			else:
				print("Training ending")
				break
		else:
			print("waiting for rollouts....")
			time.sleep(1)

def rollouts_async(sac, max_steps,  callback, log_interval=None):
	steps  = 0
	global _rollouts_
	while True:
		rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
		if rollout.continue_training is False:
			_rollouts_.put(None)
			break
		else:
			_rollouts_.put( rollout )
			steps += 1
			if steps >= max_steps:
				_rollouts_.put(None)
				callback.on_training_end()
				break

def learn_async(sac, total_timesteps = 1000000, callback=None, log_interval=None, tb_log_name="run", reset_num_timesteps=True):
	total_timesteps, callback = sac._setup_learn(total_timesteps, None, callback, 0, 0, None, reset_num_timesteps, tb_log_name)
	callback.on_training_start(locals(), globals())
	t1 = threading.Thread(target=rollouts_async, args=(sac, total_timesteps, callback, log_interval))
	t1.start()
	t2 = threading.Thread(target=train_async, args=(sac, total_timesteps))
	t2.start()
	t1.join()
	t2.join()

Error:

Traceback (most recent call last):
  File "C:\Users\gilza\anaconda3\lib\threading.py", line 973, in _bootstrap_inner
    self.run()
  File "C:\Users\gilza\anaconda3\lib\threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\gilza\doc\lab\nav\NavProAI4U\scripts\sb3sacutils.py", line 134, in rollouts_async
    rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 589, in collect_rollouts
    self._store_transition(replay_buffer, buffer_action, new_obs, reward, done, infos)
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 498, in _store_transition
    replay_buffer.add(
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\buffers.py", line 562, in add
    self.actions[self.pos] = np.array(action).copy()
ValueError: could not broadcast input array from shape (256,4) into shape (4,)

Checklist

  • I have read the documentation (required) OK
  • I have checked that there is no similar issue in the repo (required) OK

Contributor guide