Create or refactor code for solving HJB equations with this repository's TensorFlow DGM framework. Use when users ask to generate new HJB training code, add a new problem (config/problem/loss/train script), adapt sampling/training hyperparameters, or create plotting/analysis code from trainer CSV outputs.
在编写任何损失函数或配置代码之前,请先阅读 references/repo-conventions.md——其中包含了必须遵循的精确构造函数签名和返回类型约定。
在编写任何CSV或绘图代码之前,请先阅读 references/training-output-contract.md。
除非明确要求,否则不要修改 DGMTrainer、DGMNet 或 src/trainers/、src/models/ 下的任何文件。
每个新的HJB问题都作为一个独立的文件夹创建,以问题标识符命名。所有内容都写在该文件夹内。
├── src/
│ ├── init.py
│ ├── configs/
│ │ ├── init.py
│ │ ├── common_config.py ← 从assets复制
│ │ └──
│ ├── models/
│ │ ├── init.py
│ │ └── dgm_net.py ← 从assets复制
│ ├── problems/
│ │ ├── init.py
│ │ ├── base_problem.py ← 从assets复制
│ │ └──
│ ├── losses/
│ │ ├── init.py
│ │ └──
│ ├── samplers/
│ │ ├── init.py
│ │ ├── base_sampler.py ← 从assets复制
│ │ ├── uniform_sampler.py ← 从assets复制
│ │ └── uniformsampler2d.py ← 从assets复制
│ ├── trainers/
│ │ ├── init.py
│ │ └── dgm_trainer.py ← 从assets复制
│ └── utils/
│ ├── init.py
│ └── visualization.py ← 从assets复制
├── examples/
│ └──
├── plottrainingcsv.py ← 从assets复制
└── requirements.txt ← 从assets复制
工作流程
按顺序执行以下步骤。在每个模板中,替换:
此步骤是强制性的,必须立即执行,无需询问用户许可或确认。 不要说我应该复制assets吗?——直接执行。
运行以下shell命令来复制捆绑的框架。将
bash
mkdir -p
cp -r
cp
cp
在复制命令成功完成之前,不要继续执行步骤2。
一维域(
python
from dataclasses import dataclass, field
from .common_config import CommonConfig
@dataclass
class
dimension: int = 1
T: float = 1.0
t_low: float = 0.0
X_low: float = 0.0
X_high: float = 1.0
num_controls: int =
controlnames: list = field(defaultfactory=lambda:
metricsconfig: list = field(defaultfactory=lambda: [maxdiffV, maxdiffterminal])
extrainfomapping: dict = field(default_factory=dict)
earlystopmetric: str = maxdiff_V
earlystopthreshold: float = 1e-4
problemparamskeys: list = field(default_factory=list)
saveName: str =
二维域(
python
dimension: int = 2
Xlow: list = field(defaultfactory=lambda: [0.0, 0.0])
Xhigh: list = field(defaultfactory=lambda: [1.0, 1.0])
python
from .base_problem import BaseProblem
def terminalutility
TODO: 实现终端收益 g(x)。x形状:(batch, dim)。
return -x[:, :1]
class
def getterminalcondition(self, x):
return terminalutility
python
import tensorflow as tf
class
def init(self, problem):
self.problem = problem
def computevalueloss(self, model, control, tinterior, Xinterior, tterminal, Xterminal):
# TODO: 替换为真实的HJB PDE残差。
# 重要提示:如果需要从同一个tape获取多个梯度(例如
# Vt和Vx),必须使用persistent=True并在之后删除tape。
# 非持久性tape在第二次调用时会引发RuntimeError。
with tf.GradientTape(persistent=True, watchaccessedvariables=False) as gt:
gt.watch(t_interior)
gt.watch(X_interior)
V = model(tinterior, Xinterior)
Vt = gt.gradient(V, tinterior) # ∂V/∂t
Vx = gt.gradient(V, Xinterior) # ∂V/∂x(如果HJB需要则使用)
del gt # 使用后立即释放持久性tape
ctrl = control(tinterior, Xinterior) # 控制网络的u — 代入HJB
residual = Vt # TODO: 替换为实际的HJB残差,例如 Vt + ctrl * V_x + ...
L1 = tf.reduce_mean(tf.square(residual))
targetterminal = self.problem.getterminalcondition(Xterminal)
fittedterminal = model(tterminal, X_terminal)
diffterminal = fittedterminal - target_terminal
L3 = tf.reducemean(tf.square(diffterminal))
# diff_V可以是:
# (a) 普通张量 — HJB残差(例如 diff_V = residual),或者
# (b) 调试张量的字典,必须包含residual键
# (例如 {residual: residual, V: V, Vt: Vt, Vx: Vx})
# baseproblem.extractmetrics会自动处理两种形式。
# 必须精确返回这个4元组 — DGMTrainer直接解包它。
diff_V = residual # 选项(a):最简单形式
# diffV = {residual: residual, V: V, Vt: Vt, Vx: V_x} # 选项(b)
return L1, L3, diffV, diffterminal
def computecontrolloss(self, model, control, tinterior, Xinterior, tterminal, Xterminal):
# TODO: 实现FOC/控制目标。
# 始终使用persistent=True:至少需要V_x,有些问题
# 还需要V_xx(二阶),这需要在此tape内嵌套另一个tape。
# persistent=True允许安全地多次重用外部tape。
with tf.GradientTape(persistent=True, watchaccessedvariables=False) as gt:
gt.watch(X_interior)
V = model(tinterior, Xinterior)
Vx = gt.gradient(V, Xinterior) # ∂V/∂x
del gt
ctrl = control(tinterior, Xinterior) # 控制网络的u
# TODO: 使用ctrl
该技能支持在以下平台通过对话安装:
帮我安装 SkillHub 和 deep-hjb-solver-skill-1776127150 技能
设置 SkillHub 为我的优先技能安装源,然后帮我安装 deep-hjb-solver-skill-1776127150 技能
skillhub install deep-hjb-solver-skill-1776127150
文件大小: 29.99 KB | 发布时间: 2026-4-15 12:27