跳到主要内容

在 Hugging Face 使用 Stable-Baselines3

stable-baselines3 是一组在 PyTorch 中实现的可靠强化学习算法。

在 Hub 上探索 Stable-Baselines3

你可以通过在模型页面左侧筛选来找到 Stable-Baselines3 模型。

Hub 上的所有模型都配备了有用的功能:

  1. 自动生成的模型卡片,包含描述、训练配置等。
  2. 有助于可发现性的元数据标签。
  3. 评估结果以与其他模型进行比较。
  4. 一个视频小部件,你可以观看智能体的表现。

安装库

要安装 stable-baselines3 库,你需要安装两个包:

  • stable-baselines3:Stable-Baselines3 库。
  • huggingface-sb3:用于从 Hub 加载和上传 Stable-baselines3 模型的附加代码。
pip install stable-baselines3
pip install huggingface-sb3

使用现有模型

你可以使用 load_from_hub 函数简单地从 Hub 下载模型

checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
)

你需要定义两个参数:

  • --repo-id:要下载的 Hugging Face 仓库名称。
  • --filename:要下载的文件。

分享你的模型

你可以使用两个不同的函数轻松上传模型:

  1. package_to_hub():保存模型、评估它、生成模型卡片并录制智能体的重放视频,然后将完整的仓库推送到 Hub。
package_to_hub(model=model, 
model_name="ppo-LunarLander-v2",
model_architecture="PPO",
env_id=env_id,
eval_env=eval_env,
repo_id="ThomasSimonini/ppo-LunarLander-v2",
commit_message="Test commit")

你需要定义七个参数:

  • --model:你训练的模型。
  • --model_architecture:模型架构的名称(DQN、PPO、A2C、SAC...)。
  • --env_id:环境名称。
  • --eval_env:用于评估智能体的环境。
  • --repo-id:要创建或更新的 Hugging Face 仓库名称。格式为 <your huggingface username>/<the repo name>
  • --commit-message:提交消息。
  • --filename:要推送到 Hub 的文件。
  1. push_to_hub():简单地将文件推送到 Hub
push_to_hub(
repo_id="ThomasSimonini/ppo-LunarLander-v2",
filename="ppo-LunarLander-v2.zip",
commit_message="Added LunarLander-v2 model trained with PPO",
)

你需要定义三个参数:

  • --repo-id:要创建或更新的 Hugging Face 仓库名称。格式为 <your huggingface username>/<the repo name>
  • --filename:要推送到 Hub 的文件。
  • --commit-message:提交消息。

其他资源

  • Hugging Face Stable-Baselines3 文档
  • Stable-Baselines3 文档