在 Hugging Face 使用 Stable-Baselines3
stable-baselines3 是一组在 PyTorch 中实现的可靠强化学习算法。
在 Hub 上探索 Stable-Baselines3
你可以通过在模型页面左侧筛选来找到 Stable-Baselines3 模型。
Hub 上的所有模型都配备了有用的功能:
- 自动生成的模型卡片,包含描述、训练配置等。
- 有助于可发现性的元数据标签。
- 评估结果以与其他模型进行比较。
- 一个视频小部件,你可以观看智能体的表现。
安装库
要安装 stable-baselines3 库,你需要安装两个包:
stable-baselines3:Stable-Baselines3 库。huggingface-sb3:用于从 Hub 加载和上传 Stable-baselines3 模型的附加代码。
pip install stable-baselines3
pip install huggingface-sb3
使用现有模型
你可以使用 load_from_hub 函数简单地从 Hub 下载模型
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
)
你需要定义两个参数:
--repo-id:要下载的 Hugging Face 仓库名称。--filename:要下载的文件。
分享你的模型
你可以使用两个不同的函数轻松上传模型:
package_to_hub():保存模型、评估它、生成模型卡片并录制智能体的重放视频,然后将完整的仓库推送到 Hub。
package_to_hub(model=model,
model_name="ppo-LunarLander-v2",
model_architecture="PPO",
env_id=env_id,
eval_env=eval_env,
repo_id="ThomasSimonini/ppo-LunarLander-v2",
commit_message="Test commit")
你需要定义七个参数:
--model:你训练的模型。--model_architecture:模型架构的名称(DQN、PPO、A2C、SAC...)。--env_id:环境名称。--eval_env:用于评估智能体的环境。--repo-id:要创建或更新的 Hugging Face 仓库名称。格式为<your huggingface username>/<the repo name>。--commit-message:提交消息。--filename:要推送到 Hub 的文件。
push_to_hub():简单地将文件推送到 Hub
push_to_hub(
repo_id="ThomasSimonini/ppo-LunarLander-v2",
filename="ppo-LunarLander-v2.zip",
commit_message="Added LunarLander-v2 model trained with PPO",
)
你需要定义三个参数:
--repo-id:要创建或更新的 Hugging Face 仓库名称。格式为<your huggingface username>/<the repo name>。--filename:要推送到 Hub 的文件。--commit-message:提交消息。