From 46f779c1a01e58e25f606bb2f29311592311fc73 Mon Sep 17 00:00:00 2001 From: Gaiejj Date: Fri, 3 May 2024 02:37:30 +0800 Subject: [PATCH] chore: fix code style --- omnisafe/adapter/offpolicy_latent_adapter.py | 272 +++++++++++++++ omnisafe/algorithms/__init__.py | 1 + omnisafe/algorithms/off_policy/__init__.py | 14 +- omnisafe/algorithms/off_policy/ddpg.py | 4 +- omnisafe/algorithms/off_policy/safe_slac.py | 293 ++++++++++++++++ omnisafe/common/buffer/__init__.py | 3 +- omnisafe/common/buffer/base.py | 51 +++ omnisafe/common/buffer/offpolicy_buffer.py | 79 ++++- omnisafe/common/latent.py | 344 +++++++++++++++++++ omnisafe/configs/off-policy/SafeSLAC.yaml | 148 ++++++++ omnisafe/envs/__init__.py | 1 + omnisafe/envs/safety_gymnasium_vision_env.py | 194 +++++++++++ omnisafe/utils/model.py | 45 ++- omnisafe/utils/tools.py | 82 ++++- 14 files changed, 1521 insertions(+), 10 deletions(-) create mode 100644 omnisafe/adapter/offpolicy_latent_adapter.py create mode 100644 omnisafe/algorithms/off_policy/safe_slac.py create mode 100644 omnisafe/common/latent.py create mode 100644 omnisafe/configs/off-policy/SafeSLAC.yaml create mode 100644 omnisafe/envs/safety_gymnasium_vision_env.py diff --git a/omnisafe/adapter/offpolicy_latent_adapter.py b/omnisafe/adapter/offpolicy_latent_adapter.py new file mode 100644 index 000000000..105ee8ad1 --- /dev/null +++ b/omnisafe/adapter/offpolicy_latent_adapter.py @@ -0,0 +1,272 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""OffPolicy Latent Adapter for OmniSafe.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch +from gymnasium.spaces import Box + +from omnisafe.adapter.online_adapter import OnlineAdapter +from omnisafe.common.buffer import OffPolicySequenceBuffer +from omnisafe.common.latent import CostLatentModel +from omnisafe.common.logger import Logger +from omnisafe.envs.wrapper import ( + ActionRepeat, + ActionScale, + AutoReset, + CostNormalize, + ObsNormalize, + RewardNormalize, + TimeLimit, + Unsqueeze, +) +from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic +from omnisafe.utils.config import Config +from omnisafe.utils.model import ObservationConcator + + +class OffPolicyLatentAdapter(OnlineAdapter): + _current_obs: torch.Tensor + _ep_ret: torch.Tensor + _ep_cost: torch.Tensor + _ep_len: torch.Tensor + + def __init__( # pylint: disable=too-many-arguments + self, + env_id: str, + num_envs: int, + seed: int, + cfgs: Config, + ) -> None: + """Initialize a instance of :class:`OffPolicyAdapter`.""" + super().__init__(env_id, num_envs, seed, cfgs) + self._observation_concator: ObservationConcator = ObservationConcator( + self._cfgs.algo_cfgs.latent_dim_1 + self._cfgs.algo_cfgs.latent_dim_2, + self.action_space.shape, + self._cfgs.algo_cfgs.num_sequences, + device=self._device, + ) + self._current_obs, _ = self.reset() + self._max_ep_len: int = 1000 + self._reset_log() + self.z1 = None + self.z2 = None + self._reset_sequence_queue = False + + def _wrapper( + self, + obs_normalize: bool = True, + reward_normalize: bool = True, + cost_normalize: bool = True, + ) -> None: + """Wrapper the environment. + + .. hint:: + OmniSafe supports the following wrappers: + + +-----------------+--------------------------------------------------------+ + | Wrapper | Description | + +=================+========================================================+ + | TimeLimit | Limit the time steps of the environment. | + +-----------------+--------------------------------------------------------+ + | AutoReset | Reset the environment when the episode is done. | + +-----------------+--------------------------------------------------------+ + | ObsNormalize | Normalize the observation. | + +-----------------+--------------------------------------------------------+ + | RewardNormalize | Normalize the reward. | + +-----------------+--------------------------------------------------------+ + | CostNormalize | Normalize the cost. | + +-----------------+--------------------------------------------------------+ + | ActionScale | Scale the action. | + +-----------------+--------------------------------------------------------+ + | Unsqueeze | Unsqueeze the step result for single environment case. | + +-----------------+--------------------------------------------------------+ + + + Args: + obs_normalize (bool, optional): Whether to normalize the observation. Defaults to True. + reward_normalize (bool, optional): Whether to normalize the reward. Defaults to True. + cost_normalize (bool, optional): Whether to normalize the cost. Defaults to True. + """ + if self._env.need_time_limit_wrapper: + self._env = TimeLimit(self._env, device=self._device, time_limit=1000) + if self._env.need_auto_reset_wrapper: + self._env = AutoReset(self._env, device=self._device) + if obs_normalize: + self._env = ObsNormalize(self._env, device=self._device) + if reward_normalize: + self._env = RewardNormalize(self._env, device=self._device) + if cost_normalize: + self._env = CostNormalize(self._env, device=self._device) + self._env = ActionScale(self._env, device=self._device, low=-1.0, high=1.0) + self._env = ActionRepeat(self._env, times=2, device=self._device) + + if self._env.num_envs == 1: + self._env = Unsqueeze(self._env, device=self._device) + + @property + def latent_space(self) -> Box: + """Get the latent space.""" + return Box( + low=-np.inf, + high=np.inf, + shape=(self._cfgs.algo_cfgs.latent_dim_1 + self._cfgs.algo_cfgs.latent_dim_2,), + ) + + def eval_policy( # pylint: disable=too-many-locals + self, + episode: int, + agent: ConstraintActorQCritic, + logger: Logger, + ) -> None: + for _ in range(episode): + ep_ret, ep_cost, ep_len = 0.0, 0.0, 0 + obs, _ = self._eval_env.reset() + obs = obs.to(self._device) + + done = False + while not done: + act = agent.step(obs, deterministic=True) + obs, reward, cost, terminated, truncated, info = self._eval_env.step(act) + obs, reward, cost, terminated, truncated = ( + torch.as_tensor(x, dtype=torch.float32, device=self._device) + for x in (obs, reward, cost, terminated, truncated) + ) + ep_ret += info.get('original_reward', reward).cpu() + ep_cost += info.get('original_cost', cost).cpu() + ep_len += 1 + done = bool(terminated[0].item()) or bool(truncated[0].item()) + + logger.store( + { + 'Metrics/TestEpRet': ep_ret, + 'Metrics/TestEpCost': ep_cost, + 'Metrics/TestEpLen': ep_len, + }, + ) + + def pre_process(self, latent_model, concated_obs): + with torch.no_grad(): + feature = latent_model.encoder(concated_obs.last_state) + + if self.z2 is None: + z1_mean, z1_std = latent_model.z1_posterior_init(feature) + self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = latent_model.z2_posterior_init(self.z1) + self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std + else: + z1_mean, z1_std = latent_model.z1_posterior( + torch.cat([feature.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1) + ) + self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std + z2_mean, z2_std = latent_model.z2_posterior( + torch.cat([self.z1.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1) + ) + self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std + + return torch.cat([self.z1, self.z2], dim=-1).squeeze() + + def rollout( # pylint: disable=too-many-locals + self, + rollout_step: int, + agent: ConstraintActorQCritic, + latent_model: CostLatentModel, + buffer: OffPolicySequenceBuffer, + logger: Logger, + use_rand_action: bool, + ) -> None: + for step in range(rollout_step): + if not self._reset_sequence_queue: + buffer.reset_sequence_queue(self._current_obs) + self._observation_concator.reset_episode(self._current_obs) + self._reset_sequence_queue = True + + if use_rand_action: + act = act = (torch.rand(self.action_space.shape) * 2 - 1).to(self._device) # type: ignore + else: + act = agent.step( + self.pre_process(latent_model, self._observation_concator), deterministic=False + ) + + next_obs, reward, cost, terminated, truncated, info = self.step(act) + step += info.get('num_step', 1) - 1 + + real_next_obs = next_obs.clone() + + self._observation_concator.append(next_obs, act) + + self._log_value(reward=reward, cost=cost, info=info) + + for idx, done in enumerate(torch.logical_or(terminated, truncated)): + if done: + self._log_metrics(logger, idx) + self._reset_log(idx) + self.z1 = None + self.z2 = None + self._reset_sequence_queue = False + if 'final_observation' in info: + real_next_obs[idx] = info['final_observation'][idx] + + buffer.store( + obs=real_next_obs, + act=act, + reward=reward, + cost=cost, + done=torch.logical_and(terminated, torch.logical_xor(terminated, truncated)), + ) + + self._current_obs = next_obs + + def _log_value( + self, + reward: torch.Tensor, + cost: torch.Tensor, + info: dict[str, Any], + ) -> None: + self._ep_ret += info.get('original_reward', reward).cpu() + self._ep_cost += info.get('original_cost', cost).cpu() + self._ep_len += info.get('num_step', 1) + + def _log_metrics(self, logger: Logger, idx: int) -> None: + logger.store( + { + 'Metrics/EpRet': self._ep_ret[idx], + 'Metrics/EpCost': self._ep_cost[idx], + 'Metrics/EpLen': self._ep_len[idx], + }, + ) + + def _reset_log(self, idx: int | None = None) -> None: + if idx is None: + self._ep_ret = torch.zeros(self._env.num_envs) + self._ep_cost = torch.zeros(self._env.num_envs) + self._ep_len = torch.zeros(self._env.num_envs) + else: + self._ep_ret[idx] = 0.0 + self._ep_cost[idx] = 0.0 + self._ep_len[idx] = 0.0 + + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + obs, info = self._env.reset(seed=seed, options=options) + self._observation_concator.reset_episode(obs) + return obs, info diff --git a/omnisafe/algorithms/__init__.py b/omnisafe/algorithms/__init__.py index 39cb686ed..7965148e2 100644 --- a/omnisafe/algorithms/__init__.py +++ b/omnisafe/algorithms/__init__.py @@ -33,6 +33,7 @@ TD3PID, DDPGLag, SACLag, + SafeSLAC, TD3Lag, ) diff --git a/omnisafe/algorithms/off_policy/__init__.py b/omnisafe/algorithms/off_policy/__init__.py index c6e2c9435..363ffe984 100644 --- a/omnisafe/algorithms/off_policy/__init__.py +++ b/omnisafe/algorithms/off_policy/__init__.py @@ -20,9 +20,21 @@ from omnisafe.algorithms.off_policy.sac import SAC from omnisafe.algorithms.off_policy.sac_lag import SACLag from omnisafe.algorithms.off_policy.sac_pid import SACPID +from omnisafe.algorithms.off_policy.safe_slac import SafeSLAC from omnisafe.algorithms.off_policy.td3 import TD3 from omnisafe.algorithms.off_policy.td3_lag import TD3Lag from omnisafe.algorithms.off_policy.td3_pid import TD3PID -__all__ = ['DDPG', 'TD3', 'SAC', 'DDPGLag', 'TD3Lag', 'SACLag', 'DDPGPID', 'TD3PID', 'SACPID'] +__all__ = [ + 'DDPG', + 'TD3', + 'SAC', + 'DDPGLag', + 'TD3Lag', + 'SACLag', + 'DDPGPID', + 'TD3PID', + 'SACPID', + 'SafeSLAC', +] diff --git a/omnisafe/algorithms/off_policy/ddpg.py b/omnisafe/algorithms/off_policy/ddpg.py index e7199c32c..226e352c4 100644 --- a/omnisafe/algorithms/off_policy/ddpg.py +++ b/omnisafe/algorithms/off_policy/ddpg.py @@ -257,7 +257,7 @@ def learn(self) -> tuple[float, float, float]: for sample_step in range( epoch * self._samples_per_epoch, - (epoch + 1) * self._samples_per_epoch, + (epoch + 1) * self._samples_per_epoch + 1, ): step = sample_step * self._update_cycle * self._cfgs.train_cfgs.vector_env_nums @@ -305,7 +305,7 @@ def learn(self) -> tuple[float, float, float]: self._logger.store( { - 'TotalEnvSteps': step + 1, + 'TotalEnvSteps': step, 'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time), 'Time/Total': (time.time() - start_time), 'Time/Epoch': (time.time() - epoch_time), diff --git a/omnisafe/algorithms/off_policy/safe_slac.py b/omnisafe/algorithms/off_policy/safe_slac.py new file mode 100644 index 000000000..61b9522f2 --- /dev/null +++ b/omnisafe/algorithms/off_policy/safe_slac.py @@ -0,0 +1,293 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of the Safe Stochastic Latent Actor-Critic algorithm.""" + + +from __future__ import annotations + +import time + +import torch +from rich.progress import track +from torch import optim +from torch.nn.utils.clip_grad import clip_grad_norm_ + +from omnisafe.adapter.offpolicy_latent_adapter import OffPolicyLatentAdapter +from omnisafe.algorithms import registry +from omnisafe.algorithms.off_policy.sac_lag import SACLag +from omnisafe.common.buffer import OffPolicySequenceBuffer +from omnisafe.common.lagrange import Lagrange +from omnisafe.common.latent import CostLatentModel +from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic + + +@registry.register +# pylint: disable-next=too-many-instance-attributes, too-few-public-methods +class SafeSLAC(SACLag): + def _init(self) -> None: + if self._cfgs.algo_cfgs.auto_alpha: + self._target_entropy = -torch.prod(torch.Tensor(self._env.action_space.shape)).item() + self._log_alpha = torch.zeros(1, requires_grad=True, device=self._device) + + assert self._cfgs.model_cfgs.critic.lr is not None + self._alpha_optimizer = optim.Adam( + [self._log_alpha], + lr=self._cfgs.model_cfgs.critic.lr, + ) + else: + self._log_alpha = torch.log( + torch.tensor(self._cfgs.algo_cfgs.alpha, device=self._device), + ) + + self._lagrange: Lagrange = Lagrange(**self._cfgs.lagrange_cfgs) + + self._buf: OffPolicySequenceBuffer = OffPolicySequenceBuffer( + obs_space=self._env.observation_space, + act_space=self._env.action_space, + size=self._cfgs.algo_cfgs.size, + batch_size=self._cfgs.algo_cfgs.batch_size, + device=self._device, + num_sequences=self._cfgs.algo_cfgs.num_sequences, + ) + self._is_latent_model_init_learned = False + + def _init_env(self) -> None: + self._env: OffPolicyLatentAdapter = OffPolicyLatentAdapter( + self._env_id, + self._cfgs.train_cfgs.vector_env_nums, + self._seed, + self._cfgs, + ) + assert ( + self._cfgs.algo_cfgs.steps_per_epoch % self._cfgs.train_cfgs.vector_env_nums == 0 + ), 'The number of steps per epoch is not divisible by the number of environments.' + + assert ( + int(self._cfgs.train_cfgs.total_steps) % self._cfgs.algo_cfgs.steps_per_epoch == 0 + ), 'The total number of steps is not divisible by the number of steps per epoch.' + self._epochs: int = int( + self._cfgs.train_cfgs.total_steps // self._cfgs.algo_cfgs.steps_per_epoch, + ) + self._epoch: int = 0 + self._steps_per_epoch: int = ( + self._cfgs.algo_cfgs.steps_per_epoch // self._cfgs.train_cfgs.vector_env_nums + ) + + self._update_cycle: int = self._cfgs.algo_cfgs.update_cycle + assert ( + self._steps_per_epoch % self._update_cycle == 0 + ), 'The number of steps per epoch is not divisible by the number of steps per sample.' + self._samples_per_epoch: int = self._steps_per_epoch // self._update_cycle + self._update_count: int = 0 + self._update_latent_count = 0 + + def _init_model(self) -> None: + self._cfgs.model_cfgs.critic['num_critics'] = 2 + + self._latent_model = CostLatentModel( + obs_shape=self._env.observation_space.shape, + act_shape=self._env.action_space.shape, + feature_dim=self._cfgs.algo_cfgs.feature_dim, + latent_dim_1=self._cfgs.algo_cfgs.latent_dim_1, + latent_dim_2=self._cfgs.algo_cfgs.latent_dim_2, + hidden_sizes=self._cfgs.algo_cfgs.hidden_sizes, + image_noise=self._cfgs.algo_cfgs.image_noise, + ).to(self._device) + self._update_latent_count = 0 + + self._actor_critic: ConstraintActorQCritic = ConstraintActorQCritic( + obs_space=self._env.latent_space, + act_space=self._env.action_space, + model_cfgs=self._cfgs.model_cfgs, + epochs=self._epochs, + ).to(self._device) + + self._actor_critic = torch.compile(self._actor_critic) + self._latent_model = torch.compile(self._latent_model) + + self._latent_model_optimizer = optim.Adam( + self._latent_model.parameters(), + lr=1e-4, + ) + + def learn(self) -> tuple[float, float, float]: + """This is main function for algorithm update. + + It is divided into the following steps: + + - :meth:`rollout`: collect interactive data from environment. + - :meth:`update`: perform actor/critic updates. + - :meth:`log`: epoch/update information for visualization and terminal log print. + + Returns: + ep_ret: average episode return in final epoch. + ep_cost: average episode cost in final epoch. + ep_len: average episode length in final epoch. + """ + self._logger.log('INFO: Start training') + start_time = time.time() + step = 0 + for epoch in range(self._epochs): + self._epoch = epoch + rollout_time = 0.0 + update_time = 0.0 + epoch_time = time.time() + + for sample_step in range( + epoch * self._samples_per_epoch, + (epoch + 1) * self._samples_per_epoch + 1, + ): + step = sample_step * self._update_cycle * self._cfgs.train_cfgs.vector_env_nums + + rollout_start = time.time() + # set noise for exploration + if self._cfgs.algo_cfgs.use_exploration_noise: + self._actor_critic.actor.noise = self._cfgs.algo_cfgs.exploration_noise + + # collect data from environment + self._env.rollout( + rollout_step=self._update_cycle, + agent=self._actor_critic, + buffer=self._buf, + logger=self._logger, + latent_model=self._latent_model, + use_rand_action=(step <= self._cfgs.algo_cfgs.start_learning_steps), + ) + rollout_time += time.time() - rollout_start + + # update parameters + update_start = time.time() + if step > self._cfgs.algo_cfgs.start_learning_steps: + self._update() + # if we haven't updated the network, log 0 for the loss + else: + self._log_when_not_update() + update_time += time.time() - update_start + + eval_start = time.time() + self._env.eval_policy( + episode=self._cfgs.train_cfgs.eval_episodes, + agent=self._actor_critic, + logger=self._logger, + ) + eval_time = time.time() - eval_start + + self._logger.store({'Time/Update': update_time}) + self._logger.store({'Time/Rollout': rollout_time}) + self._logger.store({'Time/Evaluate': eval_time}) + + if ( + step > self._cfgs.algo_cfgs.start_learning_steps + and self._cfgs.model_cfgs.linear_lr_decay + ): + self._actor_critic.actor_scheduler.step() + + self._logger.store( + { + 'TotalEnvSteps': step, + 'Time/FPS': self._cfgs.algo_cfgs.steps_per_epoch / (time.time() - epoch_time), + 'Time/Total': (time.time() - start_time), + 'Time/Epoch': (time.time() - epoch_time), + 'Train/Epoch': epoch, + 'Train/LR': self._actor_critic.actor_scheduler.get_last_lr()[0], + }, + ) + + self._logger.dump_tabular() + + # save model to disk + if (epoch + 1) % self._cfgs.logger_cfgs.save_model_freq == 0: + self._logger.torch_save() + + ep_ret = self._logger.get_stats('Metrics/EpRet')[0] + ep_cost = self._logger.get_stats('Metrics/EpCost')[0] + ep_len = self._logger.get_stats('Metrics/EpLen')[0] + self._logger.close() + + return ep_ret, ep_cost, ep_len + + def _prepare_batch(self, obs_, action_): + with torch.no_grad(): + feature_ = self._latent_model.encoder(obs_) + z_ = torch.cat(self._latent_model.sample_posterior(feature_, action_)[2:4], dim=-1) + + z, next_z = z_[:, -2], z_[:, -1] + action = action_[:, -1] + + return z, next_z, action + + def _update(self) -> None: + if not self._is_latent_model_init_learned: + for _ in track( + range(self._cfgs.algo_cfgs.latent_model_init_learning_steps), + description='initial updating of latent model...', + ): + self._update_latent_model() + self._is_latent_model_init_learned = True + + Jc = self._logger.get_stats('Metrics/EpCost')[0] + if self._epoch > self._cfgs.algo_cfgs.warmup_epochs: + self._lagrange.update_lagrange_multiplier(Jc) + self._logger.store( + { + 'Metrics/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.data.item(), + }, + ) + + for _ in range(self._cfgs.algo_cfgs.update_iters): + self._update_latent_model() + + data = self._buf.sample_batch(64) + self._update_count += 1 + obs_, act_, reward, cost, done = ( + data['obs'], + data['act'], + data['reward'][:, -1].squeeze(), + data['cost'][:, -1].squeeze(), + data['done'][:, -1].squeeze(), + ) + obs, next_obs, act = self._prepare_batch(obs_, act_) + self._update_reward_critic(obs, act, reward, done, next_obs) + self._update_cost_critic(obs, act, cost, done, next_obs) + + if self._update_count % self._cfgs.algo_cfgs.policy_delay == 0: + self._update_actor(obs) + self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak) + + def _update_latent_model( + self, + ): + data = self._buf.sample_batch(32) + obs, act, reward, cost, done = ( + data['obs'], + data['act'], + data['reward'], + data['cost'], + data['done'], + ) + + self._update_latent_count += 1 + loss_kld, loss_image, loss_reward, loss_cost = self._latent_model.calculate_loss( + obs, act, reward, done, cost + ) + + self._latent_model_optimizer.zero_grad() + (loss_kld + loss_image + loss_reward + loss_cost).backward() + if self._cfgs.algo_cfgs.max_grad_norm: + clip_grad_norm_( + self._latent_model.parameters(), + self._cfgs.algo_cfgs.max_grad_norm, + ) + self._latent_model_optimizer.step() diff --git a/omnisafe/common/buffer/__init__.py b/omnisafe/common/buffer/__init__.py index 669770849..4e47a06db 100644 --- a/omnisafe/common/buffer/__init__.py +++ b/omnisafe/common/buffer/__init__.py @@ -15,7 +15,7 @@ """Implementation of Buffer.""" from omnisafe.common.buffer.base import BaseBuffer -from omnisafe.common.buffer.offpolicy_buffer import OffPolicyBuffer +from omnisafe.common.buffer.offpolicy_buffer import OffPolicyBuffer, OffPolicySequenceBuffer from omnisafe.common.buffer.onpolicy_buffer import OnPolicyBuffer from omnisafe.common.buffer.vector_offpolicy_buffer import VectorOffPolicyBuffer from omnisafe.common.buffer.vector_onpolicy_buffer import VectorOnPolicyBuffer @@ -24,6 +24,7 @@ __all__ = [ 'BaseBuffer', 'OffPolicyBuffer', + 'OffPolicySequenceBuffer', 'OnPolicyBuffer', 'VectorOffPolicyBuffer', 'VectorOnPolicyBuffer', diff --git a/omnisafe/common/buffer/base.py b/omnisafe/common/buffer/base.py index 08864ecb0..0e521eacf 100644 --- a/omnisafe/common/buffer/base.py +++ b/omnisafe/common/buffer/base.py @@ -22,6 +22,7 @@ from gymnasium.spaces import Box from omnisafe.typing import DEVICE_CPU, OmnisafeSpace +from omnisafe.utils.tools import SequenceQueue class BaseBuffer(ABC): @@ -132,3 +133,53 @@ def store(self, **data: torch.Tensor) -> None: Args: data (torch.Tensor): The data to store. """ + + +class BaseSequenceBuffer(BaseBuffer): + def __init__( + self, + obs_space: OmnisafeSpace, + act_space: OmnisafeSpace, + size: int, + num_sequences: int, + device: torch.device = DEVICE_CPU, + ) -> None: + """Initialize an instance of :class:`BaseBuffer`.""" + self._device: torch.device = device + self._num_sequences = num_sequences + if isinstance(obs_space, Box): + obs_buf = [None] * size + else: + raise NotImplementedError + if isinstance(act_space, Box): + act_buf = torch.zeros( + (size, num_sequences, *act_space.shape), + dtype=torch.float32, + device=device, + ) + else: + raise NotImplementedError + + self.data: dict[str, torch.Tensor | list] = { + 'obs': obs_buf, + 'act': act_buf, + 'reward': torch.zeros(size, num_sequences, 1, dtype=torch.float32, device=device), + 'cost': torch.zeros(size, num_sequences, 1, dtype=torch.float32, device=device), + 'done': torch.zeros(size, num_sequences, 1, dtype=torch.float32, device=device), + } + + self.sequence_queue = SequenceQueue( + obs_space=obs_space, + num_sequences=num_sequences, + device=device, + ) + + self._size: int = size + self._observation_shape = obs_space.shape + + def add_field(self, name: str, shape: tuple[int, ...], dtype: torch.dtype) -> None: + self.data[name] = torch.zeros( + (self._size, self._num_sequences, *shape), + dtype=dtype, + device=self._device, + ) diff --git a/omnisafe/common/buffer/offpolicy_buffer.py b/omnisafe/common/buffer/offpolicy_buffer.py index 76516468f..9cd62dbbf 100644 --- a/omnisafe/common/buffer/offpolicy_buffer.py +++ b/omnisafe/common/buffer/offpolicy_buffer.py @@ -16,10 +16,11 @@ from __future__ import annotations +import numpy as np import torch from gymnasium.spaces import Box -from omnisafe.common.buffer.base import BaseBuffer +from omnisafe.common.buffer.base import BaseBuffer, BaseSequenceBuffer from omnisafe.typing import DEVICE_CPU, OmnisafeSpace @@ -115,3 +116,79 @@ def sample_batch(self) -> dict[str, torch.Tensor]: """ idxs = torch.randint(0, self._size, (self._batch_size,)) return {key: value[idxs] for key, value in self.data.items()} + + +class OffPolicySequenceBuffer(BaseSequenceBuffer): + def __init__( # pylint: disable=too-many-arguments + self, + obs_space: OmnisafeSpace, + act_space: OmnisafeSpace, + size: int, + batch_size: int, + num_sequences: int, + device: torch.device = DEVICE_CPU, + ) -> None: + """Initialize an instance of :class:`OffPolicySequenceBuffer`.""" + super().__init__(obs_space, act_space, size, num_sequences, device) + + self._ptr: int = 0 + self._size: int = 0 + self._max_size: int = size + self._batch_size: int = batch_size + + assert ( + self._max_size > self._batch_size + ), 'The size of the buffer must be larger than the batch size.' + + @property + def max_size(self) -> int: + """Return the max size of the buffer.""" + return self._max_size + + @property + def size(self) -> int: + """Return the current size of the buffer.""" + return self._size + + @property + def batch_size(self) -> int: + """Return the batch size of the buffer.""" + return self._batch_size + + def store(self, **data: torch.Tensor) -> None: + """Store data into the buffer. + + .. hint:: + The ReplayBuffer is a circular buffer. When the buffer is full, the oldest data will be + overwritten. + + Args: + data (torch.Tensor): The data to be stored. + """ + self.sequence_queue.append(**data) + if self.sequence_queue.is_full(): + sequece_data = self.sequence_queue.get() + for key, value in sequece_data.items(): + self.data[key][self._ptr] = value + self._ptr = (self._ptr + 1) % self._max_size + self._size = min(self._size + 1, self._max_size) + + def sample_batch(self, batch_size: int | None) -> dict[str, torch.Tensor]: + """Sample a batch of data from the buffer. + + Returns: + The sampled batch of data. + """ + batch_size = batch_size or self._batch_size + idxs = torch.randint(0, self._size, (batch_size,)) + returns = {key: value[idxs] for key, value in self.data.items() if key != 'obs'} + obs = np.empty((batch_size, self._num_sequences + 1, *self._observation_shape)) + for i, idx in enumerate(idxs): + obs[i, ...] = self.data['obs'][idx] + obs = torch.tensor(obs, dtype=torch.float32, device=self._device) + returns.update({'obs': obs}) + return returns + + def reset_sequence_queue(self, obs: torch.Tensor) -> None: + """Reset the sequence queue.""" + self.sequence_queue.reset_sequence_queue(obs) diff --git a/omnisafe/common/latent.py b/omnisafe/common/latent.py new file mode 100644 index 000000000..18e8ee921 --- /dev/null +++ b/omnisafe/common/latent.py @@ -0,0 +1,344 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + + +def calculate_kl_divergence(p_mean, p_std, q_mean, q_std): + var_ratio = (p_std / q_std).pow_(2) + t1 = ((p_mean - q_mean) / q_std).pow_(2) + return 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) + + +def build_mlp( + input_dim, + output_dim, + hidden_sizes=None, + hidden_activation=nn.Tanh(), + output_activation=None, +): + if hidden_sizes is None: + hidden_sizes = [64, 64] + layers = [] + units = input_dim + for next_units in hidden_sizes: + layers.append(nn.Linear(units, next_units)) + layers.append(hidden_activation) + units = next_units + model = nn.Sequential(*layers) + model.add_module('last_linear', nn.Linear(units, output_dim)) + if output_activation is not None: + model.add_module('output_activation', output_activation) + return model + + +def initialize_weight(m): + if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + nn.init.xavier_uniform_(m.weight, gain=1.0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class FixedGaussian(torch.nn.Module): + """ + Fixed diagonal gaussian distribution. + """ + + def __init__(self, output_dim, std) -> None: + super().__init__() + self.output_dim = output_dim + self.std = std + + def forward(self, x): + mean = torch.zeros(x.size(0), self.output_dim, device=x.device) + std = torch.ones(x.size(0), self.output_dim, device=x.device).mul_(self.std) + return mean, std + + +class Gaussian(torch.nn.Module): + """ + Diagonal gaussian distribution with state dependent variances. + """ + + def __init__(self, input_dim, output_dim, hidden_sizes=(256, 256)) -> None: + super().__init__() + self.net = build_mlp( + input_dim=input_dim, + output_dim=2 * output_dim, + hidden_sizes=hidden_sizes, + hidden_activation=nn.ELU(), + ).apply(initialize_weight) + + def forward(self, x): + if x.ndim == 3: + B, S, _ = x.size() + x = self.net(x.view(B * S, _)).view(B, S, -1) + else: + x = self.net(x) + mean, std = torch.chunk(x, 2, dim=-1) + std = F.softplus(std) + 1e-5 + return mean, std + + +class Bernoulli(torch.nn.Module): + """ + Diagonal gaussian distribution with state dependent variances. + """ + + def __init__(self, input_dim, output_dim, hidden_sizes=(256, 256)) -> None: + super().__init__() + self.net = build_mlp( + input_dim=input_dim, + output_dim=output_dim, + hidden_sizes=hidden_sizes, + hidden_activation=nn.ELU(), + ).apply(initialize_weight) + + def forward(self, x): + if x.ndim == 3: + B, S, _ = x.size() + x = self.net(x.view(B * S, _)).view(B, S, -1) + else: + x = self.net(x) + return torch.sigmoid(x) + + +class Decoder(torch.nn.Module): + """ + Decoder. + """ + + def __init__(self, input_dim=288, output_dim=3, std=1.0) -> None: + super().__init__() + + self.net = nn.Sequential( + # (32+256, 1, 1) -> (256, 4, 4) + nn.ConvTranspose2d(input_dim, 256, 4), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + # (256, 4, 4) -> (128, 8, 8) + nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + # (128, 8, 8) -> (64, 16, 16) + nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + # (64, 16, 16) -> (32, 32, 32) + nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + # (32, 32, 32) -> (3, 64, 64) + nn.ConvTranspose2d(32, output_dim, 5, 2, 2, 1), + nn.LeakyReLU(inplace=True, negative_slope=0.2), + ).apply(initialize_weight) + self.std = std + + def forward(self, x): + B, S, latent_dim = x.size() + x = x.view(B * S, latent_dim, 1, 1) + x = self.net(x) + _, C, W, H = x.size() + x = x.view(B, S, C, W, H) + return x, torch.ones_like(x).mul_(self.std) + + +class Encoder(torch.nn.Module): + """ + Encoder. + """ + + def __init__(self, input_dim=3, output_dim=256) -> None: + super().__init__() + + self.net = nn.Sequential( + # (3, 64, 64) -> (32, 32, 32) + nn.Conv2d(input_dim, 32, 5, 2, 2), + nn.ELU(inplace=True), + # (32, 32, 32) -> (64, 16, 16) + nn.Conv2d(32, 64, 3, 2, 1), + nn.ELU(inplace=True), + # (64, 16, 16) -> (128, 8, 8) + nn.Conv2d(64, 128, 3, 2, 1), + nn.ELU(inplace=True), + # (128, 8, 8) -> (256, 4, 4) + nn.Conv2d(128, 256, 3, 2, 1), + nn.ELU(inplace=True), + # (256, 4, 4) -> (256, 1, 1) + nn.Conv2d(256, output_dim, 4), + nn.ELU(inplace=True), + ).apply(initialize_weight) + + def forward(self, x): + B, S, C, H, W = x.size() + x = x.view(B * S, C, H, W) + x = self.net(x) + return x.view(B, S, -1) + + +class CostLatentModel(torch.nn.Module): + """ + Stochastic latent variable model to estimate latent dynamics, reward and cost. + """ + + def __init__( + self, + obs_shape, + act_shape, + feature_dim=256, + latent_dim_1=32, + latent_dim_2=256, + hidden_sizes=(256, 256), + image_noise=0.1, + ) -> None: + super().__init__() + self.bceloss = torch.nn.BCELoss(reduction='none') + # p(z1(0)) = N(0, I) + self.z1_prior_init = FixedGaussian(latent_dim_1, 1.0) + # p(z2(0) | z1(0)) + self.z2_prior_init = Gaussian(latent_dim_1, latent_dim_2, hidden_sizes) + # p(z1(t+1) | z2(t), a(t)) + self.z1_prior = Gaussian( + latent_dim_2 + act_shape[0], + latent_dim_1, + hidden_sizes, + ) + # p(z2(t+1) | z1(t+1), z2(t), a(t)) + self.z2_prior = Gaussian( + latent_dim_1 + latent_dim_2 + act_shape[0], + latent_dim_2, + hidden_sizes, + ) + + # q(z1(0) | feat(0)) + self.z1_posterior_init = Gaussian(feature_dim, latent_dim_1, hidden_sizes) + # q(z2(0) | z1(0)) = p(z2(0) | z1(0)) + self.z2_posterior_init = self.z2_prior_init + # q(z1(t+1) | feat(t+1), z2(t), a(t)) + self.z1_posterior = Gaussian( + feature_dim + latent_dim_2 + act_shape[0], + latent_dim_1, + hidden_sizes, + ) + # q(z2(t+1) | z1(t+1), z2(t), a(t)) = p(z2(t+1) | z1(t+1), z2(t), a(t)) + self.z2_posterior = self.z2_prior + + # p(r(t) | z1(t), z2(t), a(t), z1(t+1), z2(t+1)) + self.reward = Gaussian( + 2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0], + 1, + hidden_sizes, + ) + + self.cost = Bernoulli( + 2 * latent_dim_1 + 2 * latent_dim_2 + act_shape[0], + 1, + hidden_sizes, + ) + + # feat(t) = Encoder(x(t)) + self.encoder = Encoder(obs_shape[0], feature_dim) + # p(x(t) | z1(t), z2(t)) + self.decoder = Decoder( + latent_dim_1 + latent_dim_2, + obs_shape[0], + std=np.sqrt(image_noise), + ) + self.apply(initialize_weight) + + def sample_prior(self, actions_, z2_post_): + # p(z1(0)) = N(0, I) + z1_mean_init, z1_std_init = self.z1_prior_init(actions_[:, 0]) + # p(z1(t) | z2(t-1), a(t-1)) + z1_mean_, z1_std_ = self.z1_prior( + torch.cat([z2_post_[:, : actions_.size(1)], actions_], dim=-1) + ) + # Concatenate initial and consecutive latent variables + z1_mean_ = torch.cat([z1_mean_init.unsqueeze(1), z1_mean_], dim=1) + z1_std_ = torch.cat([z1_std_init.unsqueeze(1), z1_std_], dim=1) + return (z1_mean_, z1_std_) + + def sample_posterior(self, features_, actions_): + # p(z1(0)) = N(0, I) + z1_mean, z1_std = self.z1_posterior_init(features_[:, 0]) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + # p(z2(0) | z1(0)) + z2_mean, z2_std = self.z2_posterior_init(z1) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + + z1_mean_ = [z1_mean] + z1_std_ = [z1_std] + z1_ = [z1] + z2_ = [z2] + + for t in range(1, actions_.size(1) + 1): + # q(z1(t) | feat(t), z2(t-1), a(t-1)) + z1_mean, z1_std = self.z1_posterior( + torch.cat([features_[:, t], z2, actions_[:, t - 1]], dim=1) + ) + z1 = z1_mean + torch.randn_like(z1_std) * z1_std + # q(z2(t) | z1(t), z2(t-1), a(t-1)) + z2_mean, z2_std = self.z2_posterior(torch.cat([z1, z2, actions_[:, t - 1]], dim=1)) + z2 = z2_mean + torch.randn_like(z2_std) * z2_std + + z1_mean_.append(z1_mean) + z1_std_.append(z1_std) + z1_.append(z1) + z2_.append(z2) + + z1_mean_ = torch.stack(z1_mean_, dim=1) + z1_std_ = torch.stack(z1_std_, dim=1) + z1_ = torch.stack(z1_, dim=1) + z2_ = torch.stack(z2_, dim=1) + return (z1_mean_, z1_std_, z1_, z2_) + + # + def calculate_loss(self, state_, action_, reward_, done_, cost_): + # Calculate the sequence of features. + feature_ = self.encoder(state_) + + # Sample from latent variable model. + z1_mean_post_, z1_std_post_, z1_, z2_ = self.sample_posterior(feature_, action_) + z1_mean_pri_, z1_std_pri_ = self.sample_prior(action_, z2_) + + # Calculate KL divergence loss. + loss_kld = ( + calculate_kl_divergence(z1_mean_post_, z1_std_post_, z1_mean_pri_, z1_std_pri_) + .mean(dim=0) + .sum() + ) + + # Prediction loss of images. + z_ = torch.cat([z1_, z2_], dim=-1) + state_mean_, state_std_ = self.decoder(z_) + state_noise_ = (state_ - state_mean_) / (state_std_ + 1e-8) + log_likelihood_ = (-0.5 * state_noise_.pow(2) - state_std_.log()) - 0.5 * math.log( + 2 * math.pi + ) + loss_image = -log_likelihood_.mean(dim=0).sum() + + # Prediction loss of rewards. + x = torch.cat([z_[:, :-1], action_, z_[:, 1:]], dim=-1) + B, S, X = x.shape + reward_mean_, reward_std_ = self.reward(x.view(B * S, X)) + reward_mean_ = reward_mean_.view(B, S, 1) + reward_std_ = reward_std_.view(B, S, 1) + reward_noise_ = (reward_ - reward_mean_) / (reward_std_ + 1e-8) + log_likelihood_reward_ = (-0.5 * reward_noise_.pow(2) - reward_std_.log()) - 0.5 * math.log( + 2 * math.pi + ) + loss_reward = -log_likelihood_reward_.mul_(1 - done_).mean(dim=0).sum() + + p = self.cost(x.view(B * S, X)).view(B, S, 1) + q = 1 - p + weight_p = 100 + binary_cost_ = torch.sign(cost_) + loss_cost = ( + -30 + * ( + weight_p * binary_cost_ * torch.log(p + 1e-6) + + (1 - binary_cost_) * torch.log(q + 1e-6) + ) + .mean(dim=0) + .sum() + ) + + return loss_kld, loss_image, loss_reward, loss_cost diff --git a/omnisafe/configs/off-policy/SafeSLAC.yaml b/omnisafe/configs/off-policy/SafeSLAC.yaml new file mode 100644 index 000000000..a11626fe3 --- /dev/null +++ b/omnisafe/configs/off-policy/SafeSLAC.yaml @@ -0,0 +1,148 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +defaults: + # seed for random number generator + seed: 0 + # training configurations + train_cfgs: + # device to use for training, options: cpu, cuda, cuda:0, cuda:0,1, etc. + device: cpu + # number of threads for torch + torch_threads: 16 + # number of vectorized environments + vector_env_nums: 1 + # number of parallel agent, similar to a3c + parallel: 1 + # total number of steps to train + total_steps: 1000000 + # number of evaluate episodes + eval_episodes: 0 + # algorithm configurations + algo_cfgs: + # number of times each action is repeated in the environment + action_repeat: 2 + # initial learning steps for the latent model + latent_model_init_learning_steps: 30000 + # number of sequences used for training + num_sequences: 10 + # amount of noise added to input images as a form of augmentation + image_noise: 0.4 + # list of sizes for hidden layers of dynamics model + hidden_sizes: [256, 256] + # dimensionality of feature vectors initially + feature_dim: 256 + # dimensionality of the first latent space + latent_dim_1: 32 + # dimensionality of the second latent space + latent_dim_2: 200 + # dimensionality of feature vectors + feature_dim: 200 + # number of steps to update the policy + steps_per_epoch: 2000 + # number of steps per sample + update_cycle: 100 + # number of iterations to update the policy + update_iters: 100 + # The size of replay buffer + size: 2000000 + # The size of batch + batch_size: 64 + # normalize reward + reward_normalize: False + # normalize cost + cost_normalize: False + # normalize observation + obs_normalize: False + # max gradient norm + max_grad_norm: 40.0 + # use critic norm + use_critic_norm: False + # critic norm coefficient + critic_norm_coeff: 0.001 + # The soft update coefficient + polyak: 0.005 + # The discount factor of GAE + gamma: 0.995 + # Actor perdorm random action before `start_learning_steps` steps + start_learning_steps: 30000 + # The delay step of policy update + policy_delay: 2 + # Whether to use the exploration noise + use_exploration_noise: False + # The exploration noise + exploration_noise: 0.1 + # The policy noise + policy_noise: 0.2 + # policy_noise_clip + policy_noise_clip: 0.5 + # The value of alpha + alpha: 0.004 + # Whether to use auto alpha + auto_alpha: False + # use cost + use_cost: True + # warm up epoch + warmup_epochs: 100 + # logger configurations + logger_cfgs: + # use wandb for logging + use_wandb: False + # wandb project name + wandb_project: omnisafe + # use tensorboard for logging + use_tensorboard: True + # save model frequency + save_model_freq: 100 + # save logger path + log_dir: "./runs" + # save model path + window_lens: 10 + # model configurations + model_cfgs: + # weight initialization mode + weight_initialization_mode: "kaiming_uniform" + # actor type + actor_type: gaussian_sac + # linear learning rate decay + linear_lr_decay: False + # Configuration of Actor network + actor: + # Size of hidden layers + hidden_sizes: [256, 256] + # Activation function + activation: relu + # The learning rate of Actor network + lr: 0.000005 + # Configuration of Critic network + critic: + # The number of critic networks + num_critics: 2 + # Size of hidden layers + hidden_sizes: [256, 256] + # Activation function + activation: relu + # The learning rate of Critic network + lr: 0.001 + # lagrangian configurations + lagrange_cfgs: + # Tolerance of constraint violation + cost_limit: 25.0 + # Initial value of lagrangian multiplier + lagrangian_multiplier_init: 0.000 + # Learning rate of lagrangian multiplier + lambda_lr: 0.0002 + # Type of lagrangian optimizer + lambda_optimizer: "Adam" diff --git a/omnisafe/envs/__init__.py b/omnisafe/envs/__init__.py index 57778938e..8f58b06f6 100644 --- a/omnisafe/envs/__init__.py +++ b/omnisafe/envs/__init__.py @@ -19,3 +19,4 @@ from omnisafe.envs.mujoco_env import MujocoEnv from omnisafe.envs.safety_gymnasium_env import SafetyGymnasiumEnv from omnisafe.envs.safety_gymnasium_modelbased import SafetyGymnasiumModelBased +from omnisafe.envs.safety_gymnasium_vision_env import SafetyGymnasiumVisionEnv diff --git a/omnisafe/envs/safety_gymnasium_vision_env.py b/omnisafe/envs/safety_gymnasium_vision_env.py new file mode 100644 index 000000000..778ddacb4 --- /dev/null +++ b/omnisafe/envs/safety_gymnasium_vision_env.py @@ -0,0 +1,194 @@ +# Copyright 2023 OmniSafe Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Environments in the Vision-Based Safety-Gymnasium.""" + +from __future__ import annotations + +import os +from typing import Any, ClassVar + +import numpy as np +import safety_gymnasium +import torch + +from omnisafe.envs.core import CMDP, env_register +from omnisafe.typing import DEVICE_CPU, Box + + +@env_register +class SafetyGymnasiumVisionEnv(CMDP): + need_auto_reset_wrapper: bool = False + need_time_limit_wrapper: bool = False + + _support_envs: ClassVar[list[str]] = [ + 'SafetyCarGoal1Vision-v0', + 'SafetyPointGoal1Vision-v0', + 'SafetyPointButton1Vision-v0', + 'SafetyPointPush1Vision-v0', + 'SafetyPointGoal2Vision-v0', + 'SafetyPointButton2Vision-v0', + 'SafetyPointPush2Vision-v0', + ] + + def __init__( + self, + env_id: str, + num_envs: int = 1, + device: torch.device = DEVICE_CPU, + **kwargs: Any, + ) -> None: + """Initialize an instance of :class:`SafetyGymnasiumVisionEnv`.""" + super().__init__(env_id) + self._num_envs = num_envs + self._device = torch.device(device) + if 'MUJOCO_GL' not in os.environ: + os.environ['MUJOCO_GL'] = 'osmesa' + self.need_time_limit_wrapper = True + self.need_auto_reset_wrapper = True + self._env = safety_gymnasium.make( + id=env_id, + autoreset=True, + render_mode='rgb_array', + camera_name='vision', + width=64, + height=64, + **kwargs, + ) + + self._observation_space = Box(shape=(3, 64, 64), low=0, high=255, dtype=np.uint8) + self._action_space = self._env.action_space + + self._metadata = self._env.metadata + + def step( + self, + action: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + dict[str, Any], + ]: + """Step the environment. + + .. note:: + OmniSafe uses auto reset wrapper to reset the environment when the episode is + terminated. So the ``obs`` will be the first observation of the next episode. And the + true ``final_observation`` in ``info`` will be stored in the ``final_observation`` key + of ``info``. + + Args: + action (torch.Tensor): Action to take. + + Returns: + observation: The agent's observation of the current environment. + reward: The amount of reward returned after previous action. + cost: The amount of cost returned after previous action. + terminated: Whether the episode has ended. + truncated: Whether the episode has been truncated due to a time limit. + info: Some information logged by the environment. + """ + obs, reward, cost, terminated, truncated, info = self._env.step( + action.detach().cpu().numpy(), + ) + + reward, cost, terminated, truncated = ( + torch.as_tensor(x, dtype=torch.float32, device=self._device) + for x in (reward, cost, terminated, truncated) + ) + obs = ( + torch.as_tensor(obs['vision'].copy(), dtype=torch.uint8, device=self._device) + .float() + .div_(255.0) + .transpose(0, -1) + ) + if 'final_observation' in info: + info['final_observation'] = np.array( + [ + array if array is not None else np.zeros(obs.shape[-1]) + for array in info['final_observation']['vision'].copy() + ], + ) + info['final_observation'] = ( + torch.as_tensor( + info['final_observation'], + dtype=torch.int8, + device=self._device, + ) + .float() + .div_(255.0) + .transpose(0, -1) + ) + + return obs, reward, cost, terminated, truncated, info + + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + """Reset the environment. + + Args: + seed (int, optional): The random seed. Defaults to None. + options (dict[str, Any], optional): The options for the environment. Defaults to None. + + Returns: + observation: Agent's observation of the current environment. + info: Some information logged by the environment. + """ + obs, info = self._env.reset(seed=seed, options=options) + return ( + torch.as_tensor(obs['vision'].copy(), dtype=torch.uint8, device=self._device) + .float() + .div_(255.0) + .transpose(0, -1), + info, + ) + + def set_seed(self, seed: int) -> None: + """Set the seed for the environment. + + Args: + seed (int): Seed to set. + """ + self.reset(seed=seed) + + def sample_action(self) -> torch.Tensor: + """Sample a random action. + + Returns: + A random action. + """ + return torch.as_tensor( + self._env.action_space.sample(), + dtype=torch.float32, + device=self._device, + ) + + def render(self) -> Any: + """Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment. + + Returns: + The render frames: we recommend to use `np.ndarray` + which could construct video by moviepy. + """ + return self._env.render() + + def close(self) -> None: + """Close the environment.""" + self._env.close() diff --git a/omnisafe/utils/model.py b/omnisafe/utils/model.py index 4d01c5ff5..5e1b44a0c 100644 --- a/omnisafe/utils/model.py +++ b/omnisafe/utils/model.py @@ -16,10 +16,13 @@ from __future__ import annotations +from collections import deque + import numpy as np +import torch from torch import nn -from omnisafe.typing import Activation, InitFunction +from omnisafe.typing import DEVICE_CPU, Activation, InitFunction def initialize_layer(init_function: InitFunction, layer: nn.Linear) -> None: @@ -109,3 +112,43 @@ def build_mlp_network( initialize_layer(weight_initialization_mode, affine_layer) layers += [affine_layer, act_fn()] return nn.Sequential(*layers) + + +class ObservationConcator: + def __init__(self, state_shape, action_shape, num_sequences, device=DEVICE_CPU) -> None: + self.state_shape = state_shape + self.action_shape = action_shape + self.num_sequences = num_sequences + self.device = device + + def reset_episode(self, state): + self._state = deque(maxlen=self.num_sequences) + self._action = deque(maxlen=self.num_sequences - 1) + for _ in range(self.num_sequences - 1): + self._state.append( + torch.zeros(self.state_shape, dtype=torch.float32, device=self.device), + ) + self._action.append( + torch.zeros(self.action_shape, dtype=torch.float32, device=self.device), + ) + self._state.append(state) + + def append(self, state, action): + self._state.append(state) + self._action.append(action) + + @property + def state(self): + return self._state[None, ...] + + @property + def last_state(self): + return self._state[-1][None, ...] + + @property + def action(self): + return self._action.reshape(1, -1) + + @property + def last_action(self): + return self._action[-1] diff --git a/omnisafe/utils/tools.py b/omnisafe/utils/tools.py index 77710c3ee..8fc0818da 100644 --- a/omnisafe/utils/tools.py +++ b/omnisafe/utils/tools.py @@ -21,6 +21,7 @@ import os import random import sys +from collections import deque from typing import Any import numpy as np @@ -30,7 +31,7 @@ from rich.console import Console from torch.version import cuda as cuda_version -from omnisafe.typing import DEVICE_CPU +from omnisafe.typing import DEVICE_CPU, OmnisafeSpace def get_flat_params_from(model: torch.nn.Module) -> torch.Tensor: @@ -154,9 +155,9 @@ def seed_all(seed: int) -> None: torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) try: - torch.use_deterministic_algorithms(True) - torch.backends.cudnn.enabled = False - torch.backends.cudnn.benchmark = False + # torch.use_deterministic_algorithms(True) + # torch.backends.cudnn.enabled = False + # torch.backends.cudnn.benchmark = False if cuda_version is not None and float(cuda_version) >= 10.2: os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed) @@ -366,3 +367,76 @@ def get_device(device: torch.device | str | int = DEVICE_CPU) -> torch.device: return torch.device('cpu') return device + + +def create_feature_actions(feature_, action_): + N = feature_.size(0) + # Flatten sequence of features. + # f (batch_size, (num_sequences)*feature_dim) + f = feature_[:, :-1].view(N, -1) + n_f = feature_[:, 1:].view(N, -1) + # Flatten sequence of actions. + a = action_[:, :-1].view(N, -1) + n_a = action_[:, 1:].view(N, -1) + # Concatenate feature and action. + fa = torch.cat([f, a], dim=-1) + n_fa = torch.cat([n_f, n_a], dim=-1) + return fa, n_fa + + +class LazyFrames: + def __init__(self, frames) -> None: + self._frames = list(frames) + + def __array__(self, dtype): + return np.array(self._frames, dtype=dtype) + + def __len__(self) -> int: + return len(self._frames) + + +class SequenceQueue: + def __init__(self, obs_space: OmnisafeSpace, num_sequences: int = 8, device=DEVICE_CPU) -> None: + self.num_sequences = num_sequences + self._reset_episode = False + self._obs_space = obs_space + self._device = device + self.data = {} + self.data['obs'] = deque(maxlen=self.num_sequences + 1) + self.data['act'] = deque(maxlen=self.num_sequences) + self.data['reward'] = deque(maxlen=self.num_sequences) + self.data['done'] = deque(maxlen=self.num_sequences) + self.data['cost'] = deque(maxlen=self.num_sequences) + + def reset_sequence_queue(self, obs): + for k in self.data: + self.data[k].clear() + self._reset_episode = True + self.data['obs'].append(obs.detach().cpu().numpy()) + + def append(self, **data: torch.Tensor): + assert self._reset_episode, self._reset_episode + for key, value in data.items(): + self.data[key].append(value.detach().cpu().numpy()) + + def _process_get(self, key: str) -> LazyFrames | torch.Tensor: + if key == 'obs': + return np.array(LazyFrames(self.data['obs']), dtype=np.float32).swapaxes(0, 1).squeeze() + else: + return torch.tensor( + np.array(self.data[key], dtype=np.float32), + dtype=torch.float32, + device=self._device, + ) + + def get(self): + return {key: self._process_get(key) for key in self.data} + + def is_empty(self): + return len(self.data['reward']) == 0 + + def is_full(self): + return len(self.data['reward']) == self.num_sequences + + def __len__(self) -> int: + return len(self.data['reward'])