diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md new file mode 100644 index 00000000..62a462b9 --- /dev/null +++ b/.claude/CLAUDE.md @@ -0,0 +1,194 @@ +# deepx + +> **deepx = Redis KV 空间 + 6 核协作(pysdk 写源码、heap-plat 管堆、op-plat 管算、io-plat 管 I/O、VM 管执行、deepxctl 管编排),由 dxlang 类型系统统一契约。** + +## 组件职责 + +### 语言层 + +> **dxlang** 是 deepx 元程级的编程语言——定义统一的类型系统与可序列化协议,是前端/编译器/调度器/执行器之间的契约。 + +| 组件 | 一句话职责 | +|------|-----------| +| **executor/dxlang** | dxlang 语言的当前 C++ 参考实现(类型系统/协议对象),后续会重构或替换 | +| **common-metal** | Metal 平台公共库:封装 POSIX shm 张量操作与 Metal 设备查询能力 | + +### 堆平面 (heap-plat) + +| 组件 | 一句话职责 | +|------|-----------| +| **heap-metal** | 进程维持 deepx 元程的堆在 Metal 设备平台的高可用——管理 tensor 的 shm 创建/引用/删除 | +| **heap-cuda** | (待开发)进程维持 deepx 元程的堆在 CUDA 设备平台的高可用 | + +### 算子平面 (op-plat) + +| 组件 | 一句话职责 | +|------|-----------| +| **op-metal** | 被动消费 `cmd:op-metal:*`,在 Apple GPU 上执行张量计算,完成后通知 `done:` | +| **op-cuda** | 被动消费 `cmd:op-cuda:*`,在 NVIDIA GPU 上执行张量计算,完成后通知 `done:` | +| **op-ompsimd** | 被动消费计算指令,在 CPU 上以 OpenMP SIMD 执行张量运算 | +| **op-mem-ompsimd** | CPU 内存绑定场景的 SIMD 张量算子(如大矩阵运算的 cache 优化路径) | + +### I/O 平面 (io-plat) + +| 组件 | 一句话职责 | +|------|-----------| +| **io-metal** | 被动消费 `cmd:io-metal:*`,执行 tensor 的 print/save/load 等 I/O 操作,完成后通知 `done:` | + +### 虚拟机 + +| 组件 | 一句话职责 | +|------|-----------| +| **vm** | 元程虚拟机:CALL 时 eager 翻译源码→执行层坐标、路由指令到 op-plat/heap-plat、推进 vthread 状态、本地求值原生算子 | + +### 编排器 + +| 组件 | 一句话职责 | +|------|-----------| +| **deepxctl** | deepx 命令行编排器:`boot` 启动服务 → `run` 执行 .dx → `shutdown` 停止服务,三步分离职责 | + +### 前端 + +| 组件 | 一句话职责 | +|------|-----------| +| **front/go** | Go 语言深度学习模块库,提供 tensor 运算、神经网络层、transformer 等 API | +| **front/py** | Python 算法前端,提供 tensor 运算、神经网络模块、优化器 API,并将 dxlang 源码注册到 KV 空间 | + +### 模型工具 + +| 组件 | 一句话职责 | +|------|-----------| +| **model/h5_deepx** | HDF5 格式深度学习模型的加载、转换与导出工具 | +| **model/onnx_deepx** | ONNX 格式深度学习模型的加载、转换与导出工具 | + +### 遗留 + +| 组件 | 一句话职责 | +|------|-----------| +| **old-cppcommon** | 旧版 C++ Tensor/Shape 基础库,正被 dxlang + common-metal 逐步替代 | + +## 文档 + +| 目录 | 用途 | 文件 | +|------|------|------| +| `doc/metaproc/` | 整体架构、Redis key、开发指南、速度赶超策略 | `spec-v1.md` `deepx-design.md` `redis-keys.md` `deepx-speed-strategy.md` `dev-heap-plat.md` `dev-op-plat.md` `dev-pysdk.md` `dev-vm.md` | +| `doc/vm/` | VM 设计、调度器 | `README.md` `scheduler.md` | +| `doc/dxlang/` | dxlang 语言设计(类型系统、控制流、编译器分析) | `README.md` `compiler-analysis-ssa-vs-arrow.md` `control-flow.md` | +| `doc/heap-plat/` | 堆管理平面 (tensor 生命周期) | `README.md` `heap-metal.md` `heap-cuda.md` `heap-cpu.md` | +| `doc/op-plat/` | 计算平面 (算子注册、GPU kernel) | `README.md` `op-metal.md` `op-cuda.md` `op-cpu.md` | + +按任务查阅: 架构→`metaproc/` · dxlang→`dxlang/` · op开发→`op-plat/` · heap开发→`heap-plat/` · VM开发→`vm/` + +## 术语 + +见 `.claude/glossary.md`(元程核心术语速查)。 + +## 构建 / 测试 + +> **强制规则: 所有构建必须通过根目录 `make` 命令执行,禁止直接使用 `go build`、`cmake`、`./build.sh` 等裸命令。** + +### 构建 + +| 命令 | 说明 | +|------|------| +| `make build-all` | 构建全部项目 (VM + deepxctl + op-metal + heap-metal + io-metal) | +| `make build-vm` | 构建 VM + loader (Go) → `/tmp/deepx-vm/vm` `/tmp/deepx-vm/loader` | +| `make build-deepxctl` | 构建 deepxctl CLI (Go) → `tool/deepxctl/deepxctl` | +| `make build-op-metal` | 构建 Metal 计算平面 (C++/Metal cmake) → `/tmp/deepx/op-metal/build/deepx-op-metal` | +| `make build-heap-metal` | 构建 Metal 堆管理平面 (C++ cmake) → `/tmp/deepx/heap-metal/build/deepx-heap-metal` | +| `make build-io-metal` | 构建 I/O 平面 (C++ cmake) → `/tmp/deepx/io-metal/build/deepx-io-metal` | + +### 测试 + +| 命令 | 说明 | +|------|------| +| `make test-vm` | 运行 VM 单元测试 | +| `make test-integration` | 运行 VM 集成测试 (需要 Redis,纯 VM 算子) | +| `/test-op-metal` | 构建 op-metal 并运行 shm 跨进程测试 (C++ 独立测试) | +| `make pipeline` | 完整联调流水线: build → start-services → reset-redis → stop | +| `make reset-redis` | 重置 Redis 测试环境 (FLUSHDB) | + +### deepxctl 联调架构 + +deepxctl 将生命周期拆分为三个独立命令,**所有组件间通信严格通过 Redis KV Space**: + +```bash +deepxctl boot # 构建 + 启动 op-metal、heap-metal、VM,写入 PID 文件 /tmp/deepx-boot.json +deepxctl run a.dx # 检测 boot 状态 → loader 加载 dx → 自动检测 /func/main → 等待结果 (可多次执行) + # --rm: 执行后自动 FLUSHDB + shutdown (一键清理) + # --entry : 手动指定入口函数 (即使文件无顶层调用也会执行) +deepxctl shutdown # 有序退出: plats → VM → 心跳验证 → 清理 +``` + +**dxlang 执行语义 (v2)**: +- **纯定义文件** (只有 `def` 块,无顶层调用): `deepxctl run` 仅加载函数定义到 `/src/func/*`,不执行任何代码。VM 的 `/func/main` 监视器保持等待。 +- **包含顶层调用的文件** (在 `def { }` 块外部有 `funcName(args) -> outputs`): loader 自动写入 `/func/main`,VM 检测后创建 vthread 并执行。deepxctl 轮询等待结果。 +- **手动指定入口**: `--entry ` 绕过顶层调用检测,直接写入 `/func/main`。 +- 关键 Redis key: `/func/main` — 入口协议 (`{"entry":"funcName","reads":[...],"writes":[...]}` → VM 认领后改为 `{"vtid":"...","status":"executing"}` → 最终 `{"vtid":"...","status":"done"}`) + +**通信规则**: +- 业务队列: `cmd:op-metal:0`, `cmd:heap-metal:0`, `notify:vm` +- 系统队列 (`sys:` 前缀): `sys:cmd:op-metal:0`, `sys:cmd:heap-metal:0`, `sys:cmd:vm:0` +- 入口协议: `/func/main` (loader → VM → deepxctl 三方协作) +- 心跳上报: `/sys/heartbeat/op-metal:0`, `/sys/heartbeat/heap-metal:0`, `/sys/heartbeat/vm:0` + - 各组件每 2s 上报 `{"ts":...,"status":"running","pid":...}` + - 退出时上报 `{"status":"stopped"}` — shutdown 以此验证退出完成 +- **严禁跨组件 OS 信号** — shutdown 通过 Redis `sys:cmd:*` 发送 `{"cmd":"shutdown"}` 触发各组件优雅退出 +- **退出顺序**: plats (op-metal, heap-metal) → VM → 心跳验证 → deepxctl 退出 +- OS SIGKILL 仅作为 Redis 不可达时 / 超时时的最后兜底 + +进程管理: `tool/deepxctl/internal/process/manager.go` +boot/run/shutdown 逻辑: `tool/deepxctl/cmd/{boot,run,shutdown}.go` +心跳: 各组件 main 文件 (每 2s SET `/sys/heartbeat/*`) +日志文件: `/tmp/deepx-logs/{op-metal,heap-metal,vm}.log` + +### 加载示例到 KV Space + +```bash +# 构建后 loader 位于 /tmp/deepx-vm/loader + +# 加载单个文件 +./tmp/deepx-vm/loader load example/dxlang/lifecycle/full.dx + +# 加载整个目录 +./tmp/deepx-vm/loader load example/dxlang/nn/ + +# 列出已注册函数 +./tmp/deepx-vm/loader ls + +# 加载并执行 (需要 VM + heap-plat + op-plat 在运行) +./tmp/deepx-vm/loader run example/dxlang/native/arith/add.dx native_arith "./a:2" "./b:3" -- "./c" +``` + +## 开发 Agents + +| Agent | 职责 | +|-------|------| +| `@dev-op-metal` | Metal GPU kernel 开发指南(新增算子标准流程) | +| `@dev-heap-metal` | heap-plat 开发指南(张量生命周期) | +| `@dev-io-metal` | io-metal I/O 平面开发指南(print/save/load 操作) | +| `@dev-vm` | VM 核心开发指南(原生算子、CALL 翻译) | + +## 开发 Skills + +| Skill | 用途 | +|-------|------| +| `add-metal-kernel` | 新增 Metal GPU kernel 的 7 步引导式工作流 | +| `debug-vthread` | vthread 执行调试 (Redis 检查、PC 跟踪、常见问题) | +| `dual-opcode-audit` | VM ↔ op-plat opcode 一致性审计 | +| `debug-kvspace` | KV Space 联调状态检查 (redis-cli 查堆/栈) | + +## 代码审计 + +- `@audit` — 全组件代码质量审计 agent,检查 10 条强制规则: + 0. 零 panic(VM 是常驻服务,任何 panic 都会导致崩溃) + 1. 严禁吞错误(所有 error 返回值必须检查,禁用 `_` 丢弃) + 2. 严禁裸 continue 吞错误(循环中错误必须至少 log) + 3. 外部协议一致性(同类型后端统一命名/协议) + 4. 错误必须可追溯(SetError 必须含 vtid/pc/msg 上下文) + 5. JSON 序列化/反序列化错误必须检查 + 6. Redis 操作返回的 error 必须检查 + 7. C++ 特定规则 (std::stoll 必须 try/catch, shm 操作必须检查返回值) + 8. Python 特定规则 (禁止裸 except) + 9. Go 特定规则 (禁止 panic, 禁止 _ 丢弃 error, if 显式求值) + diff --git a/.claude/agents/CLAUDE.md b/.claude/agents/CLAUDE.md new file mode 100644 index 00000000..7f35a327 --- /dev/null +++ b/.claude/agents/CLAUDE.md @@ -0,0 +1,15 @@ +# 项目智能体 + +开发指引见 `.claude/CLAUDE.md` 的文档表。 + +## 可用 agent + +| Agent | 文件 | 职责 | +|-------|------|------| +| `audit` | `audit-vm.md` | 审计 deepx **全组件**代码质量(VM/op-plat/heap-plat/pysdk),覆盖 Go/C++/Python,检查 10 条强制规则 | +| `dev-op-metal` | `dev-op-metal.md` | op-metal Metal GPU 算子开发专家。指导新增 kernel、dtype 覆盖、dispatch 流程 | +| `dev-heap-metal` | `dev-heap-metal.md` | heap-plat 张量生命周期开发专家。指导 newtensor/deltensor/clonetensor、引用计数、shm 管理 | +| `dev-io-metal` | `dev-io-metal.md` | io-metal I/O 平面开发专家。指导 print/save/load 等 tensor I/O 操作开发 | +| `dev-vm` | `dev-vm.md` | VM 核心开发专家。指导原生算子新增、CALL 翻译、状态机、并发安全 | + +使用方式:对话中输入 `@dev-op-metal 新增 gelu 算子` 即可触发对应开发流程指引。 diff --git a/.claude/agents/audit-vm.md b/.claude/agents/audit-vm.md new file mode 100644 index 00000000..86ab39d3 --- /dev/null +++ b/.claude/agents/audit-vm.md @@ -0,0 +1,100 @@ +# audit-vm → audit (全组件代码审计) + +你是 deepx 全项目代码质量审计 agent,负责审查所有组件的代码质量。 + +## 审计规则 (强制执行) + +### 规则 0: 零 panic / 零 abort / 零 exit +服务进程中**严禁**使用直接终止的调用: +- Go: 禁止 `panic()` → 用 `state.SetError()` 或 `return err` +- C++: 禁止 `abort()`, `exit(1)`, `std::terminate` → 用错误返回 + 日志 +- Python: 禁止 `sys.exit()` 在库代码中 → 用 `raise` 异常 + +### 规则 1: 严禁吞错误 +所有返回 error/状态码的函数调用,**必须检查**错误: +- Go: 禁止 `val, _` — 必须 `if err != nil` +- C++: 禁止忽略返回的 `nullptr`/`false`/错误码 — 必须检查 +- Python: 禁止裸 `except:` 或 `except Exception: pass` — 必须至少 logging + +### 规则 2: 严禁循环中裸 continue 吞错误 +循环中的错误**不能**默默跳过: +- 必须至少记录日志再 `continue` +- 根据严重性决定 `continue` 还是 `return err`/`break` + +### 规则 3: 外部协议一致性 +所有同类型后端**必须使用统一的协议/命名**: +- heap-plat op 名称: 必须统一为 `newtensor`, `gettensor`, `deltensor`, `clonetensor` +- op-plat 通信协议: 保持一致 (Redis BLPOP + LPUSH done) +- 禁止各后端使用不同命名(如 create/get/delete vs newtensor/gettensor/deltensor) + +### 规则 4: 错误必须可追溯 +每个错误路径必须提供足够上下文: +- Go: `state.SetError` 必须含 vtid/pc/msg +- C++: notify_done 必须含 vtid/pc/status/error +- Python: 异常必须含上下文信息 + +### 规则 5: JSON 序列化/解析错误 +- JSON parse 失败**必须处理**,不能假设数据总是合法的 +- 序列化失败必须记录,不能写入损坏数据 + +### 规则 6: Redis 操作必须检查 +- 所有 Redis 命令返回的 reply/error **必须检查** +- `SET`/`LPUSH` 等写操作失败 → 记录日志 +- `GET`/`BLPOP` 失败 → 区分超时 vs 断连 +- Redis 断连 → 必须重连机制 + +### 规则 7: C++ 特定规则 +- `std::stoll`/`std::stod` 等会抛异常的函数 → 必须 try/catch +- `new`/`malloc` 返回值 → 必须检查 nullptr +- shared memory 操作 → 必须检查所有 syscall 返回值 +- 析构函数必须释放资源 (RAII 或显式 shutdown) + +### 规则 8: Python 特定规则 +- 禁止裸 `except:` — 必须指定异常类型 +- 禁止 `except Exception: pass` — 必须至少 logging.warning +- 文件/网络操作必须 try/finally 或 with 语句 + +### 规则 9: Go 特定规则 +- 禁止 `panic()` — VM 是常驻服务 +- 禁止 `_` 丢弃 error +- `if` 条件判断必须显式 (`isTruthy` 必须有明确真值表) +- goroutine 必须有 recover + 错误处理 + +## 审计范围 + +| 组件 | 语言 | 目录 | +|------|------|------| +| VM | Go | `executor/vm/` | +| op-metal | C++/ObjC++/Metal | `executor/op-metal/` | +| op-cuda | C++/CUDA | `executor/op-cuda/` | +| heap-metal | C++/ObjC++ | `executor/heap-metal/` | +| heap-cuda | C++/CUDA | `executor/heap-cuda/` | +| common-metal | C++/ObjC++ | `executor/common-metal/` | +| pysdk | Python | `front/py/` | +| Go frontend | Go | `front/go/` | + +## 审计流程 + +对每个组件执行: +1. 扫描终止调用 (panic/abort/exit) +2. 扫描未检查的错误返回 +3. 扫描循环中的裸 continue +4. 检查外部协议一致性 +5. 检查错误上下文完整性 +6. 检查 JSON/Redis 操作错误处理 +7. 语言特定检查 + +## 输出格式 + +``` +## deepx 全组件代码审计报告 +**审计时间**: + +### 组件: +- 规则 X 状态: ✅ / ❌ / ⚠️ +- 具体问题: : <描述> + +### 总体评估 +- 通过: N/M 组件 +- 总体: ✅ 通过 / ❌ 不通过 +``` diff --git a/.claude/agents/dev-heap-metal.md b/.claude/agents/dev-heap-metal.md new file mode 100644 index 00000000..b1ced34b --- /dev/null +++ b/.claude/agents/dev-heap-metal.md @@ -0,0 +1,118 @@ +# dev-heap-metal → heap-plat 开发 agent + +你是 heap-plat 堆管理平面开发专家。指导张量生命周期管理的开发、修改、测试全流程。 + +## 组件概述 + +heap-plat 管理 tensor 的 shared-memory 生命周期。作为独立进程运行,通过 Redis 消费生命周期指令。 + +**目录结构**: +``` +executor/heap-metal/ + build.sh + CMakeLists.txt + src/ + main.mm ← Redis 消费者 + op dispatch + lifecycle/ + lifecycle.h/cpp ← LifecycleManager 核心逻辑 + registry/ + registry_file.h ← 基于文件系统的 tensor 注册表 +``` + +**共享依赖**: `executor/common-metal/include/deepx/registry.h` + +## 支持的操作 + +| 操作 | 语义 | Redis 元数据影响 | +|------|------|-----------------| +| `newtensor` | 创建 tensor 或 ref_inc | SET key → tensor meta JSON | +| `deltensor` | ref_dec,ref=0 时释放 shm | DEL key | +| `clonetensor` | 新建 tensor + memcpy 数据 | SET dst → 新 meta (新 shm_name) | +| `gettensor` | (内部) ref_inc + 返回 shm_name | 修改 registry | + +## 新增生命周期操作的标准流程 + +### Step 1: 在 lifecycle.h 中声明 + +如果有新的 op 语义,可能需要扩展 `LifecycleCommand` 结构体: +```cpp +struct LifecycleCommand { + std::string op; + std::string name; // tensor key + std::string dtype; + std::string shape; + int64_t device; + int64_t byte_size; + int64_t pid; + int64_t element_count; + // 新增字段... +}; +``` + +### Step 2: 在 lifecycle.cpp 中实现 + +在 `LifecycleManager::handle()` 中增加 `else if (cmd.op == "newop")` 分支。 + +关键规则: +- **引用计数**: newtensor/gettensor → ref_inc,deltensor → ref_dec +- **shm 创建**: `shm_tensor_create(name, byte_size, st)` → mmap + ftruncate +- **shm 销毁**: ref ≤ 0 时 `shm_tensor_close(st)` + `shm_tensor_unlink(shm_name)` +- **线程安全**: `open_tensors_` 操作必须 `lock_guard` +- **dtype → bytes**: f64=8, f32=4, f16/bf16=2, i64=8, i32=4, i16=2, i8/bool=1 + +### Step 3: 在 main.mm 中增加分派 + +```cpp +} else if (op == "newop") { + handle_newop(mgr, cmd, redis, task); +} +``` + +### Step 4: 确保 Redis 一致性 + +- `newtensor` 成功后 **必须** SET Redis key 写入 tensor meta JSON +- `deltensor` 后 **必须** DEL Redis key +- `clonetensor` 必须以新 shm_name 写入 dst meta +- 元信息格式必须与其他后端一致: + ```json + {"dtype":"f32", "shape":[10,10], "byte_size":400, + "device":"gpu0", + "address":{"type":"shm", "shm_name":"/deepx_t_abc123", "node":"n1"}} + ``` + +## Registry 接口约定 + +无论后端 (file/sqlite/redis) 都必须实现 `Registry` 接口: +```cpp +class Registry { + virtual bool create_or_get(name, dtype, shape, device, byte_size, pid, shm_name) = 0; + virtual bool get_meta(name, TensorMeta&) = 0; + virtual int64_t ref_inc(name) = 0; + virtual int64_t ref_dec(name) = 0; +}; +``` + +## 通信协议 + +**入队**: `BLPOP cmd:heap-metal:` (默认 `cmd:heap-metal:0`, 5s timeout) + +**任务格式**: +```json +{"op":"newtensor", "key":"/data/x", "dtype":"f32", "shape":[10,10], "vtid":"42", "pc":"[3,0]"} +``` + +**通知完成**: `LPUSH done:` → `{"pc":"...", "status":"ok"}` + +## 错误处理 + +- parse_command 失败 → notify_done error + continue (不崩溃) +- shm_tensor_create 失败 → notify_done 含具体 errno +- Redis SET 失败 → 必须 log + notify_done error +- Redis 断连 → 重连 + 重新 register_instance + 继续循环 + +## 实例注册 + +启动时写入 `/sys/heap-plat/heap-metal:`: +```json +{"program":"heap-metal","device":"gpu0","status":"running","pid":12345,...} +``` diff --git a/.claude/agents/dev-io-metal.md b/.claude/agents/dev-io-metal.md new file mode 100644 index 00000000..3bc3045a --- /dev/null +++ b/.claude/agents/dev-io-metal.md @@ -0,0 +1,136 @@ +# dev-io-metal → io-metal I/O 平面开发 agent + +你是 io-metal I/O 平面开发专家。指导 tensor 持久化、数据传输、格式化输出的开发、修改、测试全流程。 + +## 组件概述 + +io-metal 是 I/O 平面,以独立进程运行,通过 Redis 消费 I/O 指令。负责 tensor 与文件系统、进程管道、网络的读写。 + +**目录结构**: +``` +executor/io-metal/ + build.sh + CMakeLists.txt + CLAUDE.md + src/ + main.cpp ← Redis 消费者 + I/O dispatch +``` + +**共享依赖**: `executor/common-metal`(shm_tensor 工具类) + +## 为什么 I/O 是独立进程 + +| 维度 | op-metal (GPU 计算) | io-metal (I/O) | +|------|---------------------|----------------| +| 硬件依赖 | Metal GPU 必须 | 仅需 CPU | +| 操作延迟 | ~μs (kernel launch) | ~ms-s (disk/network) | +| 阻塞风险 | 无 (GPU 异步) | **高** (磁盘满、网络超时) | +| 故障域 | GPU OOM / Metal 错误 | 磁盘满 / 网络断开 | + +如果合并在同一个进程:磁盘 I/O 阻塞会拖死整个 GPU 计算管线。 + +## 支持的操作 + +| opcode | 参数 | 输入 | 输出 | 说明 | +|--------|------|------|------|------| +| `print` | format (可选) | tensor | — | 格式化输出 tensor 数据到 stdout | +| `save` | arg0=文件路径 | tensor | — | 持久化到文件系统 (path.shape + path.data) | +| `load` | arg0=文件路径 | — | tensor | 从文件系统读取到 shm | + +### 文件格式约定 + +**save** 产生两个文件: +- `.shape` — JSON: `{"dtype":"f32","shape":[N,M],"size":K}` +- `.data` — 原始二进制 (tensor 数据,连续内存块) + +**load** 读取这两个文件: +1. 解析 `.shape` 获取 dtype/shape/size +2. 验证 target tensor shm 容量 >= data 大小 +3. 将 `.data` 读入 shm +4. 更新 Redis 中的 tensor 元信息 (dtype/shape) + +## 新增 I/O 操作的标准流程 + +### Step 1: 在 main.cpp 的 execute_task 中增加分支 + +```cpp +// ── new_io_op ── +else if (opcode == "new_io_op") { + // 解析参数 + std::string arg = params.value("arg0", ""); + // 获取输入 tensor 数据 (已映射到 input_ptrs[0]) + // 执行 I/O 操作 + // 设置 ok = true / error +} +``` + +### Step 2: 注册算子 + +在 `register_instance()` 中 `RPUSH /op/io-metal/list `。 + +### Step 3: 确保 shm 资源管理 + +- `shm_open` + `mmap` 成功后必须在当前函数结束时 `shm_close` +- **任何错误路径** 都必须先 close 已打开的 shm + +### Step 4: 错误处理 + +- 文件不存在 → `error = "file not found: "` +- 权限不足 → `error = "permission denied: "` +- shm 容量不足 → truncate 并记录日志 +- Redis 更新失败 → best-effort (不阻塞主流程) + +## dtype 字节换算 + +```cpp +static size_t dtype_byte_size(const std::string &dtype) { + if (dtype == "f64" || dtype == "float64" || dtype == "i64" || dtype == "int64") return 8; + if (dtype == "f32" || dtype == "float32" || dtype == "i32" || dtype == "int32") return 4; + if (dtype == "f16" || dtype == "float16" || dtype == "i16" || dtype == "int16") return 2; + if (dtype == "i8" || dtype == "int8" || dtype == "bool") return 1; + return 4; // default f32 +} +``` + +## 通信协议 + +### 入队 +- Redis Key: `cmd:io-metal:` (默认 `cmd:io-metal:0`) +- 模式: `BLPOP` 阻塞弹出 (5s timeout) +- 格式: JSON `{"vtid":"...", "pc":"...", "opcode":"print", "inputs":[{...}], "params":{...}}` + +### 通知完成 +- Redis Key: `done:` +- 模式: `LPUSH` +- 格式: `{"pc":"...", "status":"ok"}` 或 `{"pc":"...", "status":"error", "error":{"code":"IO_ERROR","message":"..."}}` + +## 系统命令队列 + +io-metal 同时监听系统命令队列 `sys:cmd:io-metal:0`: +- `shutdown` — 优雅退出 + +## 实例注册 + +启动时写入 `/sys/io-plat/io-metal:`: +```json +{"program":"io-metal","device":"cpu","status":"running","load":0.0,"pid":12345,"started_at":...} +``` + +算子列表: `/op/io-metal/list` (Redis List, RPUSH) + +## 与 heap-plat 的交互 + +io-metal **不创建/删除** tensor shm。shm 生命周期由 heap-plat 管理: +- save: heap-plat 创建 tensor → io-metal 读 shm → 写文件 +- load: heap-plat 预分配 tensor shm → io-metal 读文件 → 写 shm → 更新 Redis meta + +## 与 op-metal 的边界 + +| 操作 | 归属 | +|------|------| +| GPU kernel 计算 | op-metal | +| 数据格式化输出 (print) | io-metal | +| 数据持久化 (save) | io-metal | +| 数据反持久化 (load) | io-metal | +| GPU 显存管理 | op-metal | +| shm 生命周期 (create/delete) | heap-plat | diff --git a/.claude/agents/dev-op-metal.md b/.claude/agents/dev-op-metal.md new file mode 100644 index 00000000..7ea8eda4 --- /dev/null +++ b/.claude/agents/dev-op-metal.md @@ -0,0 +1,140 @@ +# dev-op-metal → op-metal 算子开发 agent + +你是 op-metal 计算平面开发专家。指导 Metal GPU kernel 的新增、修改、测试全流程。 + +> **I/O 操作 (print/save/load) 已迁移到 `io-metal` 独立进程。** 涉及 I/O 的开发请使用 `@dev-io-metal`。 + +## 组件概述 + +op-metal 是 Metal 后端的 GPU 计算平面,以独立进程运行,通过 Redis 消费**纯计算**指令。 + +**目录结构**: +``` +executor/op-metal/ + build.sh ← cmake 构建脚本 + CMakeLists.txt + src/ + client/main.cpp ← Redis 消费者主循环 + 计算指令分派 + deepx/ + metal_context.hpp/cpp ← Metal 设备/命令队列上下文 + mem/mem_metal.hpp ← 统一内存缓存 + tensorfunc/ + elementwise_miaobyte.metal ← Metal shader (GPU kernel) + elementwise_miaobyte.hpp ← host→device 桥接 (kernel 调用封装) + elementwise_common.hpp ← 共享 dispatch 模板 + changeshape_miaobyte.hpp ← reshape/transpose/concat 等 + reduce_miaobyte.hpp ← sum/prod/max/min 规约 + init_miaobyte.hpp ← constant/arange + tf/ + register_miaobyte.hpp ← TfFactory 算子注册 (调度器用) +``` + +## 新增 Metal GPU Kernel 标准流程 + +### Step 1: 在 .metal 文件中写 kernel + +文件: `src/deepx/tensorfunc/elementwise_miaobyte.metal` + +命名约定: `kernel void _(device const T* X, device T* Y, constant uint& n, uint gid)` + +规则: +- 用 `[[thread_position_in_grid]]` 一维网格 +- 显式 `if (gid < n)` 边界检查 +- 整数类型运算需显式 cast(如 `(char)(A[gid] + B[gid])`) +- 必须覆盖的 dtype: f16, f32, i8, i16, i32, i64(至少 f32 和整数类型) + +### Step 2: 在 .hpp 文件中声明封装函数 + +文件: `src/deepx/tensorfunc/elementwise_miaobyte.hpp` + +每个 kernel 对应一个 `extern bool _(const T* a, const T* b, T* c, int64_t n)` 声明。 + +封装函数内部: +1. 获取 `MetalContext::instance()` 的 device / commandQueue +2. 查找 `MTLLibrary`(从 .metal 编译的 default.metallib) +3. 创建 `MTLComputePipelineState` +4. 创建 `MTLBuffer`(输入/输出 + 常量 n) +5. `dispatchThreads` + commit + waitUntilCompleted +6. 检查 commandBuffer.error + +### Step 3: 在 main.mm 注册算子并分派 + +**注册**: 在 `register_instance()` 中 `RPUSH /op/op-metal/list ` + +**分派**: 在 `execute_task()` 中增加分支: +```cpp +else if (input_ptrs.size() == N && + (opcode == "newop")) { + ok = dispatch_xxx(opcode, dtype, ...); +} +``` + +**dispatch 函数**: 已有的 `dispatch_binary()` / `dispatch_unary()` 可直接复用,只需在 `if (opcode == "newop")` 中增加条目。 + +### Step 4: 在 TfFactory 注册 + +文件: `src/deepx/tf/register_miaobyte.hpp` + +```cpp +factory.add_tf(std::make_shared>( + vector{ + {"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{ + {"", DataCategory::Tensor, Precision::Float}})); +``` + +### Step 5: 构建与测试 + +```bash +/build-op-metal # cmake 构建 +/test-op-metal # 运行 shm 跨进程测试 +``` + +## dtype 覆盖检查清单 + +新增 GPU kernel 时必须逐 dtype 检查: + +| dtype | kernel 名称 | .metal | .hpp | main.mm dispatch | 测试 | +|-------|------------|--------|------|------------------|------| +| f16 | op_f16 | ☐ | ☐ | ☐ | ☐ | +| f32 | op_f32 | ☐ | ☐ | ☐ | ☐ | +| i8 | op_i8 | ☐ | ☐ | ☐ | ☐ | +| i16 | op_i16 | ☐ | ☐ | ☐ | ☐ | +| i32 | op_i32 | ☐ | ☐ | ☐ | ☐ | +| i64 | op_i64 | ☐ | ☐ | ☐ | ☐ | + +浮点专用 (sqrt/exp/log/sin/cos/tan) 只需 f16/f32。 + +## 通信协议 + +**入队**: +- Redis Key: `cmd:op-metal:` (默认 `cmd:op-metal:0`) +- 模式: `BLPOP` 阻塞弹出 (5s timeout) +- 格式: JSON `{"vtid":"...", "pc":"...", "opcode":"add", "inputs":[...], "outputs":[...]}` + +**通知完成**: +- Redis Key: `done:` +- 模式: `LPUSH` +- 格式: `{"pc":"...", "status":"ok"}` 或 `{"pc":"...", "status":"error", "error":{"code":"OP_ERROR","message":"..."}}` + +## 错误处理规范 + +每个 notify_done 必须包含: +- `vtid`: 虚拟线程 ID +- `pc`: 程序计数器坐标 +- `status`: "ok" | "error" +- `error`: (如果失败) `{"code":"OP_ERROR", "message":"..."}` + +shm 资源管理: +- `shm_open` + `mmap` 成功后必须在当前函数结束时 `shm_close` +- **任何错误路径** (early return, continue 等) 都必须先 close 已打开的 shm + +## 实例注册 + +启动时写入 `/sys/op-plat/op-metal:`: +```json +{"program":"op-metal","device":"gpu0","status":"running","load":0.0,"pid":12345,"started_at":...} +``` + +算子列表: `/op/op-metal/list` (Redis List, RPUSH) diff --git a/.claude/agents/dev-vm.md b/.claude/agents/dev-vm.md new file mode 100644 index 00000000..8c09e530 --- /dev/null +++ b/.claude/agents/dev-vm.md @@ -0,0 +1,135 @@ +# dev-vm → VM 核心开发 agent + +你是 deepx VM (Virtual Machine) 核心开发专家。指导 VM 解释器的新增、修改、调试全流程。 + +## 组件概述 + +VM 是 Go 实现的核心解释器。从 Redis 拾取 vthread,逐条解码执行层指令,分派到 op-plat / heap-plat 或本地求值。 + +**目录结构**: +``` +executor/vm/ + build.sh ← Go 构建脚本 + go.mod / go.sum + cmd/vm/main.go ← 入口: server 模式 / single-run 调试模式 + internal/ + engine/engine.go ← 执行循环: pick → decode → dispatch → next + state/state.go ← VThreadState{PC,Status,Error} GET/SET/SetError/WaitDone + ir/instruction.go ← Decode (MGET 批量), ParseDxlang, PC 导航 + ir/native.go ← 28 个原生算子定义 + dispatch/dispatch.go ← Compute → op-plat, Lifecycle → heap-plat, If + dispatch/native.go ← 原生求值引擎 (nativeValue + eval 函数) + translate/translate.go ← CALL eager 翻译 + RETURN 子栈清理 + route/router.go ← 算子路由 + 负载均衡 + picker/picker.go ← vthread 原子拾取 (Redis Watch) + cache/cache.go ← 本地指令缓存 (待集成) + testdata/ ← .dx 测试文件 (call/lifecycle/native/*) + testutil/dxloader.go ← 测试加载器 +``` + +## 关键架构概念 + +### 执行循环 +``` +RunWorker: + PickVthread (Redis Watch CAS 抢占 status=init) + → Execute loop: + state.Get → Decode (MGET 指令) → dispatch switch + → NextPC advance → repeat +``` + +### 三层指令分派 + +| 层 | 判断函数 | 分派目标 | 同步/异步 | +|----|---------|---------|----------| +| 控制流 | IsControlOp() | translate (call/return) + dispatch (if) | 同步 | +| 原生算子 | IsNativeOp() | dispatch.Native → evalNative | 同步 (VM 内) | +| 生命周期 | IsLifecycleOp() | dispatch.Lifecycle → heap-plat | 异步 (BLPOP wait) | +| 函数调用 | isFunctionCall() | translate.HandleCall → 子栈 | 同步 | +| 计算算子 | IsComputeOp() | dispatch.Compute → op-plat | 异步 (BLPOP wait) | + +### 执行层坐标系统 +``` +/vthread//[addr0, 0] = "opcode" ← 操作码 +/vthread//[addr0,-1] = "param1" ← 读参数 +/vthread//[addr0, 1] = "output1" ← 写参数 +/vthread//[2,0]/[1,0] = 子栈 ← CALL 翻译产生 +``` + +### CALL Eager 翻译流程 +1. 读取编译层 `/op//func//N` 或源码层 `/src/func//N` +2. 解析签名 → 形参列表 `(Reads: [A,B], Writes: [C])` +3. 形参→实参映射 (by position) +4. MGET 所有编译层指令 → 逐条 ParseDxlang → 形参替换 +5. Pipeline 批量 SET 到 `/vthread///[i,j]` 子栈 +6. 追加隐式 `return ` 指令 +7. PC 跳转到子栈 `[0,0]` + +## 新增原生算子的标准流程 + +### Step 1: 在 ir/native.go 注册 + +```go +var nativeOps = map[string]bool{ + // ...existing... + "newop": true, // ← 新增 +} +``` + +### Step 2: 在 dispatch/native.go 添加 eval 函数 + +```go +func evalNewOp(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { // 或 requireBinary + return nativeValue{}, err + } + // ... 实现逻辑 ... + return nativeValue{kind: "float", f: result}, nil +} +``` + +### Step 3: 在 evalNative switch 中注册 + +```go +case "newop": + return evalNewOp(inputs) +``` + +### Step 4: 更新 IsUnaryNativeOp (如果是单目) + +```go +func IsUnaryNativeOp(opcode string) bool { + switch opcode { + case "!", "-", "abs", ..., "newop": // ← 新增 + return true + } + return false +} +``` + +### Step 5: 测试 + +在 `testdata/native/` 下创建 `.dx` 测试文件,运行 `/test-vm`。 + +## 状态机 + +``` +init → running → wait → running → ... + ↓ + error / done +``` + +## 并发安全 + +- Worker 之间通过 Redis WATCH + TxPipelined 原子抢占 vthread (CAS) +- 同一 vthread 只能被一个 worker 执行 (status=init→running 原子操作保证) +- `state.SetError` 调用者必须检查 err,状态标记失败也需 log + +## 关键开发约束 + +1. **零 panic**: VM 是常驻服务。用 `state.SetError()` 代替 panic +2. **严禁吞 error**: Go `_` 丢弃 error 违反审计规则 1 +3. **IF 显式求值**: `isTruthy()` 只能接受 `"true"/"1"/"yes"`(其余为 false) +4. **错误含上下文**: SetError 必须含 vtid/pc/msg +5. **JSON 解析检查**: 所有 `json.Unmarshal` 错误必须处理 +6. **Redis 返回检查**: 所有 GET/SET/BLPOP 错误必须处理 diff --git a/.claude/commands/CLAUDE.md b/.claude/commands/CLAUDE.md new file mode 100644 index 00000000..647f93fe --- /dev/null +++ b/.claude/commands/CLAUDE.md @@ -0,0 +1,118 @@ +# 项目命令 + +> **所有构建命令均通过根目录 `Makefile` 执行。禁止使用裸 `go build`、`cmake`、`./build.sh`。** + +## 构建命令 + +| 命令 | 组件 | 语言 | 说明 | +|------|------|------|------| +| `/build-op-cuda` | op-plat (CUDA) | C++/CUDA | 构建 CUDA 计算平面 | +| `/build-op-metal` | op-plat (Metal) | C++/ObjC++/Metal | → `make build-op-metal` | +| `/build-heap-metal` | heap-plat (Metal) | C++/ObjC++ | → `make build-heap-metal` | +| `/build-io-metal` | io-plat | C++ | → `make build-io-metal` | +| `/build-vm` | VM | Go | → `make build-vm` | +| `/build-all` | 全部 | C++/Go | → `make build-all` | + +## 测试命令 + +| 命令 | 说明 | +|------|------| +| `/test-vm` | 运行 VM 单元测试 → `make test-vm` | +| `/test-integration` | 运行 VM 集成测试 → `make test-integration` | +| `/test-op-metal` | 构建 op-metal 并运行 shm 跨进程测试 | + +## 服务生命周期 (联调) + +| 命令 | 说明 | +|------|------| +| `/boot` | 构建并启动所有服务 (deepxctl boot) | +| `/run ` | 加载 .dx 文件 (如有顶层调用则自动执行) → deepxctl run | +| `/run --entry ` | 手动指定入口函数执行 | +| `/shutdown` | 停止所有 booted 服务 (deepxctl shutdown) | +| `/status` | 查看所有服务状态 → `make status` | +| `/pipeline` | 完整联调流水线 → `make pipeline` | + +## 环境命令 + +| 命令 | 说明 | +|------|------| +| `/reset-redis` | 重置 Redis → `make reset-redis` | +| `/kvspace ` | 检查 KV Space 联调状态 (堆 + 栈) → skill `debug-kvspace` | + +## 联调命令详解 + +### 典型联调流程 (deepxctl) + +```bash +# 1. 构建并启动所有服务 (一次性) +deepxctl boot + +# 2a. 加载纯定义文件 (只定义函数,不执行) +deepxctl run example/dxlang/call/add_test.dx +# → "Loaded 1 function(s) into KV Space." +# → 没有顶层调用,VM 的 /func/main 监视器保持等待 + +# 2b. 加载含顶层调用的文件 (自动执行) +deepxctl run my_file_with_call.dx +# → loader 检测到顶层调用 → 写入 /func/main → VM 自动创建 vthread → 执行 + +# 2c. 手动指定入口函数 +deepxctl run --entry native_arith example/dxlang/native/arith/add.dx +# → 绕过顶层调用检测,直接以 native_arith 为入口执行 + +# 3. 停止所有服务 +deepxctl shutdown +``` + +### 分步联调 (make) + +```bash +# 构建 +make build-all + +# 启动服务 (后台,通过 PID 文件管理) +make start-services REDIS_ADDR=127.0.0.1:16379 +make status REDIS_ADDR=127.0.0.1:16379 + +# 重置 Redis +make reset-redis REDIS_ADDR=127.0.0.1:16379 + +# 加载 & 执行 +./tmp/deepx-vm/loader load example/dxlang/lifecycle/full.dx +./tmp/deepx-vm/vm 127.0.0.1:16379 # 手动启动 VM + +# 停止服务 +make stop-services + +# 查看日志 +tail -f /tmp/deepx-logs/op-metal.log +tail -f /tmp/deepx-logs/heap-metal.log +tail -f /tmp/deepx-logs/io-metal.log +tail -f /tmp/deepx-logs/vm.log +``` + +### plats 管理 + +plats (op-metal / heap-metal / io-metal / VM) 由 deepxctl 通过 subprocess 管理生命周期,**所有通信严格通过 Redis**: +- `deepxctl boot` 启动所有进程,写入 PID 到 `/tmp/deepx-boot.json` +- `deepxctl run` 检测 boot 状态后加载并执行 .dx 文件 +- `deepxctl shutdown` **有序退出**: plats → VM → 心跳验证 → deepxctl 退出 + +**Redis 队列分工**: +| 队列 | 类型 | 用途 | +|------|------|------| +| `cmd:op-metal:0` | 业务 | GPU 计算指令 | +| `cmd:heap-metal:0` | 业务 | 堆内存管理指令 | +| `cmd:io-metal:0` | 业务 | I/O 指令 (print/save/load) | +| `notify:vm` | 业务 | VThread 调度通知 | +| `/func/main` | **入口** | loader→VM→deepxctl 三方协作: `{"entry":"f","reads":[...],"writes":[...]}` → `{"vtid":"...","status":"executing"}` → `{"vtid":"...","status":"done"}` | +| `sys:cmd:op-metal:0` | 系统 | op-metal shutdown | +| `sys:cmd:heap-metal:0` | 系统 | heap-metal shutdown | +| `sys:cmd:io-metal:0` | 系统 | io-metal shutdown | +| `sys:cmd:vm:0` | 系统 | VM shutdown | +| `/sys/heartbeat/*` | **心跳** | 各组件每 2s 上报 `{"ts":...,"status":"running/stopped"}` | + +**退出顺序**: plats (op-metal, heap-metal, io-metal) 先退出 → VM 退出 → deepxctl 检查心跳确认全部 stopped → 清理 PID 文件 +OS SIGKILL 仅作为 Redis 不可达时 / 超时时的最后兜底。 + +参见 `tool/deepxctl/internal/process/manager.go` (进程管理)、`tool/deepxctl/cmd/shutdown.go` (有序 shutdown + 心跳验证) 和各组件 main 文件 (心跳上报 + 系统指令监听)。 \ No newline at end of file diff --git a/.claude/commands/build-heap-metal.sh b/.claude/commands/build-heap-metal.sh new file mode 100644 index 00000000..c7021132 --- /dev/null +++ b/.claude/commands/build-heap-metal.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +exec make -C "$PROJECT_ROOT" build-heap-metal diff --git a/.claude/commands/build-io-metal.sh b/.claude/commands/build-io-metal.sh new file mode 100644 index 00000000..197cbd47 --- /dev/null +++ b/.claude/commands/build-io-metal.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +exec make -C "$PROJECT_ROOT" build-io-metal diff --git a/.claude/commands/build-op-cuda.sh b/.claude/commands/build-op-cuda.sh new file mode 100644 index 00000000..18b9451a --- /dev/null +++ b/.claude/commands/build-op-cuda.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +set -euo pipefail +exec ./executor/op-cuda/build.sh "$@" diff --git a/.claude/commands/build-op-metal.sh b/.claude/commands/build-op-metal.sh new file mode 100644 index 00000000..c3898980 --- /dev/null +++ b/.claude/commands/build-op-metal.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +exec make -C "$PROJECT_ROOT" build-op-metal diff --git a/.claude/commands/build-vm.sh b/.claude/commands/build-vm.sh new file mode 100644 index 00000000..eb3ffbc9 --- /dev/null +++ b/.claude/commands/build-vm.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +exec make -C "$PROJECT_ROOT" build-vm diff --git a/.claude/commands/reset-redis.sh b/.claude/commands/reset-redis.sh new file mode 100644 index 00000000..47a5eafe --- /dev/null +++ b/.claude/commands/reset-redis.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +exec make -C "$PROJECT_ROOT" reset-redis REDIS_ADDR="${1:-${REDIS_ADDR:-127.0.0.1:16379}}" diff --git a/.claude/commands/test-op-metal.sh b/.claude/commands/test-op-metal.sh new file mode 100644 index 00000000..06b32814 --- /dev/null +++ b/.claude/commands/test-op-metal.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Build op-metal and run tests (shm cross-process tests) +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +OP_DIR="$SCRIPT_DIR/../../executor/op-metal" +BUILD_DIR="/tmp/deepx/op-metal/build" + +echo "=== Build op-metal ===" +mkdir -p "$BUILD_DIR" +cd "$BUILD_DIR" +cmake "$OP_DIR" +cmake --build . -j$(sysctl -n hw.ncpu 2>/dev/null || nproc) + +echo "" +echo "=== Run SHM Cross-Process Test ===" +if [ -f "$BUILD_DIR/test/shm/test_cross_process" ]; then + "$BUILD_DIR/test/shm/test_cross_process" + echo "SHM test passed." +else + echo "Test binary not found." +fi diff --git a/.claude/commands/test-vm.sh b/.claude/commands/test-vm.sh new file mode 100644 index 00000000..de5af940 --- /dev/null +++ b/.claude/commands/test-vm.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +exec make -C "$PROJECT_ROOT" test-vm diff --git a/.claude/glossary.md b/.claude/glossary.md new file mode 100644 index 00000000..ccc38e02 --- /dev/null +++ b/.claude/glossary.md @@ -0,0 +1,125 @@ +# 元程术语表 (Metaproc Glossary) + +> 完整术语定义见 `doc/metaproc/spec-v1.md` 附录 C。 +> 本文档为 Claude Code 上下文快速参考。 + +## 核心概念 + +| 术语 | 英文 | 定义 | +|------|------|------| +| 元程 | Metaproc | 一个 KV 空间实例,分布式计算的边界。类比 OS 进程。 | +| 元线程 | Vthread | 元程内的执行流。私有调用栈,共享堆数据。类比 OS 线程。 | +| KV 空间 | KV Space | Redis 实现的全局 key-value 存储。元程的运行载体。 | +| 路径空间 | Path Space | KV 空间的 key 命名规范。`/src/func/`、`/vthread/` 等。 | +| dxlang | DX Language | deepx 元程级编程语言(概念);`executor/dxlang/` 是其当前 C++ 参考实现(代码)。 | + +## 5 个核心 (Five Cores) + +| 核心 | 英文 | 角色 | +|------|------|------| +| Redis | — | KV 空间:全局状态存储、命令队列 (List)、锁、通知 | +| pysdk | Python SDK | 算法前端:注册 dxlang 源码到 `/src/func/`,创建 vthread | +| op-plat | Operator Platform | 计算平面:被动消费指令,执行 GPU/CPU 张量运算 | +| heap-plat | Heap Platform | 堆管理平面:tensor 对象生命周期 (shm 创建/删除/克隆) | +| VM | Virtual Machine | 解释执行:CALL eager 翻译、指令路由到 op-plat/heap-plat、状态推进 | + +## 三层 IR + +| 层 | 路径 | 格式 | 角色 | +|----|------|------|------| +| 源码层 | `/src/func//` | dxlang 人类可读文本 | pysdk 写入 | +| 编译层 | `/op//func//` | 编译器优化后 dxlang | 编译器写入,VM CALL 时读取 | +| 执行层 | `/vthread//` | `[addr0, addr1]` 二维坐标 | VM CALL 时 eager 翻译 | + +## 路径空间 + +| 路径 | 说明 | +|------|------| +| `/src/func/` | 函数签名 (dxlang) | +| `/src/func//N` | 第 N 条指令 | +| `/op//func//N` | 编译后指令 (可能融合/拆分) | +| `/op//list` | 算子列表 (程序级, 所有实例共享) | +| `/op//` | 算子元数据 (category, dtype, max_shape...) | +| `/vthread/` | vthread 自身: `{pc, status}` | +| `/vthread//[addr0,0]` | 操作码 (opcode) | +| `/vthread//[addr0,-N]` | 第 N 个读取参数 | +| `/vthread//[addr0,+N]` | 第 N 个写入参数 | +| `/vthread//` | 命名槽位 (局部变量, 与指令坐标平级) | +| `/vthread//[n,0]/[0,0]` | 子栈 (CALL 产生) | +| `/sys/op-plat/` | op-plat 进程注册 | +| `/sys/heap-plat/` | heap-plat 进程注册 | +| `/sys/vtid_counter` | vthread ID 自增计数器 | +| `cmd:op-:` | op-plat 命令队列 (Redis List) | +| `cmd:heap-:` | heap-plat 命令队列 | +| `done:` | vthread 完成通知队列 | +| `notify:vm` | VM 唤醒通知队列 | +| 其他非保留路径 | 堆变量 (tensor 元信息) | + +## 指令格式 + +**dxlang (源码层/编译层):** +``` +opcode(read_p1, read_p2, ...) -> write_p1, write_p2 +``` + +**执行层 (二维寻址):** +``` +/vthread//[addr0, 0] = "opcode" ← addr1=0 → 操作码 +/vthread//[addr0,-1] = "param1" ← addr1<0 → 读取参数 +/vthread//[addr0, 1] = "output1" ← addr1>0 → 写入参数 +``` + +`.` 相对路径 (`./mm`) 解析为 `/vthread//mm` (命名槽位)。 + +## Vthread 状态 + +| 状态 | 含义 | +|------|------| +| `init` | 已创建,待 VM 拾取 | +| `running` | VM 正在调度执行 | +| `wait` | 等待异步操作 (op-plat / heap-plat) | +| `error` | 执行出错 | +| `done` | 执行完毕,可 GC | + +## CALL 语义 + +1. VM 读取 `/op//func//` 编译层 +2. 建立形参→实参映射 +3. 逐条解析 dxlang → 形参替换 → 展开为 `[i,j]` 坐标 +4. Pipeline 批量写入 `/vthread//[n,0]/` 子栈 +5. PC 进入子栈首条指令 + +## RETURN 语义 + +1. 返回值写入父 CALL 指令的写参数槽位 +2. 递归 DELETE 子栈 KV 路径 +3. PC 恢复到父栈 CALL 的下一条 + +## 算子融合与拆分 (编译器) + +| 操作 | 方向 | 说明 | +|------|------|------| +| 融合 (Fusion) | N→1 | 编译器将连续匹配指令替换为 fused 算子 | +| 拆分 (Split) | 1→N | Tensor 超过单卡上限时拆分 + 标注设备 | + +两者都是编译器在 `/src/func/` → `/op//func/` 层完成。 + +## Tensor 元信息 + +```json +{"dtype":"f32","shape":[1024,512],"byte_size":2097152,"device":"gpu0","address":{"type":"shm","shm_name":"/deepx_t_abc123"},"ctime":1714000000,"version":5} +``` + +## 与 OS 进程对照 + +| OS 概念 | 元程对应 | +|---------|---------| +| 虚拟地址空间 | KV 空间 | +| 进程 | 一个 KV 空间实例 | +| 线程 | Vthread | +| 代码段 (.text) | /src/func/ + /op//func/ | +| 堆段 (.data/.bss) | 非保留路径 (堆变量) | +| 栈段 | /vthread// | +| PC | /vthread/ 的 pc 字段 | +| CALL/RET | CALL 翻译 → 子栈 / RETURN → DELETE 子栈 | +| 系统调用 | heap-plat / op-plat 命令 | diff --git a/.claude/hooks/CLAUDE.md b/.claude/hooks/CLAUDE.md new file mode 100644 index 00000000..da5b1ae6 --- /dev/null +++ b/.claude/hooks/CLAUDE.md @@ -0,0 +1,16 @@ +# 项目钩子 + +deepx 事件拦截与自动化。 + +## 可用钩子 + +| 钩子 | 触发事件 | 动作 | +|------|---------|------| +| `post-build-op-metal` | `/build-op-metal` 完成后 | 验证 default.metallib 存在 + 运行 shm 测试 | +| `post-build-vm` | `/build-vm` 完成后 | go vet 检查 + 单元测试结果确认 | + +## 实现方式 + +钩子逻辑嵌入在对应命令脚本中(如 `test-op-metal.sh` 同时做构建+测试)。 + +如需扩展为独立钩子文件,在此目录下创建 `on-.sh` 并在 settings 中注册。 diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 00000000..7a028144 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,16 @@ +{ + "permissions": { + "allow": [ + "Bash(printf '#!/usr/bin/env bash\\\\nset -euo pipefail\\\\nexec ./executor/op-cuda/build.sh \"$@\"\\\\n')", + "Bash(chmod +x /Users/lipeng/miaobyte-1252231640/git.array2d.com/ai/deepx/.claude/commands/build-op-cuda.sh)", + "Bash(printf '#!/usr/bin/env bash\\\\nset -euo pipefail\\\\nexec ./executor/op-metal/build.sh \"$@\"\\\\n')", + "Bash(chmod +x /Users/lipeng/miaobyte-1252231640/git.array2d.com/ai/deepx/.claude/commands/build-op-metal.sh)", + "Bash(printf '#!/usr/bin/env bash\\\\nset -euo pipefail\\\\nexec ./executor/heap-metal/build.sh \"$@\"\\\\n')", + "Bash(chmod +x /Users/lipeng/miaobyte-1252231640/git.array2d.com/ai/deepx/.claude/commands/build-heap-metal.sh)", + "Bash(cat)", + "Bash(chmod +x .claude/commands/*.sh)", + "Bash(rm -rf /Users/lipeng/miaobyte-1252231640/git.array2d.com/ai/deepx/executor/build/*)", + "Bash(cmake ../op-metal -DCMAKE_BUILD_TYPE=Debug)" + ] + } +} diff --git a/.claude/skills/CLAUDE.md b/.claude/skills/CLAUDE.md new file mode 100644 index 00000000..f0ca8e05 --- /dev/null +++ b/.claude/skills/CLAUDE.md @@ -0,0 +1,13 @@ +# 项目技能 + +deepx 专用技能模块。提供引导式工作流,覆盖常见开发与调试任务。 + +## 可用技能 + +| 技能 | 文件 | 用途 | +|------|------|------| +| `add-metal-kernel` | `add-metal-kernel.md` | 新增 Metal GPU kernel 的 7 步引导式工作流 (shader→host→dispatch→注册→构建→测试) | +| `debug-vthread` | `debug-vthread.md` | vthread 执行调试指南 (Redis key 检查、PC 跟踪、异步追踪、单步运行、常见问题) | +| `dual-opcode-audit` | `dual-opcode-audit.md` | VM ↔ op-plat opcode 一致性审计 (交叉比对、问题分类、自动检查脚本) | + +使用方式:对话中输入对应 skill 名称即可触发引导式工作流。例如:"帮我新增一个 gelu Metal kernel,用 add-metal-kernel"。 diff --git a/.claude/skills/add-metal-kernel.md b/.claude/skills/add-metal-kernel.md new file mode 100644 index 00000000..c64f977f --- /dev/null +++ b/.claude/skills/add-metal-kernel.md @@ -0,0 +1,126 @@ +# skill: add-metal-kernel → 新增 Metal GPU kernel + +引导式工作流,按步骤新增一个完整的 Metal GPU 算子。 + +## 前置条件 + +- 本地有 Xcode + Metal 工具链 +- 已运行过 `/build-op-metal` 确认构建环境正常 +- 了解 op-metal 代码结构 (参考 `@dev-op-metal`) + +## 工作流步骤 + +### 1. 确认算子范围 + +回答以下问题: +- 算子名称? (e.g., "gelu", "softmax") +- 几元操作? 一元 / 二元 / 多元? +- 支持哪些 dtype? f32 必须,扩展 f16/i8/i16/i32/i64? +- 存在 CPU fallback 需求吗? + +### 2. 写 Metal Shader + +编辑: `executor/op-metal/src/deepx/tensorfunc/elementwise_miaobyte.metal` + +模板 (一元): +```metal +kernel void newop_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = /* 实现 */; } +} +``` + +模板 (二元): +```metal +kernel void newop_f32(device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = /* 实现 */; } +} +``` + +生成所有 dtype 变体: f16, f32, i8, i16, i32, i64 +整数 kernel 注意显式 cast: `(char)(...)` + +### 3. 写 Host 桥接函数 + +编辑: `executor/op-metal/src/deepx/tensorfunc/elementwise_miaobyte.hpp` + +参考已有函数模式 (如 `add_f32` → `relu_f32`): +- `extern bool newop_f32(const float* x, float* y, int64_t n);` +- 实现: 获取 MetalContext → 加载 library → 创建 pipeline → 分配 buffer → dispatch → 等待完成 +- 检查 commandBuffer.error → 返回成功/失败 + +### 4. 注册算子 + 分派 + +编辑: `executor/op-metal/src/client/main.mm` + +**注册**: 在 `register_instance()` 中: +```cpp +redis_cmd(c, "RPUSH %s %s", "/op/op-metal/list", "newop"); +``` + +**分派**: 在 `execute_task()` 中增加分支: +```cpp +// 如果是已有一元分派函数 → 加到 dispatch_unary 的 if 条件 +else if (opcode == "newop" && input_ptrs.size() == 1) { + ok = dispatch_unary(opcode, dtype, input_ptrs[0], out_shm.addr, n); + if (!ok) error = "Metal kernel dispatch failed for newop:" + dtype; +} +``` + +在 `dispatch_unary()` 中增加: +```cpp +if (opcode == "newop") { + if (dtype == "f32" || dtype == "float32") return newop_f32(...); + // ... 其他 dtype +} +``` + +### 5. TfFactory 注册 (可选,调度器用) + +编辑: `executor/op-metal/src/deepx/tf/register_miaobyte.hpp` + +```cpp +factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); +``` + +### 6. 构建 + 验证 + +```bash +/build-op-metal +``` + +检查输出: +- `[op-metal] registered all ops` 包含 newop +- 构建无 Metal shader 编译错误 + +### 7. 测试 + +在 VM 侧编写 `.dx` 测试文件: +``` +# testdata/native/arith/newop.dx +newtensor("./a") -> ./a ← 创建输入 +./a <- constant([3,3], 2.0) +newop(./a) -> ./b ← 测试新算子 +print(./b) +deltensor(./a) +deltensor(./b) +``` + +## 完成检查清单 + +- [ ] Metal shader 编译通过 (cmake 构建无错误) +- [ ] 所有声明 dtype 都有 kernel +- [ ] `/op/op-metal/list` 包含新算子 +- [ ] main.mm dispatch 路径可达(无 unreachable code) +- [ ] 错误路径正确释放 shm + notify_done +- [ ] TfFactory 注册 (如适用) diff --git a/.claude/skills/debug-kvspace.md b/.claude/skills/debug-kvspace.md new file mode 100644 index 00000000..05bf9bda --- /dev/null +++ b/.claude/skills/debug-kvspace.md @@ -0,0 +1,172 @@ +# skill: debug-kvspace → KV Space 联调状态检查 + +用 `redis-cli` 在联调时快速检查 deepx 元程的堆 (tensor 元信息 / heap-plat) 和栈 (vthread 执行状态)。 + +## 前置条件 + +- Redis 运行中,默认地址 `127.0.0.1:16379` (由 `REDIS_ADDR` 环境变量控制) +- `redis-cli` 已安装 (本项目环境: `/opt/homebrew/bin/redis-cli` 8.6.2) + +## 速连 + +```bash +# Claude 可直接执行 redis-cli,无需 sandbox 放行 +redis-cli -p 16379 PING +# → PONG +``` + +## 栈检查 (vthread — VM 管理的元程执行栈) + +vthread key 树结构见 `doc/metaproc/redis-keys.md` §4。 + +### 概览 + +```bash +# 列出全部 vthread +redis-cli -p 16379 KEYS "/vthread/*" | grep -v '\[.*\]' + +# 看某个 vthread 的完整状态 (pc + status + error) +redis-cli -p 16379 GET /vthread/ | jq . + +# 看 vthread 全部子 key (指令坐标 + 命名槽位 + 子栈) +redis-cli -p 16379 KEYS "/vthread//*" | sort +``` + +### PC & 指令 + +```bash +# 当前执行到哪条指令 +redis-cli -p 16379 GET /vthread/ | jq '.pc' + +# 根栈第 N 条指令的操作码 + 读写参数 +redis-cli -p 16379 MGET \ + /vthread//[N,0] \ + /vthread//[N,-1] \ + /vthread//[N,-2] \ + /vthread//[N,1] + +# 扫描子栈 (CALL 后多层嵌套) +redis-cli -p 16379 KEYS "/vthread//\[*" | sort +``` + +### 命名槽位 (局部变量) + +```bash +# 看所有非坐标子 key (局部变量 / tensor 引用) +redis-cli -p 16379 KEYS "/vthread//*" | grep -v '\[.*\]' + +# 看某个槽位的值 (基础类型直存, tensor 存 JSON 元信息) +redis-cli -p 16379 GET /vthread// | jq . +``` + +### 状态解读 + +| status | 含义 | 排查方向 | +|--------|------|---------| +| `init` | 已创建,待 VM 拾取 | 检查 VM 是否运行 | +| `running` | 正在执行 | 正常 | +| `wait` | 等待异步操作 | 检查 `cmd:op-metal:*` / `cmd:heap-metal:*` 队列 | +| `error` | 执行出错 | 查看 `error` 字段详情 | +| `done` | 执行完毕 | 正常,可 GC | + +## 堆检查 (tensor 元信息 & heap-plat) + +deepx 堆由 **heap-plat 进程** 管理,tensor 元信息存 Redis,实际数据存 POSIX shm。 + +### 堆变量 (tensor 元信息) + +```bash +# 列出所有堆变量 (排除保留路径) +redis-cli -p 16379 KEYS "*" | grep -v -E '^(/src|/vthread|/sys|/cmd:|/done:|/notify:|/lock:|/op/)' | head -50 + +# 看某个 tensor 的完整元信息 +redis-cli -p 16379 GET /models/ | jq . +# → {"dtype":"f32","shape":[1024,512],"byte_size":2097152,"device":"gpu0","address":{"type":"shm","shm_name":"/deepx_t_abc123"}} + +# 批量看关键字段 +redis-cli -p 16379 MGET /models/A /models/B | while read line; do echo "$line" | jq -c '{dtype,shape,device}'; done +``` + +### heap-plat 进程状态 + +```bash +# 查看 heap-plat 实例注册 +redis-cli -p 16379 GET /sys/heap-plat/metal:0 | jq . + +# 列出所有 heap 实例 +redis-cli -p 16379 KEYS "/sys/heap-plat/*" +``` + +### heap 命令队列 + +```bash +# 队列长度 (堆积 >0 说明 heap-plat 未响应) +redis-cli -p 16379 LLEN cmd:heap-metal:0 + redis-cli -p 16379 LLEN cmd:io-metal:0 + +# 查看全部待处理命令 +redis-cli -p 16379 LRANGE cmd:heap-metal:0 0 -1 + +# 查看一条命令详情 (newtensor / deltensor / clonetensor) +redis-cli -p 16379 LINDEX cmd:heap-metal:0 0 | jq . +``` + +### io-plat 进程状态 + +```bash +# 查看 io-plat 实例注册 +redis-cli -p 16379 GET /sys/io-plat/io-metal:0 | jq . + +# io 命令队列 +redis-cli -p 16379 LLEN cmd:io-metal:0 +redis-cli -p 16379 LRANGE cmd:io-metal:0 0 -1 +``` + +### shm 存在性验证 + +```bash +# 从 Redis 拿到 shm_name 后验证 shm 是否真实存在 +ls -la /tmp/deepx_t_* 2>/dev/null +``` + +## 联调联合检查 (堆 + 栈一站式) + +联调时最常见的快速检查路径: + +```bash +# 1. 确认全部平台进程在线 +redis-cli -p 16379 KEYS "/sys/*" | sort + +# 2. 确认 vthread 状态 +redis-cli -p 16379 GET /vthread/1 | jq '{pc,status}' + +# 3. 若 wait → 查命令队列堆积 +redis-cli -p 16379 LLEN cmd:op-metal:0 +redis-cli -p 16379 LLEN cmd:heap-metal:0 + redis-cli -p 16379 LLEN cmd:io-metal:0 + +# 4. 若 error → 查看错误详情 +redis-cli -p 16379 GET /vthread/1 | jq '.error' + +# 5. 查看堆 tensor 是否完整 +redis-cli -p 16379 KEYS "/models/*" | while read k; do + echo -n "$k → "; redis-cli -p 16379 GET "$k" | jq -c '{dtype,shape,device}' +done + +# 6. 检查完成通知 +redis-cli -p 16379 KEYS "done:*" +redis-cli -p 16379 LRANGE done:1 0 -1 +``` + +## 快速重置 + +```bash +make reset-redis +# 等价于 redis-cli -p 16379 FLUSHDB +``` + +## 参考 + +- Redis key 完整规范: `doc/metaproc/redis-keys.md` +- vthread 调试工作流: `.claude/skills/debug-vthread.md` +- heap-plat 开发: `doc/heap-plat/` diff --git a/.claude/skills/debug-vthread.md b/.claude/skills/debug-vthread.md new file mode 100644 index 00000000..70e702e3 --- /dev/null +++ b/.claude/skills/debug-vthread.md @@ -0,0 +1,100 @@ +# skill: debug-vthread → vthread 执行调试 + +排查 vthread 执行卡住、报错、或结果不符合预期的问题。 + +## 前置条件 + +- Redis 运行中(含测试数据) +- VM binary 已构建 (`/build-vm`) + +## 调试工作流 + +### 1. 快速诊断 + +用 `redis-cli` 检查 vthread 状态: + +```bash +# 查看 vthread 状态 +redis-cli GET /vthread/ + +# 查看当前 PC 指令 +redis-cli MGET /vthread//[0,0] /vthread//[0,-1] /vthread//[0,1] + +# 查看所有执行层 key +redis-cli KEYS "/vthread//*" +``` + +**状态解读**: +| Status | 含义 | 排查方向 | +|--------|------|---------| +| `init` | VM 未拾取 | 检查 VM worker 是否运行; 检查 picker 日志 | +| `running` | 正在执行 | 正常 | +| `wait` | 等待 op/head-plat | 检查 `cmd:op-metal:0` / `cmd:heap-metal:0` 队列 | +| `error` | 出错 | GET `/vthread/` 查看 error 详情 | +| `done` | 完成 | 正常 | + +### 2. PC 跟踪 + +查看当前执行到哪条指令: +```bash +# 方式 1: 直接读状态 +redis-cli GET /vthread/ | jq . + +# 方式 2: 扫描子栈 (CALL 路径) +redis-cli KEYS "/vthread//*" | sort +``` + +PC 格式: `[0,0]` (根栈) → `[2,0]/[0,0]` (子栈) → `[2,0]/[3,0]/[0,0]` (深层子栈) + +### 3. 查看编译层 vs 执行层 + +对比编译层指令和翻译后的执行层: +```bash +# 编译层 (源码) +redis-cli GET "/src/func/" +redis-cli KEYS "/src/func//*" | sort -t/ -k5 -n | xargs redis-cli GET + +# 执行层 (翻译后) +redis-cli KEYS "/vthread///*" | while read k; do echo "$k: $(redis-cli GET "$k")"; done +``` + +### 4. 异步任务追踪 + +op-plat / heap-plat 的异步通信: +```bash +# 查看 op-plat 命令队列长度 +redis-cli LLEN "cmd:op-metal:0" + +# 查看 done 队列 +redis-cli KEYS "done:*" + +# 查看算子系统注册 +redis-cli LRANGE "/op/op-metal/list" 0 -1 +redis-cli GET "/sys/op-plat/op-metal:0" +``` + +### 5. 单步调试 + +使用 VM single-run 模式逐 vthread 执行: +```bash +./executor/vm/build/vm run [redis_addr] +``` + +输出包含最终 PC、Status、Error 信息。 + +### 6. 常见问题模式 + +| 症状 | 常见原因 | 解决 | +|------|---------|------| +| vthread 永远 `wait` | op-plat / heap-plat 未启动 | 启动对应进程 | +| CALL 失败 "func not found" | 编译层或源码层缺少函数定义 | SET `/src/func/` 或 `/op//func/` | +| "route: no op-plat supports" | 算子未在任何 op-plat 注册 | 检查 `/op/*/list` | +| "BLPOP timeout" | op-plat 响应超时 (30s) | 增大 timeout 或优化 kernel | +| PC 全部为 `[0,0]` | pysdk 写入指令时未使用递增编号 | 检查 `/src/func//N` 的 N 是否连续递增 | + +### 7. 重置测试环境 + +```bash +/reset-redis +``` +清空所有 vthread / src / op / sys / done / cmd key 后重新测试。 diff --git a/.claude/skills/dual-opcode-audit.md b/.claude/skills/dual-opcode-audit.md new file mode 100644 index 00000000..ceb531cd --- /dev/null +++ b/.claude/skills/dual-opcode-audit.md @@ -0,0 +1,81 @@ +# skill: dual-opcode-audit → 双端 opcode 一致性审计 + +检查 VM 分派的 opcode 与 op-plat 注册的算子列表是否匹配。 + +## 审计目标 + +确保三层之间 opcode 命名一致: + +1. **VM 分派层**: `ir` 包中识别的 opcode (IsComputeOp / IsLifecycleOp / IsControlOp / IsNativeOp) +2. **op-plat 注册层**: `/op/op-metal/list` + `/op/op-cuda/list` 中注册的算子 +3. **测试层**: `testdata/` 中 `.dx` 文件使用的 opcode + +## 审计步骤 + +### Step 1: 收集 VM 端 opcode + +列出所有 VM 会分发到 op-plat 的 opcode (不在 native/lifecycle/control 中的): + +```bash +# 从 VM 测试文件统计实际使用的 opcode +grep -rhoP '^\s*\w+\(' executor/vm/testdata/ | sort -u | sed 's/($//' +``` + +### Step 2: 收集 op-plat 端 opcode + +```bash +# 从 Redis 读取 (需 Redis 运行 + op-plat 已注册) +redis-cli LRANGE "/op/op-metal/list" 0 -1 +redis-cli LRANGE "/op/op-cuda/list" 0 -1 +``` + +或从源码静态收集: +```bash +# op-metal main.mm 中 RPUSH 的算子 +grep -A1 'RPUSH.*list' executor/op-metal/src/client/main.mm | grep '"' | grep -oP '"\K[^"]+' | grep -v '/op/' | grep -v 'RPUSH' +``` + +### Step 3: 交叉比对 + +| opcode | VM 使用 | op-metal 注册 | op-cuda 注册 | CPU fallback | 状态 | +|--------|---------|---------------|--------------|--------------|------| +| add | ✅ | ✅ | ? | — | ✅ | +| newop | ✅ | ❌ | ❌ | — | ❌ | + +### Step 4: 识别问题 + +**问题类型 A**: VM 分派但无 op-plat 注册 → `route.Select()` 返回 "no op-plat supports" +**问题类型 B**: op-plat 注册但 VM 无测试覆盖 → 死代码 +**问题类型 C**: opcode 拼写不一致 (e.g., "equalscalar" vs "eq_scalar") +**问题类型 D**: VM native 与 op-plat 都处理同一 opcode → 优先级冲突 (VM native 优先) + +### Step 5: 自动检查脚本 + +```bash +#!/bin/bash +# 快速 opcode 一致性检查 +echo "=== VM native ops ===" +grep ':"' executor/vm/internal/ir/native.go | grep -oP '"\K[^"]+' | sort > /tmp/vm_native.txt + +echo "=== VM test ops ===" +grep -rhoP '^\s*\w+\(' executor/vm/testdata/ | sed 's/($//' | sort -u > /tmp/vm_test.txt + +echo "=== op-metal registered ===" +# (需要从 Redis 或 main.mm 中提取) +grep -A1 'RPUSH.*list' executor/op-metal/src/client/main.mm | \ + grep '"[a-z]' | grep -oP '"\K[^"]+' | grep -v 'list' | sort > /tmp/op_metal.txt + +echo "=== Diff: VM test vs op-metal ===" +comm -23 /tmp/vm_test.txt /tmp/op_metal.txt +echo "(VM test ops NOT in op-metal)" + +comm -13 /tmp/vm_test.txt /tmp/op_metal.txt +echo "(op-metal ops NOT in VM tests)" +``` + +## 合规标准 + +每个 compute opcode 必须满足: +- 至少一个 op-plat 后端注册 (metal 或 cuda 或 cpu) +- VM 有对应的 dispatch 路径 (Compute 函数) +- 测试覆盖 (testdata 中有 .dx 文件) diff --git a/.github/workflows/executor-heapmemcuda.yml b/.github/workflows/executor-heap-cuda.yml similarity index 94% rename from .github/workflows/executor-heapmemcuda.yml rename to .github/workflows/executor-heap-cuda.yml index fbfd2c40..49ff69bc 100644 --- a/.github/workflows/executor-heapmemcuda.yml +++ b/.github/workflows/executor-heap-cuda.yml @@ -1,11 +1,11 @@ -name: executor/heapmem-cuda Build +name: executor/heap-cuda Build on: push: paths: - - 'executor/heapmem-cuda/**' + - 'executor/heap-cuda/**' pull_request: paths: - - 'executor/heapmem-cuda/**' + - 'executor/heap-cuda/**' env: CUDA_VERSION: "12.9.1" CUDA_MAJOR_VERSION: "12" @@ -69,7 +69,7 @@ jobs: # 构建 CUDA 执行器 apt install -y libhiredis-dev && \ - cd ../../heapmem-cuda && \ + cd ../../heap-cuda && \ mkdir -p build && cd build && \ cmake -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ diff --git a/.gitignore b/.gitignore index 071a90ef..8e950fc6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,11 @@ .vscode **/build/ +**/build-*/ .idea **/.idea **/__pycache__/ **/dist/ **/*.egg-info/ -*.pdf \ No newline at end of file +*.pdf.claude/ +.claude/settings.local.json.tmp.* +dump.rdb diff --git a/Agents.md b/Agents.md index ef5709a2..e217edd6 100644 --- a/Agents.md +++ b/Agents.md @@ -1,20 +1,9 @@ ## agent 规则 -+ This is the only AGENTS.md, there are no recursive AGENTS.md -+ When you are working on a bug, first create a standalone file that reproduces the bug and verify it fails in the expected way. Use this to test if your changes work. Once the change is passing, find an appropriate test file to add the test to and make sure to follow local conventions on the test file. -+ Always respond in 中文,不要回答重复的内容(如我提问中的代码) ++ 这是唯一的Agents.md,无递归的Agents.md ++ Always respond in 中文 ++ 请节省代码行数,代码行数越少,质量越高 ++ 不要在代码中提供注释,除非我明确要求 -## deepx的架构 +## deepx重要设计文件列表 -项目分为3部分 -1. 前端。python库的接口风格参考pytorch -2. 编译,调度器,待设计 -3. 执行器,使用c++,cuda,metal,omp simd等,实现不同executor的算子 - -# 关于deepx的细节概念 -+ deepx.Tensor仅仅就是一个tensor,不像pytorch的tensor,一个tensor其实包含了自身和梯度2个tensor的数据 - - -贴近pytorch的接口风格,不要增加任何注释,我会手动添加注释 - -关于doc目录 -采用Sphinx构建,使用reStructuredText格式 \ No newline at end of file +docs/deepxIR/deepxir.md //非常重要 diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..af7c4c16 --- /dev/null +++ b/Makefile @@ -0,0 +1,170 @@ +# ═══════════════════════════════════════════════════════════════ +# DeepX 统一构建入口 +# ═══════════════════════════════════════════════════════════════ +# +# 所有构建必须通过 make 执行: +# make build-all 构建全部项目 +# make build-vm 构建 VM + loader (Go) +# make build-deepxctl 构建 deepxctl CLI (Go) +# make build-op-metal 构建 Metal 计算平面 (C++/Metal) +# make build-heap-metal 构建 Metal 堆管理平面 (C++/ObjC++) +# make build-io-metal 构建 I/O 平面 (C++) +# make test-vm 运行 VM 单元测试 +# make pipeline 完整流水线 +# +# 更多目标见 make help + +# ── Go 项目路径 ── +VM_DIR := executor/vm +DEEPXCTL_DIR := tool/deepxctl +OP_METAL_DIR := executor/op-metal +HEAP_METAL_DIR := executor/heap-metal +IO_METAL_DIR := executor/io-metal + +# ── Go 环境 ── +GOROOT ?= $(HOME)/sdk/go +GOPATH := $(shell $(GOROOT)/bin/go env GOPATH 2>/dev/null || echo $(HOME)/go) +export GOROOT +export GOPROXY ?= https://goproxy.cn,direct +export PATH := $(GOROOT)/bin:$(GOPATH)/bin:$(PATH) + +# ── 输出目录 ── +VM_OUT := /tmp/deepx-vm +OP_METAL_OUT := /tmp/deepx/op-metal/build +HEAP_METAL_OUT := /tmp/deepx/heap-metal/build +IO_METAL_OUT := /tmp/deepx/io-metal/build + +# ── Redis 配置 (用于联调) ── +REDIS_ADDR ?= 127.0.0.1:16379 + +.PHONY: help \ + build-all build-vm build-deepxctl build-op-metal build-heap-metal build-io-metal \ + test-vm test-integration \ + start-services stop-services status \ + pipeline reset-redis \ + clean clean-all + +# ═══════════════════════════════════════════════════════════════ +# Help +# ═══════════════════════════════════════════════════════════════ + +help: + @echo "DeepX Makefile (Root)" + @echo "" + @echo "BUILD:" + @echo " make build-all Build all projects" + @echo " make build-vm Build VM + loader (Go) → $(VM_OUT)/vm, $(VM_OUT)/loader" + @echo " make build-deepxctl Build deepxctl (Go) → $(DEEPXCTL_DIR)/deepxctl" + @echo " make build-op-metal Build Metal compute plane (C++) → $(OP_METAL_OUT)/deepx-op-metal" + @echo " make build-heap-metal Build Metal heap plane (C++) → $(HEAP_METAL_OUT)/deepx-heap-metal" + @echo " make build-io-metal Build I/O plane (C++) → $(IO_METAL_OUT)/deepx-io-metal" + @echo "" + @echo "TEST:" + @echo " make test-vm Run VM unit tests" + @echo " make test-integration Run VM integration tests (needs Redis)" + @echo "" + @echo "SERVICES (daemon):" + @echo " make start-services Start op-metal + heap-metal in background" + @echo " make stop-services Stop all background services" + @echo " make status Check service/Redis status" + @echo "" + @echo "PIPELINE:" + @echo " make pipeline Full cycle: build → start → reset → stop" + @echo "" + @echo "UTILS:" + @echo " make reset-redis Reset Redis (FLUSHDB)" + @echo " make clean Remove build artifacts" + @echo " make clean-all Clean all including temp output dirs" + @echo "" + @echo "Config via env:" + @echo " REDIS_ADDR=$(REDIS_ADDR) GOROOT=$(GOROOT) GOPROXY=$(GOPROXY)" + +# ═══════════════════════════════════════════════════════════════ +# Build — All +# ═══════════════════════════════════════════════════════════════ + +build-all: build-vm build-deepxctl build-op-metal build-heap-metal build-io-metal + @echo "=== build-all complete ===" + +# ═══════════════════════════════════════════════════════════════ +# Build — Go Projects +# ═══════════════════════════════════════════════════════════════ + +build-vm: + @echo "=== Building VM ===" + @command -v go >/dev/null 2>&1 || (echo "ERROR: go not found in PATH (GOROOT=$(GOROOT))" && exit 1) + @echo "Go version: $$(go version)" + mkdir -p $(VM_OUT) + cd $(VM_DIR) && go mod tidy + cd $(VM_DIR) && go build -ldflags="-s -w" -o $(VM_OUT)/vm ./cmd/vm/ + cd $(VM_DIR) && go build -ldflags="-s -w" -o $(VM_OUT)/loader ./cmd/loader/ + @echo " → $(VM_OUT)/vm" + @echo " → $(VM_OUT)/loader" + +build-deepxctl: + @echo "=== Building deepxctl ===" + @command -v go >/dev/null 2>&1 || (echo "ERROR: go not found in PATH (GOROOT=$(GOROOT))" && exit 1) + @echo "Go version: $$(go version)" + cd $(DEEPXCTL_DIR) && go mod tidy + cd $(DEEPXCTL_DIR) && go build -ldflags="-s -w" -o deepxctl . + @echo " → $(DEEPXCTL_DIR)/deepxctl" + +# ═══════════════════════════════════════════════════════════════ +# Build — C++ Projects (delegate to executor/Makefile) +# ═══════════════════════════════════════════════════════════════ + +build-op-metal: + @echo "=== Building op-metal ===" + cd executor && $(MAKE) build-op + +build-heap-metal: + @echo "=== Building heap-metal ===" + cd executor && $(MAKE) build-heap + +build-io-metal: + @echo "=== Building io-metal ===" + cd executor && $(MAKE) build-io + +# ═══════════════════════════════════════════════════════════════ +# Test +# ═══════════════════════════════════════════════════════════════ + +test-vm: + cd $(VM_DIR) && go test ./... -count=1 -run "^Test[^I]" -v + +test-integration: + cd executor && $(MAKE) test-integration REDIS_ADDR=$(REDIS_ADDR) + +# ═══════════════════════════════════════════════════════════════ +# Services & Pipeline (delegate to executor/Makefile) +# ═══════════════════════════════════════════════════════════════ + +start-services: + cd executor && $(MAKE) start-services REDIS_ADDR=$(REDIS_ADDR) + +stop-services: + cd executor && $(MAKE) stop-services + +status: + cd executor && $(MAKE) status REDIS_ADDR=$(REDIS_ADDR) + +reset-redis: + cd executor && $(MAKE) reset-redis REDIS_ADDR=$(REDIS_ADDR) + +pipeline: + cd executor && $(MAKE) pipeline REDIS_ADDR=$(REDIS_ADDR) + +# ═══════════════════════════════════════════════════════════════ +# Clean +# ═══════════════════════════════════════════════════════════════ + +clean: + cd executor && $(MAKE) clean + cd $(DEEPXCTL_DIR) && rm -f deepxctl + cd $(VM_DIR) && go clean -testcache + +clean-all: clean + rm -rf $(VM_OUT) + rm -rf $(OP_METAL_OUT) + rm -rf $(HEAP_METAL_OUT) + rm -rf $(IO_METAL_OUT) diff --git a/README.md b/README.md index ab104c08..8bb667d0 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ deepx可以划分为前端与统一存算面(中端与后端),分别是为 + 统一寻址空间 - * 当前采用redis存储tensor元信息,配合heapmem进程,负责管理堆tensor的生命周期。 + * 当前采用redis存储tensor元信息,配合heap进程,负责管理堆tensor的生命周期。 + 调度层:编译替换与分布式调度层:注册了多轮不同类型的IR编译器,实现等价替换,可以以插件的形式增加自定义能力如定制kvcache,实现对计算图进行局部替换,获得新的能力。 * 算子注册: 收集当前已就绪的执行器的算子列表,收集算子时耗和存储占用信息。计算图编译器优化器:fusion算子,计算图节点消除,自动生成tensor拆分并行的计算子图并替代原节点 @@ -38,9 +38,9 @@ deepx可以划分为前端与统一存算面(中端与后端),分别是为 * 执行调度器:负责数据并行,流水线并行(前向反向并行),模型并行 + 执行器层:绑定具体的加速硬件,实现真正的tensor的储存、计算、网络通信,大规模并行化。 - * heapmem-cuda:实现了nv平台的tensor生命周期管理,是统一寻址空间中的tensor的具体实现。 - 当我们在统一寻址空间删除一个key对应的tensor,实际的tensor会通过heapmem-cuda进程进行删除,创建同理。 - heapmem管理的tensor,通常是持久的权重,可能被很多个不同进程访问,。 + * heap-cuda:实现了nv平台的tensor生命周期管理,是统一寻址空间中的tensor的具体实现。 + 当我们在统一寻址空间删除一个key对应的tensor,实际的tensor会通过heap-cuda进程进行删除,创建同理。 + heap管理的tensor,通常是持久的权重,可能被很多个不同进程访问,。 相对应的,随着函数执行完毕自动回收的中间变量tensor,可以被称之为stacktensor,这些tensor交给op进程自行管理。 * op-cuda:实现了nv平台的常用基础算子。[cuda](docs/executor/op-mem-cuda/list.md) @@ -55,7 +55,7 @@ deepx可以划分为前端与统一存算面(中端与后端),分别是为 |数据区、代码区|kv存储管理| |上层程序设:func和struct| deepxIR和tensor| |cpu执行底层机器码/字节码| deepx执行器执行deepxIR| -|存储-堆 |kv存储tensor元信息,heapmem管理gpu、内存上的tensordata| +|存储-堆 |kv存储tensor元信息,heap管理gpu、内存上的tensordata| |线程栈|计算进程自行管理| diff --git a/docs/.gitignore b/doc/.gitignore similarity index 100% rename from docs/.gitignore rename to doc/.gitignore diff --git a/docs/README.md b/doc/README.md similarity index 100% rename from docs/README.md rename to doc/README.md diff --git a/docs/conf.py b/doc/conf.py similarity index 100% rename from docs/conf.py rename to doc/conf.py diff --git a/doc/deepxctl/.gitignore b/doc/deepxctl/.gitignore new file mode 100644 index 00000000..006f0497 --- /dev/null +++ b/doc/deepxctl/.gitignore @@ -0,0 +1,4 @@ +_build/ +*.pyc +__pycache__ +.DS_Store diff --git a/doc/deepxctl/CLAUDE.md b/doc/deepxctl/CLAUDE.md new file mode 100644 index 00000000..5c7b7329 --- /dev/null +++ b/doc/deepxctl/CLAUDE.md @@ -0,0 +1,280 @@ +# deepxctl 开发约束 + +> deepxctl 的职责边界。哪些逻辑 deepxctl 能做,哪些**绝对不能碰**。 + +--- + +## 1. deepxctl 是什么 + +deepxctl 是 DeepX 元程系统的**进程编排工具**。它不实现任何计算/存储/调度逻辑, +只负责**子进程生命周期管理**和**Redis 状态的检查性读写**。 + +deepxctl = 启动脚本的 Go 化替代品。 + +--- + +## 2. 允许做的事(白名单) + +### 2.1 子进程管理 + +| 操作 | 允许 | 说明 | +|------|------|------| +| `exec.Command` 启动 op-plat | ✅ | 调用已编译的二进制 | +| `exec.Command` 启动 heap-plat | ✅ | 同上 | +| `exec.Command` 启动 VM | ✅ | 同上 | +| `exec.Command` 调用 loader | ✅ | 加载 dx 源码到 `/src/func/` | +| 给子进程传递参数(redis addr 等) | ✅ | | +| 捕获子进程 stdout/stderr | ✅ | 仅用于日志展示 | +| SIGTERM / SIGKILL 子进程 | ✅ | 进程清理 | +| 检测子进程退出状态 | ✅ | | + +### 2.2 Redis 操作(仅限读取状态 + vthread 管理) + +| 操作 | 允许 | 说明 | +|------|------|------| +| `PING` | ✅ | 检测 Redis 可达 | +| `FLUSHDB` | ✅ | 重置测试环境(仅限开发端口 16379) | +| `DBSIZE` | ✅ | 验证重置结果 | +| `GET /sys/op-plat/*` | ✅ | 检查 op-plat 就绪状态 | +| `GET /sys/heap-plat/*` | ✅ | 检查 heap-plat 就绪状态 | +| `GET /sys/vm/*` | ✅ | 检查 VM 就绪状态 | +| `GET /vthread/` | ✅ | 轮询 vthread 执行状态 | +| `SET /vthread/` | ✅ | 创建 vthread + 设置 init 状态 | +| `SET /vthread//[0,0]` | ✅ | 设置 vthread 入口 CALL 指令 | +| `INCR /sys/vtid_counter` | ✅ | 分配 vthread ID | +| `LPUSH notify:vm` | ✅ | 唤醒 VM 拾取新 vthread | +| `GET /src/func/` | ✅ | 验证 dx 加载是否成功 | +| `KEYS /src/func/*` | ✅ | 列出已注册函数(status 命令用) | + +### 2.3 构建 + +| 操作 | 允许 | 说明 | +|------|------|------| +| 调用 `executor/*/build.sh` | ✅ | 直接 exec 现有脚本 | +| 检测二进制是否存在 | ✅ | 跳过已构建的组件 | + +### 2.4 平台检测 + +| 操作 | 允许 | 说明 | +|------|------|------| +| 检测 GOOS | ✅ | 判断 macOS vs Linux | +| 检测 Metal 可用 | ✅ | Objective-C 小工具或 go 绑定 | +| 检测 nvidia-smi | ✅ | 判断 CUDA 可用 | + +--- + +## 3. 禁止做的事(黑名单) + +### 3.1 绝对禁止:实现组件内部逻辑 + +> deepxctl 不是 VM,不是 op-plat,不是 heap-plat。 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| 实现任何 tensor 计算 | ❌ | 这是 op-plat 的职责 | +| 分配/释放 shm | ❌ | 这是 heap-plat 的职责 | +| 翻译 dxlang 指令 | ❌ | 这是 VM 的职责 | +| 执行 CALL/RETURN | ❌ | 这是 VM 的职责 | +| 路由算子到 op-plat | ❌ | 这是 VM route 包的职责 | +| 解析 dxlang 语法 | ❌ | 这是 loader + VM ir 包的职责 | +| 注册算子列表 | ❌ | 这是 op-plat 自行注册的 | +| 生产 Redis 命令队列消息 | ❌ | 这是 VM 的职责(push 到 cmd:*) | +| 消费 Redis 命令队列 | ❌ | 这是 op-plat/heap-plat 的职责 | + +### 3.2 禁止:修改其他组件 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| 修改 op-plat 源码 | ❌ | 跨组件边界 | +| 修改 heap-plat 源码 | ❌ | 跨组件边界 | +| 修改 VM 源码 | ❌ | 跨组件边界 | +| 修改 build.sh 脚本 | ❌ | 它们是独立的构建原语 | +| 修改 CMakeLists.txt | ❌ | 构建系统属于各组件 | + +### 3.3 禁止:引入新的通信协议 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| deepxctl 与 VM 直接通信 | ❌ | 所有通信通过 Redis KV 空间 | +| deepxctl 与 op-plat 直接通信 | ❌ | 同上 | +| 自定义 socket/gRPC/HTTP API | ❌ | 不引入额外协议 | +| 新增 Redis key 模式 | ❌ | 只能使用已定义的 key 模式 | + +### 3.4 禁止:在生产环境自动 FLUSHDB + +| 操作 | 禁止 | 原因 | +|------|------|------| +| 在非 16379 端口自动 FLUSHDB | ❌ | 可能误删生产数据 | +| 无确认直接清除非本地 Redis | ❌ | 同上 | + +--- + +## 4. deepxctl 的 Redis key 使用边界 + +deepxctl 只能**读**以下 key(状态检查): + +``` +读取: + /sys/op-plat/metal:0 → 检查 op-plat 是否就绪 + /sys/heap-plat/metal:0 → 检查 heap-plat 是否就绪 + /sys/vm/0 → 检查 VM 是否就绪 + /src/func/ → 验证 dx 加载成功 + /vthread/ → 轮询执行状态 (pc, status) + /op//list → (可选) 验证算子注册 +``` + +deepxctl 只能**写**以下 key(vthread 生命周期): + +``` +写入: + /sys/vtid_counter → INCR 分配 vthread ID + /vthread/ → SET 创建 vthread (pc + status) + /vthread//[0,0] → SET 入口指令 + notify:vm → LPUSH 唤醒 VM +``` + +deepxctl **绝对不能**读写的 key: + +``` +禁止: + /vthread//[*,*] → 指令坐标 (这是 VM 的私有格式) + /op/*/func/* → 编译层 (这是 VM + 编译器的) + cmd:op-* → 命令队列 (这是 VM 生产、op-plat 消费) + cmd:heap-* → 命令队列 (这是 VM 生产、heap-plat 消费) + done:* → 完成通知 (这是 VM 消费) + /lock/* → 锁 (这是 VM 管理) + 堆变量 (任意非保留路径) → tensor 元信息 (这是 pysdk + heap-plat 管理) +``` + +--- + +## 5. 命令架构 + +deepxctl 将生命周期拆分为三个独立命令: + +``` +deepxctl boot → 构建 + 启动 op-metal、heap-metal、VM,写入 PID 文件 +deepxctl run a.dx → 检测 boot 状态 → 加载 dx → 创建 vthread → 轮询等待结果 +deepxctl shutdown → 有序退出: plats → VM → 心跳验证 → 清理 PID 文件 +``` + +### `deepxctl run` 执行流程 + +``` +deepxctl run xxx.dx [--rm] +│ +│ (deepxctl 负责的部分) +│ +├─ [1/3] Check services ─────── 检查 boot PID 文件 + Redis 服务就绪 +├─ [2/3] Load dx ────────────── exec loader 二进制 (子进程加载 .dx 到 /src/func/) +├─ [3/3] Execute ────────────── +│ ├─ create vthread ───────── SET /vthread/ (初始状态) +│ ├─ wake VM ──────────────── LPUSH notify:vm +│ │ │ +│ │ ▼ (VM 接手 — deepxctl 不参与) +│ │ VM 拾取 → CALL 翻译 → dispatch → op/heap → PC++ +│ │ +│ └─ poll status ──────────── GET /vthread/ (轮询 status) +│ ├─ done → print result ✓ +│ └─ error → print error ✗ +│ +└─ [--rm] Cleanup (可选) + ├─ FLUSHDB ───────────────── 重置 Redis KV 空间 + └─ ExecShutdown ──────────── 复用 shutdown 逻辑: plats → VM → 清理 +``` + +**关键分界线**:deepxctl 在 `notify:vm` 之后就不再参与执行——后续所有步骤(CALL 翻译、 +指令 dispatch、op 执行、done 通知)都是 VM/op-plat/heap-plat 之间通过 Redis 的协作。 +deepxctl 只是**旁观**:轮询 `/vthread/` 的 status 字段,直到 `done` 或 `error`。 + +### `--rm` 一键清理 + +`deepxctl run a.dx --rm` 在 dx 代码执行成功后自动: +1. **FLUSHDB** — 重置 Redis KV 空间 +2. **shutdown** — 复用 `deepxctl shutdown` 的完整退出逻辑(Redis sys:shutdown 命令 → 心跳验证 → 清理 PID 文件 → OS 信号兜底) + +等价于手动执行: +```bash +deepxctl run a.dx && make reset-redis && deepxctl shutdown +``` + +--- + +## 6. 不允许的"快捷方式" + +以下是一些看似方便但**绝不能做**的事情: + +| 快捷方式 | 为什么不行 | +|---------|-----------| +| 直接在 deepxctl 里解析 dxlang 找入口函数 | loader + VM 已有解析逻辑,deepxctl 不应重复实现语法解析 | +| 直接 SET `/vthread//[0,1]` 等详细指令 | VM 的 CALL eager 翻译负责展开指令坐标,deepxctl 只负责最顶层的一个 CALL | +| 直接 LPUSH `cmd:op-metal:0` | 命令队列由 VM 生产,deepxctl 绕过 VM 会破坏调度逻辑 | +| 通过 `done:` 轮询完成 | VM 消费 done 队列后更新 `/vthread/` status,deepxctl 应该读 status 而非直接消费 done 队列 | +| 读取 `/vthread//[*,*]` 指令 | 这是 VM 的内部数据格式,外部不应依赖 | +| 给 VM 发自定义信号 | 所有通信必须走 Redis KV 空间 | + +--- + +## 7. 入口函数约定 + +deepxctl 创建 vthread 时,写入的指令是**一个顶层 CALL**: + +``` +/vthread/ = {"pc":"[0,0]","status":"init"} +/vthread//[0,0] = "" ← CALL 指令的操作码 +/vthread//[0,1] = "./ret" ← 返回值槽位 +``` + +VM 执行到 `[0,0]` 时: +1. 识别 `` 不是内置关键字 +2. 检查 `/src/func/` 存在 +3. 触发 CALL eager 翻译,展开函数体到子栈 `[0,0]/[0,0]`, `[0,0]/[1,0]`... +4. 继续逐条执行 + +**入口函数名确定规则**: +1. 文件中有 `def main` → 用 `main` +2. 只有一个 `def` → 用那个名字 +3. 多个 `def` 且无 `main` → 报错,要求 `--entry` 指定 + +> 入口函数名从文件名推断(loader 已实现命名逻辑),deepxctl 调用 loader 后 +> 可以 GET `/src/func/*` 的 KEYS 结果确定有哪些函数可用。 + +--- + +## 8. 实现语言和依赖 + +| 项 | 选择 | 约束 | +|----|------|------| +| 语言 | Go | 与 VM 一致,单二进制 | +| Redis 客户端 | `go-redis/v9` | 与 VM 相同依赖 | +| 进程管理 | `os/exec` | Go 标准库 | +| CLI 框架 | 能跑就行(flag 包即可) | MVP 不引入重型框架 | +| 配置文件 | 硬编码先(后续 YAML) | MVP 阶段不引入 viper | + +--- + +## 9. 当前文件清单 + +``` +tool/deepxctl/ +├── main.go +├── go.mod +├── cmd/ +│ ├── boot.go ← boot 子命令 (构建 + 启动服务) +│ ├── run.go ← run 子命令 (加载 dx + 创建 vthread + 轮询) +│ ├── shutdown.go ← shutdown 子命令 (有序退出服务) +│ └── common.go ← 共享打印/辅助函数 +├── internal/ +│ ├── redis/redis.go ← 连接 + FLUSHDB + 状态检查 + vthread 管理 +│ ├── builder/builder.go ← exec build.sh +│ ├── process/manager.go ← 子进程生命周期 +│ └── executor/executor.go ← vthread 创建 + 轮询 +└── tensor/ ← tensor 文件操作 (print/save/load) +``` + +不做的: +- YAML 配置文件解析 +- JSON 输出格式(结构化) +- 多平台自动检测(先只支持 metal) +- 守护进程模式 +- 远程 Redis TLS diff --git a/doc/deepxctl/README.md b/doc/deepxctl/README.md new file mode 100644 index 00000000..24b56c8c --- /dev/null +++ b/doc/deepxctl/README.md @@ -0,0 +1,211 @@ +# deepxctl run — 一键运行 dx 代码 + +> `deepxctl run xxx.dx` 一条命令完成 Redis 准备、组件启动、代码加载、执行、收尾。 +> 职责边界见 [CLAUDE.md](CLAUDE.md)。 + +--- + +## 1. 解决的问题 + +当前调试 .dx 代码需手动 8 步: + +``` +redis-server --port 16379 & +redis-cli -p 16379 FLUSHDB +./executor/op-metal/build.sh +./executor/heap-metal/build.sh +./executor/vm/build.sh +/tmp/deepx/op-metal/build/deepx-op-metal 127.0.0.1 16379 & +/tmp/deepx/heap-metal/build/deepx-heap-metal 127.0.0.1 16379 & +VM_ID=0 /tmp/deepx-vm/vm 127.0.0.1:16379 & +/tmp/deepx-vm/loader example/dxlang/lifecycle/full.dx 127.0.0.1:16379 +# 还需要手动创建 vthread... +``` + +deepxctl 的目标: + +``` +deepxctl run example/dxlang/lifecycle/full.dx +``` + +--- + +## 2. 执行流程 + +``` +deepxctl run full.dx +│ +├─ [1/6] Redis +│ PING → 不可达则报错退出 +│ FLUSHDB → 重置 KV 空间 +│ +├─ [2/6] Build (按需) +│ 检测二进制是否存在: +│ /tmp/deepx/op-metal/build/deepx-op-metal +│ /tmp/deepx/heap-metal/build/deepx-heap-metal +│ /tmp/deepx-vm/vm +│ /tmp/deepx-vm/loader +│ 缺失 → exec build.sh 构建 +│ +├─ [3/6] 启动子进程(按依赖顺序) +│ ① op-plat → GET /sys/op-plat/metal:0 等待 status=running +│ ② heap-plat → GET /sys/heap-plat/metal:0 等待 status=running +│ ③ VM → GET /sys/vm/0 等待 status=running +│ +├─ [4/6] 加载 dx 代码 +│ exec loader 二进制 → 写入 /src/func/ +│ 验证: GET /src/func/ 非空 +│ +├─ [5/6] 创建 vthread + 执行 +│ INCR /sys/vtid_counter → vtid +│ SET /vthread/ = {"pc":"[0,0]","status":"init"} +│ SET /vthread//[0,0] = "" +│ SET /vthread//[0,1] = "./ret" +│ LPUSH notify:vm {"event":"new_vthread","vtid":""} +│ ┌─ 轮询 GET /vthread/ +│ │ status=done → 成功 +│ │ status=error → 打印错误 +│ └─ 超时 → TIMEOUT_ERROR +│ +└─ [6/6] 清理 + SIGTERM → 等待 2s → SIGKILL 所有子进程 +``` + +--- + +## 3. 入口函数确定规则 + +deepxctl 创建 vthread 只写**一条顶层 CALL 指令**: + +``` +/vthread//[0,0] = "" # CALL 操作码 +/vthread//[0,1] = "./ret" # 返回值槽位 +``` + +VM 拾取后自动识别 `` 非内置关键字 → 查找 `/src/func/` → CALL eager 翻译 → 展开函数体到子栈。 + +**入口函数名规则**: +1. 文件中有 `def main` → 用 `main` +2. 只有一个 `def` → 用那个名字 +3. 多个 `def` 且无 `main` → 报错,要求 `--entry` 指定 + +入口函数名从 loader 的输出获取(loader 加载时打印 `→ /src/func/`),或 GET `/src/func/*` KEYS 推断。 + +--- + +## 4. CLI 用法 + +``` +deepxctl run [flags] + +flags: + -r, --redis string Redis 地址 (默认: 127.0.0.1:16379) + -b, --build 强制重新构建 + --no-reset 跳过 Redis FLUSHDB + --keep-alive 执行后保持进程运行 (调试) + -v, --verbose 输出子进程 stdout/stderr + --entry string 指定入口函数名 (多 def 且无 main 时必须) + --timeout int 执行超时秒数 (默认: 60, 0=无限制) + +示例: + deepxctl run example/dxlang/lifecycle/full.dx + deepxctl run -v example/dxlang/call/tensor_pipeline.dx + deepxctl run --entry stage1 --timeout 30 example/dxlang/call/tensor_pipeline.dx +``` + +--- + +## 5. 输出格式 + +``` +$ deepxctl run example/dxlang/lifecycle/full.dx + + deepxctl | redis: 127.0.0.1:16379 +───────────────────────────────────────── + +[1/6] Redis ........................ ✓ +[2/6] Build ........................ ✓ (up-to-date) +[3/6] op-plat ...................... ✓ (pid=12345) + heap-plat .................... ✓ (pid=12346) + VM ........................... ✓ (pid=12347) +[4/6] Load: full.dx ................ ✓ (/src/func/lifecycle_full) +[5/6] Execute: lifecycle_full ...... ✓ (vtid=1, done, 0.042s) +[6/6] Cleanup ...................... ✓ + +───────────────────────────────────────── +SUCCESS vtid=1 status=done 42ms +───────────────────────────────────────── +``` + +错误输出: + +``` +$ deepxctl run bad.dx + + deepxctl | redis: 127.0.0.1:16379 +───────────────────────────────────────── + +[1/6] Redis ........................ ✓ +[2/6] Build ........................ ✓ +[3/6] op-plat ...................... ✓ + heap-plat .................... ✓ + VM ........................... ✓ +[4/6] Load: bad.dx ................. ✗ + +───────────────────────────────────────── +ERROR loader exit code 1 + /src/func/ not found after loading bad.dx +───────────────────────────────────────── +``` + +--- + +## 6. 错误码 + +| 退出码 | 含义 | +|--------|------| +| 0 | 成功 | +| 10 | Redis 连接失败 | +| 20 | 组件构建失败 | +| 30 | 子进程启动失败/超时 | +| 40 | dx 加载失败 | +| 50 | vthread 执行失败 (status=error) | +| 60 | vthread 执行超时 | +| 99 | 内部错误 | + +--- + +## 7. 实现结构 + +``` +tool/deepxctl/ # 复用现有目录 +├── main.go # 入口 +├── cmd/run.go # run 子命令 +├── internal/ +│ ├── redis/redis.go # 连接 + FLUSHDB +│ ├── build/builder.go # exec build.sh 子进程 +│ ├── process/manager.go # 子进程启动/停止/就绪等待 +│ ├── loader/loader.go # exec loader 子进程 +│ └── executor/executor.go # vthread 创建 + 轮询 +└── tensor/ # 已有,不动 +``` + +--- + +## 8. 与组件的关系 + +deepxctl 只做**进程编排**,不实现任何计算逻辑: + +``` +deepxctl 组件 +──────── ───── +✓ 启动/停止子进程 ✗ 不实现 tensor 计算 (op-plat) +✓ 调用 build.sh ✗ 不分配 shm (heap-plat) +✓ 调用 loader ✗ 不翻译 dxlang (VM) +✓ FLUSHDB ✗ 不生产 cmd:* 队列消息 (VM) +✓ 写 /vthread/ init ✗ 不解析 dxlang 语法 (loader/VM) +✓ 轮询 status ✗ 不消费 done:* 队列 (VM) +✓ 结束清理子进程 ✗ 不注册算子 (op-plat 自行注册) +``` + +详细约束见 [CLAUDE.md](CLAUDE.md)。 diff --git a/docs/design.md b/doc/design.md similarity index 100% rename from docs/design.md rename to doc/design.md diff --git a/doc/dxlang/README.md b/doc/dxlang/README.md new file mode 100644 index 00000000..901d76ba --- /dev/null +++ b/doc/dxlang/README.md @@ -0,0 +1,291 @@ +# dxlang + +> dxlang 是 deepx 元程级的编程语言。**deepxir 即 dxlang 的指令设计部分**——两者不是两层,而是同一语言的不同视角。 + +## 0. 统一语法哲学 + +LLVM 将表示划分为多层:高级语言 → 中层 IR → 低级机器码,各层语法互不相通,编译器在其中逐层翻译。 + +dxlang 不这样做。**dxlang 以同一种语法,同时承担四种职能**: + +| 职能 | 说明 | +|------|------| +| **VM 指令执行** | `add(A, B) -> C` 即 VM 可逐条解码执行的操作码 | +| **高级语言定义** | `def gemm(...) -> (...) { ... }` 即程序员编写的函数 | +| **编译器分析** | `->`/`<-` 数据流箭头天然可做 use-def 链分析,无需构造 SSA | +| **人类可阅读** | 纯文本、无编号寄存器、无 phi 节点,源码即 IR | + +**与 LLVM 的本质区别**: + +``` +LLVM: C++/Rust → LLVM IR → Machine Code + 三层语法,逐层翻译,互不兼容 + +dxlang: def → op(A,B)->C → /vthread/[i,j] + 同一语法,同一表示,视角不同而已 +``` + +**deepxir** 即这条指令 `op(A, B) -> C` 的设计——它不是独立的一层,它就是 dxlang 本身,是 dxlang 在"指令"视角下的名字。 + +### 执行模型定位 + +**底座是分布式的,语言是单线程的。** + +deepx 的底层是一个分布式系统:Redis KV 空间、heap-plat 跨进程管理 shm、op-plat 常驻消费计算指令、VM 多 worker 并行调度 vthread——但 dxlang **使用起来就像 SQL 一样**,程序员看到的是极其简单的单线程顺序逻辑: + +```dxlang +# 程序员视角:就几条顺序语句,像写 SQL 一样简单 +def hadamard3() -> ("/data/result") { + newtensor("f32", "[128]") -> "/data/a" + newtensor("f32", "[128]") -> "/data/b" + mul("/data/a", "/data/b") -> "/data/temp" # GPU 执行,程序员无感 + deltensor("/data/a") +} +``` + +| 底座(分布式) | 语言(单线程) | +|----------------|----------------| +| Redis KV 空间跨进程共享 | 程序员只写 `A + B -> C` | +| heap-plat 管理 shm 生命周期 | `newtensor` / `deltensor` 就像 `malloc` / `free` | +| op-plat 被动消费 GPU 指令 | `mul` / `add` 不感知 Metal / CUDA | +| VM 多 worker 并行调度 | `def` 函数定义,调用即 `CALL` 翻译 | + +这就是 dxlang 的核心设计意图:**为 AI 任务服务的声明式语言**——数据流用箭头 (`->`) 表达,计算交给平台代理,使用者只需关心"什么算子、什么数据、什么顺序",无需关心 GPU 型号、内存布局、进程间通信。 + +与 C/CUDA 的本质差异: + +``` +C/CUDA: 程序员管理一切 ── 内存分配、设备选择、数据传输、kernel launch +dxlang: 程序员只写数据流 ── 底座自动调度、分配、执行、回收 +``` + +dxlang 不接触物理地址,不管理内存分配,不感知设备拓扑——这些全部由 heap-plat 和 op-plat 的常驻进程代理。 + +## 1. 类型系统 + +### 基础数据类型 +``` +type f16, f32, f64, bf16, bf8 +type i8, i16, i32, i64, u8 +type bool +type string +type tensor +``` + +### 类型约束 +``` +f32|f64 +``` + +### Tensor 类型模板 +``` +type tensor +``` +- shape 格式:dim1xdim2x...xdimN,或使用 `?` 表示动态维度 +- 示例:`tensor<10x20xf32>`, `tensor` + +tensor 也可以没有 shape 和 dtype 约束: +``` +func addscalar(A:tensor, b:i8|i16|i32|i64) -> (c:tensor) { ... } +``` + +### 动态维度变量 +- `?` 任意数字 +- `?1` 动态维度变量 1 +- `?2` 动态维度变量 2 +- 示例:`tensor` + + +### 数组类型 +``` +type[] +``` +list 可以与基础类型与 tensor 组合 +## 2.控制流 + +> **v1 整合设计**:[spec-control-flow-v1.md](spec-control-flow-v1.md) — 6 套重构方案全光谱对比 (渐进 → SSA → 编译) + SSA vs `->`/`<-` 语义对比。 +> 详细方案见 [control-flow.md](./control-flow.md)、[frontend-control-flow.md](./frontend-control-flow.md)。 +> 编译器分析:[compiler-analysis-ssa-vs-arrow.md](./compiler-analysis-ssa-vs-arrow.md) — `->`/`<-` 能否替代 SSA 做编译分析 (§1-8) + 变量版本号方案 (§9) + **`resolve` 创新方案填补 φ 缺口 (§10)**。 + +dxlang 支持分支与循环,控制流以“语义块”表达,执行时由解释器按块索引跳转。 + +### 分支 +``` +if (cond:bool) { + op(a)-> b + op2(b)->c +} else { + op3(a)->b + op4(a)->d +} +``` + +### 循环(迭代器) +``` +var list <-[1,2,3,4,5,6,7] +for (i:i32 in range ) { + op(i, a)-> b +} +``` + +## 3.函数语法 + +### 函数定义 +``` +func ir_name(ro_p1:type1, ro_p2:type2, ...) -> (w_p1:type3, w_p2:type4, ...) +{ + operation_name(ro_p1, ro_p2) -> w_p1 + operation_name(ro_p2, ro_p2) -> w_p2 +} +``` + +### 只读/写入参数标记 +dxlang 支持`<-` 与 `->`(或者`<=` 与 `=>`),用于显式区分只读与写入参数,箭头指向写入参数列表; + +示例(标准格式): +``` +func gemm(A:tensor, B:tensor, alpha:f32, beta:f32, C:tensor) -> (Y:tensor) { + matmul(A, B) -> Y + mul(Y, alpha) -> Y + mul(C, beta) -> C + add(Y, C) -> Y +} +``` + +示例: +``` +func ffn(A:tensor, W1:tensor, b1:tensor, W2:tensor, b2:tensor) -> (Y:tensor) { + matmul(A, W1) => Y + add(Y, b1) => Y + gelu(Y) => Y + matmul(Y, W2) => Y + add(Y, b2) => Y +} +``` + +### 创建变量 +``` +var a<=false +var b,c<=1,2 +``` +- `var` 定义新对象 +- 自动类型推断:`b,c` 推断为 `int` + +## 4. KV空间与执行模型 + +### KV空间组织 + +dxlang采用kv地址系统,从而实现跨节点的统一地址,而非传统的单机进程内存模型 + + +简单的串行函数与语句可由 python-sdk 直接写入 KV(如 Redis),示例结构: +``` +/func/gemm = (A:tensor, B:tensor, alpha:f32, beta:f32, C:tensor) -> (Y:tensor) +/func/gemm/0 = matmul(A, B) -> Y +/func/gemm/1 = mul(Y, alpha) -> Y +/func/gemm/2 = mul(C, beta) -> C +/func/gemm/3 = add(Y, C) -> Y +``` + +控制流 if: +``` +.../0 = v +.../1 = if(cond) +.../1/true/0 = add(A, B) -> Y +.../1/false/0 = sub(A, B) -> Y +``` + + +控制流for + + +### 函数体的复杂控制流索引 +对于包含控制流的函数体 + + + + +## 5. pysdk 代码生成 + +> 详细设计方案见 [frontend-control-flow.md](./frontend-control-flow.md) —— 当前 front/py 架构分析 + 三种方案对比 + Hybrid Eager+Defer 推荐方案 + 实现路线图。 + +核心结论:保留现有即刻发射模式不变;通过 `@deepxir.compile` 装饰器新增编译模式,支持 `if`/`for`/`while` 动态控制流的 deepxir 代码生成。 + +## 6. 设计思考 +dxlang 采用简洁文本格式表达类型约束、运算定义与运算体,便于阅读与解析。 +dxlang 不是 SSA,调用时遵循一侧读、另一侧写的规则,参数列表支持多个。 +dxlang 作为调度语义与协议载体,不负责算子实现与存储生命周期。 + +## 7. 具体示例 + +### 示例 1:融合 Linear + 归一化 +``` +func fused_linear_norm( + A: tensor, + W: tensor, + b: tensor, + axis: i32, + keepdims: bool +) -> (out: tensor) { + newtensor(?1x?3, f32)->(mm) + matmul(A, W)-> (mm) + newtensor(?1x?3, f32)-> bias + add(mm, b)-> bias + deltensor(mm)-> mm + newtensor(?1, f32)-> mean + sum(bias, axis, keepdims)-> mean + newtensor(?1x?3, f32)-> centered + sub(bias, mean)-> centered + deltensor(bias)-> bias + deltensor(mean)-> mean + newtensor(?1x?3, f32)-> sq + mul(centered, centered)-> sq + deltensor(centered)-> centered + newtensor(?1, f32)-> var + sum(sq, axis, keepdims)-> var + deltensor(sq)-> sq + constant(1e-5)-> eps + newtensor(?1, f32)-> var_eps + add(var, eps)-> var_eps + deltensor(var)-> var + deltensor(eps)-> eps + newtensor(?1, f32)-> std + sqrt(var_eps)-> std + deltensor(var_eps)-> var_eps + div(std, std)-> std + deltensor(std)-> std + div(centered, std)-> out +} +``` + +``` +func example_use_fused_linear_norm() -> (out: tensor<2x3xf32>) { + newtensor([2,4], f32)-> A + newtensor([4,3], f32)-> W + newtensor([3], f32)-> b + fused_linear_norm(A, W, b, 1, false) -> out +} +``` + +### 示例 2:融合 Attention score + Softmax +``` +func fused_attention_scores( + Q: tensor, + K: tensor, + axis: list, + keepdims: bool, + shape_scores: list, + shape_sum: list +) -> (out: tensor) { + newtensor(shape_scores, f32)-> scores_tmp + matmul(Q, K)-> scores_tmp + newtensor(shape_scores, f32)-> exp_tmp + exp(scores_tmp)-> exp_tmp + deltensor(scores_tmp)-> scores_tmp + newtensor(shape_sum, f32)-> sum_tmp + sum(exp_tmp, axis, keepdims)-> sum_tmp + div(exp_tmp, sum_tmp)-> out + deltensor(exp_tmp)-> exp_tmp + deltensor(sum_tmp)-> sum_tmp +} +``` \ No newline at end of file diff --git a/doc/dxlang/compiler-analysis-ssa-vs-arrow.md b/doc/dxlang/compiler-analysis-ssa-vs-arrow.md new file mode 100644 index 00000000..24f04cf8 --- /dev/null +++ b/doc/dxlang/compiler-analysis-ssa-vs-arrow.md @@ -0,0 +1,434 @@ +# `->`/`<-` 与 SSA 编译器分析能力对比 + +> **核心问题**:deepxir 的 `->`/`<-` 读写分离模型,能否在不引入完整 SSA 的前提下,支持经典编译器的全部分析和优化 pass? +> **结论**:是。通过**变量版本号编码控制流合流**,以零语法增量获得等同 SSA 的完整分析能力。 + +--- + +## 1. 问题定义 + +``` +给定一个 ->/<- 形式的 deepxir 函数: + + fn example(flag: bool, a: f32, b: f32) -> y: f32 { + a + b -> t + if flag { + t * 2.0 -> x + } else { + t + 3.0 -> x + } + x + 1.0 -> y # 写入输出参数 y 即隐式返回 + } + +目标: 在不转换为完整 SSA (不引入 block argument、φ 节点、CFG 基本块) 的前提下: + Q1: 死代码消除 — 跨分支的未使用值能否识别并删除? + Q2: 公共子表达式消除 — 相同表达式能否跨分支识别? + Q3: 常量传播 — 条件常量能否跨分支传播? + Q4: 全局值编号 — 跨分支的值能否分配唯一编号? + Q5: 循环不变量外提 — 循环内不变计算能否移到循环外? + Q6: 寄存器分配 — 能否构建精确的干涉图? +``` + +## 2. SSA 与 `->`/`<-` 的本质差异 + +``` +┌────────────────────────────────────────────────────────────┐ +│ │ +│ SSA = 值身份模型 (Value Identity) │ +│ 每个变量有唯一静态定义点 │ +│ 追踪"哪个值"被使用 │ +│ use-def 链: 从使用点 O(1) 回溯到唯一定义点 │ +│ 控制流合流: block argument / φ 节点 │ +│ 典型系统: LLVM IR, MLIR, GCC GIMPLE │ +│ │ +│ ->/<- = 存储效应模型 (Storage Effect) │ +│ 变量是可复用的存储槽位 │ +│ 追踪"哪个存储位置"被读写 │ +│ 数据流: reads[] / writes[] 数组显式标注 │ +│ 控制流合流: slot 复用 (同名字覆盖) │ +│ 典型系统: deepxir │ +│ │ +│ 核心分歧: SSA 关心"值的身份",->/<- 关心"存储的效应"。 │ +│ 两者服务不同层次,不是替代关系,是互补关系。 │ +│ │ +└────────────────────────────────────────────────────────────┘ +``` + +| 维度 | SSA | `->`/`<-` | +|------|-----|----------| +| **变量赋值次数** | 严格 1 次 | 无限次(可变) | +| **值标识** | 虚拟寄存器编号 (`%0`, `%1`) | KV 路径 (`./x`, `/data/W`) | +| **数据流表达** | 操作数引用 = 隐式 use-def 链 | 显式 reads[] / writes[] 数组 | +| **控制流合流** | block argument / φ 节点 | slot 复用 (同名覆盖) | +| **副作用建模** | 需隔离到 dialect (memref) | 天然 (tensor.new/del 即指令) | +| **定义-使用关系** | 支配树保证: 定义支配所有使用 | 需额外分析 (到达定义) | + +### 线性段的等价性 + +在无控制流分支的线性代码中,两者严格等价。每个 SSA 虚拟寄存器 `%i` 与首次出现在写位置的变量名构成双射映射。 + +### 控制流合流 — 两种模型的分岔点 + +``` +场景: if/else 分支后使用结果 + +SSA (MLIR 风格): ->/<- (slot 复用): + br flag ? @then : @else if flag { + a + 1.0 -> x +@then: } else { + %t = add %a, 1.0 a - 1.0 -> x + br @merge(%t) } + x * 2.0 -> y +@else: + %e = sub %a, 1.0 + br @merge(%e) + +@merge(%y: f32): ← block arg + %r = mul %y, 2.0 + ret %r + +机制: block argument 参数化 机制: slot x 被两个分支复用 +优势: SSA 贯穿始终 优势: 零额外语法,VM 直接读 slot +代价: 需 CFG + block 分割 代价: 失去值身份,编译器分析困难 +``` + +**核心矛盾**: SSA 为编译器分析而生,但要求 CFG + block 结构。 +`->`/`<-` 为执行透明而生,但丢失值身份信息。 +**版本号方案同时在两种模型上取长补短。** + +## 3. 方案:变量版本号编码合流 + +### 3.1 版本号格式 + +``` +版本号以 @ 与变量名分隔(区别于 KV 路径的 / 分隔符)。 +@ 之后用 / 分层,每级使用语义单词标记控制流来源, +合流点以 merge 显式标注。 + +格式: + slot "@" base ("/" branch)* + + base = INT # 根版本号 (0, 1, 2...) + branch = "then" | "else" # if/else 分支来源 + | "loop" | "body" | "step" # for/while 循环来源 + | "merge" # 合流点 (隐式 φ) +``` + +### 3.2 设计原则 + +``` +规则: + +1. 每个 slot 维护独立的 base 版本计数器,从 0 开始 +2. 线形写入 (-> 右侧): 产生新 base 版本 slot@N + 读取 (-> 左侧): 使用当前可见版本 +3. 进入控制流分支: 在当前版本上追加分支标签 (/then /else /loop 等) + 嵌套: slot@0/then/then (outer then → inner then) +4. 控制流合流: 自动产生合并版本 slot@N/merge = φ(分支来源...) + 合流后 base 计数器 +1 (合流结果视为新 base) + +解析规则: + 最后一个 @ 之前 = 变量 KV 路径 + @ 之后 = 版本路径 (/ 分层) + +示例: + /models/W@0/then ← 堆变量 W 的 base 0 → then 分支 + /vthread/vt1/x@0/then ← 栈变量 x 的 base 0 → then 分支 + ./tmp@0/then/else ← 局部变量 tmp 的 base 0 → then → else +``` + +**VM 行为**: VM 忽略 `@` 及之后全部内容。只看 slot 名,slot 复用语义不变。 + +### 3.3 语法示例 + +``` +线性代码: + constant(3.0) -> a@0 # a 首次写入 → base 0 + a@0 + 2.0 -> b@0 # 读 a@0, 写 b@0 + b@0 * 4.0 -> c@0 # 读 b@0, 写 c@0 + +带分支: + a@0 + 1.0 -> x@0 # 合流前写入 + if flag { + x@0 * 2.0 -> x@0/then # "版本 0 的 then 结果" + } else { + x@0 * 3.0 -> x@0/else # "版本 0 的 else 结果" + } + # 合流: x@0/merge = φ(x@0/then, x@0/else) + x@0/merge + 4.0 -> y@0 # 读合流结果 → 隐式返回 + +嵌套分支: + a@0 + 1.0 -> x@0 + if outer { + if inner { + x@0 * 2.0 -> x@0/then/then + } else { + x@0 * 3.0 -> x@0/then/else + } + # inner merge: x@0/then/merge = φ(x@0/then/then, x@0/then/else) + x@0/then/merge + 1.0 -> x@0/then/body + } else { + x@0 * 4.0 -> x@0/else + } + # outer merge: x@0/merge = φ(x@0/then/body, x@0/else) + x@0/merge / 2.0 -> y@0 + +循环: + 0.0 -> acc@0 + for i in 0..100 { + # 循环头合流: acc@0/loop = φ(acc@0, acc@0/loop/body) + data[i] + acc@0/loop -> acc@0/loop/body + } + # 循环退出: acc@0/merge = φ(acc@0, acc@0/loop/body) + acc@0/merge -> total@0 +``` + +### 3.4 编译器合流识别 + +``` +merge 标签使合流识别退化为前缀匹配: + +判定规则: + 1. /merge 后缀 → 直接识别为合流版本 + 2. 合流来源 = 去掉末尾 /merge 后, 以前缀匹配的所有分支版本 + 例: x@0/then/merge → 前缀 x@0/then/ → 来源 = x@0/then/then, x@0/then/else + 例: x@0/merge → 前缀 x@0/ → 来源 = x@0/then, x@0/else, x@0/then/body + 3. 若某分支无写入 → 来源为 base 版本 + +算法: + + func resolve_merge(version_str): + slot, ver_path = parse(version_str) // "x@0/then/merge" → slot="x", path=[0,then,merge] + if not ver_path.ends_with("merge"): + return // 非合流版本 + + // 提取前缀: "x@0/then/merge" → "x@0/then/" + prefix = version_str.strip_suffix("/merge") + "/" + + // 收集所有匹配前缀的版本 → φ 来源集合 + sources = find_all_versions_with_prefix(prefix) + + // 补全: 某分支无写入 → 来源 = base 版本 + for branch in get_branches_from_cfg(): + if no version matches prefix + branch: + sources.append(base_version(slot, prefix)) + + return create_phi(version_str, sources) +``` + +## 4. 编译器分析能力验证 + +### 4.1 死代码消除 (DCE) + +``` +a@0 + 1.0 -> x@0 +if flag { + x@0 * 2.0 -> x@0/then + b@0 * 3.0 -> dead@0/then ← 仅在 then 分支被定义 +} else { + x@0 * 4.0 -> x@0/else +} +# 合流: x@0/merge = φ(x@0/then, x@0/else) +x@0/merge + 3.0 -> y@0 + +分析: dead@0/then 从未被读 → 死代码 → 删除 ✅ +``` + +### 4.2 公共子表达式消除 (CSE) + +``` +if flag { + a@0 + b@0 -> x@0/then ← add(a@0, b@0) +} else { + a@0 + b@0 -> x@0/else ← 相同操作数和操作码! +} +# 合流: x@0/merge = φ(x@0/then, x@0/else) +x@0/merge * 2.0 -> y@0 + +分析: x@0/then 和 x@0/else 的定义完全一致 → 同一值 + φ 退化为单一来源 → 消除冗余定义 ✅ +``` + +### 4.3 全局值编号 (GVN) + +``` +a@0 + b@0 -> t@0 +c@0 + d@0 -> u@0 +if flag { + a@0 + b@0 -> x@0/then ← VN = 42 (同 t@0) +} else { + c@0 + d@0 -> x@0/else ← VN = 17 (同 u@0) +} +# 合流: x@0/merge = φ(x@0/then, x@0/else) +x@0/merge + 1.0 -> y@0 + +分析: t@0=42, u@0=17, x@0/then=42, x@0/else=17 + x@0/merge = φ(42, 17) → VN = hash(φ, 42, 17) ✅ +``` + +### 4.4 稀疏条件常量传播 (SCCP) + +``` +constant(true) -> flag@0 +if flag@0 { + constant(5.0) -> x@0/then +} else { + constant(3.0) -> x@0/else ← 不可达! flag=true +} +# 合流: x@0/merge = φ(x@0/then, x@0/else) +x@0/merge + 1.0 -> y@0 + +分析: flag@0=true → then 可达到, else 不可达 + x@0/then=5.0, x@0/else=⊥ + x@0/merge = φ(5.0, ⊥) = 5.0 → y@0 = 6.0 ✅ +``` + +### 4.5 循环不变量外提 (LICM) + +``` +a@0 * b@0 -> inv@0 ← 定义在循环外 +0.0 -> acc@0 +for i in 0..n { + # 循环头合流: acc@0/loop = φ(acc@0, acc@0/loop/body) + inv@0 + acc@0/loop -> acc@0/loop/body +} + +分析: inv@0 定义在循环外, 循环内无新版本 → 循环不变量 ✅ +``` + +### 4.6 寄存器分配 / 干涉图 + +``` +a@0 + b@0 -> t1@0 +t1@0 * 2 -> t2@0 +if flag { + t1@0 + 5 -> t3@0/then +} else { + t2@0 + 3 -> t3@0/else +} +# 合流: t3@0/merge = φ(t3@0/then, t3@0/else) +t3@0/merge * t1@0 -> z@0 + +分析: 每个版本有唯一活性区间 + t3@0/then 与 t3@0/else 不干涉 (不同分支) + t1@0 与 t3@0/merge 干涉 → 图着色精确 ✅ +``` + +## 5. 与纯 SSA 的架构对比 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ │ +│ 纯 SSA (MLIR cf dialect): │ +│ │ +│ @0: │ +│ %0 = add %a, %b │ +│ br %flag ? @1 : @2 ← 基本块分割 │ +│ @1: │ +│ %1 = add %a, 2.0 │ +│ br @3(%1) ← block argument 传递 │ +│ @2: │ +│ %2 = sub %a, 2.0 │ +│ br @3(%2) ← block argument 传递 │ +│ @3(%x: f32): ← %x = φ(%1, %2) │ +│ %3 = mul %x, 3.0 │ +│ ret %3 │ +│ │ +│ 版本号编码合流 (deepxir): │ +│ │ +│ fn example(flag: bool, a: f32) -> y: f32 { │ +│ a@0 + b@0 -> x@0 │ +│ if flag { │ +│ a@0 + 2.0 -> x@0/then ← base 0 → then 分支 │ +│ } else { │ +│ a@0 - 2.0 -> x@0/else ← base 0 → else 分支 │ +│ } │ +│ x@0/merge * 3.0 -> y@0 ← x@0/merge = φ(x@0/then, x@0/else) │ +│ } ↑ 写入 y@0 = 隐式返回 │ +│ │ +└─────────────────────────────────────────────────────────────────┘ + +┌────────────────────────┬──────────────────────┬──────────────────────┐ +│ 维度 │ 纯 SSA │ 版本号编码合流 │ +├────────────────────────┼──────────────────────┼──────────────────────┤ +│ IR 结构变化 │ 函数→基本块→指令 │ 函数→指令(一层) │ +│ 合流机制 │ block argument │ 版本号隐式编码 │ +│ 前端生成复杂度 │ 高 (CFG 构造 + φ) │ 低 (版本计数器) │ +│ VM 执行 │ 需 block-loop │ slot 复用,零变化 │ +│ 编译分析能力 │ 原生 SSA │ 内部 SSA (自动构建) │ +│ 向后兼容 │ ❌ 不兼容现有 IR │ ✅ 不加版本号照常运行 │ +│ 人可读性 │ 编号 %0 %1 无意义 │ 名字有意义 + 版本可溯源│ +│ 与执行协议对齐 │ 需 lower 到 slot │ 直接 = slot 格式 │ +└────────────────────────┴──────────────────────┴──────────────────────┘ +``` + +## 6. 实现路线 + +``` +Phase 1: 前端版本号生成 + ┌─────────────────────────────────────────────┐ + │ 每个 slot 维护: slot → current_base │ + │ 线形写入: 输出 "slot@N" │ + │ 进入分支: 追加分支标签 "slot@N/then" 等 │ + │ 嵌套分支: 逐级追加 "/then" "/else" │ + │ 合流点: 自动产生 "slot@N/merge" │ + │ │ + │ 新增代码量: ~100-200 行 │ + │ VM: 0 行 (忽略 @ 及之后全部内容) │ + └─────────────────────────────────────────────┘ + +Phase 2: 编译器合流识别 + ┌─────────────────────────────────────────────┐ + │ 解析 slot@path 格式 (@ 分隔 slot 与版本) │ + │ /merge 后缀直接识别为合流版本 │ + │ 前缀匹配查找 φ 来源集合 (无需回溯 CFG) │ + │ │ + │ 新增代码量: ~150-250 行 │ + └─────────────────────────────────────────────┘ + +Phase 3: 优化 Pass 全线启用 + ┌─────────────────────────────────────────────┐ + │ 内部 SSA 已就绪 │ + │ 复用标准优化 pass (DCE, CSE, GVN, SCCP...) │ + │ 优化结果映射回 slot → 输出优化后 IR │ + └─────────────────────────────────────────────┘ +``` + +## 7. 结论 + +``` +┌──────────────────────────────────────────────────────────────┐ +│ │ +│ Q: ->/<- 能否替代 SSA 做编译分析的全部需求? │ +│ │ +│ A: 能。通过变量版本号编码控制流合流。 │ +│ │ +│ 纯 slot 复用: │ +│ ❌ 缺少值身份 → 全局分析不可行 │ +│ ✅ 执行层高效 (VM 直接读写 slot) │ +│ │ +│ 版本号编码合流: │ +│ ✅ 线性段: slot@ver 直接提供 use-def 链, O(1) │ +│ ✅ 控制流合流: /merge 触发自动 φ 构建 │ +│ ✅ 全部 pass: DCE, CSE, GVN, SCCP, LICM, 寄存器分配 │ +│ ✅ 零新语法: 版本号即后缀, 无关键字, 无 block arg │ +│ ✅ VM 零变化: slot 复用保持不变, @ 之后全部忽略 │ +│ ✅ 向后兼容: 不加版本号照常运行 (无优化) │ +│ │ +│ 这不是"用 ->/<- 替代 SSA"。 │ +│ 这是"把 SSA 的编译分析能力折叠进 ->/<- 的版本号系统中"。 │ +│ │ +│ 版本号 = SSA 的值身份 │ +│ slot = 存储实体 │ +│ /merge = φ 的显式编码 │ +│ 结构化控制流 = CFG 的确定性描述 │ +│ │ +└──────────────────────────────────────────────────────────────┘ +``` + +--- + +> **关联文档**: +> - [spec-control-flow-v1.md](spec-control-flow-v1.md) — `->`/`<-` 与 SSA 语义层面对比 + C1~C5 方案架构 +> - [control-flow.md](control-flow.md) — 控制流 IR 设计(基本块模型) +> - [frontend-control-flow.md](frontend-control-flow.md) — 前端代码生成方案 diff --git a/doc/dxlang/control-flow.md b/doc/dxlang/control-flow.md new file mode 100644 index 00000000..214402ac --- /dev/null +++ b/doc/dxlang/control-flow.md @@ -0,0 +1,395 @@ +# deepxir 控制流设计 + +> **方案分析**:[spec-control-flow-v1.md](spec-control-flow-v1.md) — 5 种 C 子方案架构对比。 +> 本文提供 MLIR 对比 + 基本块模型 + C1 关键字 IR 具体设计。 + +## 1. 四层对比:Assembly IR、deepxir、MLIR、C 语言 + +> **MLIR** (Multi-Level Intermediate Representation) 与 deepxir 的设计目标高度重合: +> 都面向深度学习计算场景,都需要在计算图和指令级之间提供可组合的中间表示。 +> deepxir 可以视为 MLIR 在 **KV 存储原生、分布式调度、多后端异构** 方向上的一个垂直领域方言实现。 + +### 1.1 函数抽象 + +| 维度 | Assembly IR | deepxir | MLIR | C 语言 | +|------|-----------|---------|------|--------| +| 函数单元 | label + 指令序列 | `def name(params) -> (rets) { body }` | `func.func @name(%arg: T) -> T { body }` | `ret_type name(params) { body }` | +| 调用约定 | 手动 push/pop 寄存器、手动跳转 | CALL 指令 + VM 自动管理子栈帧 | `func.call @callee(%args)` + 显式 SSA 结果 | 函数调用表达式,编译器管理栈帧 | +| 参数传递 | 寄存器 / 栈偏移 | Redis KV 路径绑定(形参→实参) | Block arguments (SSA 值), 支持 memref 传递 | 按值 / 按引用,编译器分配 | +| 返回值 | rax/eax 寄存器 或 栈 | 隐式 RETURN 将输出形参值回传父栈 | `func.return %val : T` 显式返回 SSA 值 | `return` 表达式 | +| 栈帧 | push rbp; sub rsp, N | `/vthread///` 子键空间 | Region + block hierarchy, alloc 可下沉 | 连续栈内存 | +| 类型系统 | 无 (原始字节) | 签名标注型别,运行时类型感知求值 | 完整类型体系 (tensor/memref/vector/...) | 编译期静态类型 | + +**deepxir 的定位**:在汇编之上抽象了 KV 空间——函数帧是 Redis 子树,参数通过路径绑定传递。与 MLIR 共享"深度学习 IR"的目标,但 MLIR 侧重编译优化(SSA、dialect 混用、pass pipeline),deepxir 侧重运行时调度(Redis 原生、多后端路由、vthread 并发)。比 C 少一层编译复杂度,比 MLIR 少一层类型和 SSA 的抽象约束。 + +### 1.2 控制流原语 + +| 原语 | Assembly IR | deepxir (当前) | deepxir (目标) | MLIR | C 语言 | +|------|-----------|--------------|--------------|------|--------| +| 顺序执行 | PC++ | `[i,0] → [i+1,0]` | 同左 | block 内顺序 op | `;` 分隔 | +| 无条件跳转 | `jmp label` | — (缺失) | `jump ` | `cf.br ^block` | `goto label` | +| 条件分支 | `cmp; je/jne label` | `if cond → true/false 子树` (部分) | `if (cond) block1 else block2` | `scf.if` / `cf.cond_br ^t, ^f` | `if/else` | +| 循环 | `jmp` 回跳 | — (缺失) | `for` / `while` 展开为基本块 | `scf.for` / `scf.while` / `affine.for` | `for` / `while` | +| 函数调用 | `call func` | `CALL func → eager inline 翻译` | 同左 | `func.call @callee(%args)` | `func(args)` | +| 函数返回 | `ret` | 隐式 RETURN + 子栈清除 | 同左 | `func.return %val` | `return expr` | +| 多路分支 | `jmp [table]` | — | `switch` → 跳转表 | `cf.switch` / `scf.index_switch` | `switch/case` | +| 结构化抽象 | — | — | if/for 语法糖翻译为基本块 | `scf.` dialect (高层) ↔ `cf.` dialect (底层) | 原生结构化 | + +### 1.3 状态管理 + +| 维度 | Assembly IR | deepxir | MLIR | C 语言 | +|------|-----------|---------|------|--------| +| 变量存储 | 寄存器 / 内存地址 | Redis key (`/vthread//`) | SSA Value (虚拟寄存器编号) → 可下沉到 memref | 栈 / 堆 内存 | +| 作用域 | 全局(label 可见) | vthread 子树内全局 | Block/Region 隔离, SSA 支配规则 | 词法作用域 `{}` | +| 活跃分析 | 程序员 / 编译器负责 | VM 不追踪(调用方清理子栈) | SSA use-def chain 自动推导, 编译器负责 | 编译器 RA + 析构 | +| 并发安全 | 原子指令 + 内存屏障 | Redis WATCH/MULTI/EXEC (picker) | 显式 async/await dialect, gpu.async 等 | mutex / atomic | +| 数据流 | 隐式 (side effect) | 显式 (reads/writes 路径数组) | 显式 SSA (操作数 ↔ 结果 编号链接) | 隐式 + 指针别名 | + +**MLIR 的 SSA 模型**是其核心设计:每个值有唯一编号(`%0`, `%1`),use-def 链天然形成数据流图。deepxir 的数据流则通过 Redis key 路径(`reads`/`writes` 数组)显式编码,适合分布式场景,但缺少编译期 use-def 验证的严格性。 + +### 1.4 MLIR 核心创新对 deepxir 的启发 + +MLIR 的三个核心设计对 deepxir 有直接参考价值: + +#### Dialect 方言体系 vs deepxir 算子分类 + +``` +MLIR: deepxir (类比): + + func dialect (函数定义) def ... -> () { } ← 函数层 + scf dialect (结构化控制流) if / for / while ← 控制流层 (目标) + cf dialect (底层控制流) br / jump / switch ← CFG 层 (目标) + linalg dialect (线性代数) matmul, add, conv ← 计算层 + gpu dialect (GPU 抽象) op-metal / op-cuda ← 后端调度 (隐性) + memref dialect (内存抽象) newtensor/deltensor ← 生命周期层 + arith dialect (算术) + - * / % + 原生算子 ← 求值层 +``` + +MLIR 的 **dialect 混用** 允许在同一函数中混合不同抽象层的操作(如在 `scf.for` 循环内调用 `linalg.matmul`),deepxir 当前已经隐式实现了类似能力——生命周期 (`newtensor`)、计算 (`add`)、控制 (`if`) 可以在同一函数体中混用。**但 deepxir 缺少 dialect 的显式声明和校验机制**。 + +#### Region / Block 层级模型 + +``` +MLIR: deepxir (目标): + + func @main { def main(...) -> (...) { + ^bb0(%arg0: f32): @0: ← entry block + %0 = arith.addf %arg0, %cst newtensor(...) -> /data + cf.br ^bb1(%0) br ... → @1/@2 + + ^bb1(%1: f32): @1: ← loop header + scf.for %i = %lb to %ub { ... (block body) + %2 = linalg.matmul ... + ... + } + } } +``` + +MLIR 的 **Region** 是其关键抽象——`scf.for` 的循环体是一个嵌套的 Region,与父函数的 value 作用域隔离。deepxir 当前用 **PC 路径嵌套** (`/vthread///`) 实现了类似隔离,但控制流跳转(jump)跨越嵌套时需要更明确的 frame/region 边界语义。 + +#### Pass Pipeline 与 deepxir 的翻译阶段 + +``` +MLIR 编译流程: deepxir 当前 + 目标流程: + + 前端 (C++/PyTorch) → pysdk 写入 /src/func/ (dxlang) + │ │ + dialect lowering → ParseDxlang → ir.Instruction (语法→结构) + │ │ + canonicalization → eager inline translate (形参绑定, 坐标化) + │ │ + loop/affine optimization → 基本块合并 / 死代码消除 (Phase 3 目标) + │ │ + gpu mapping → route.Select(backend) → op-plat + │ │ + LLVM lowering → heap-plat 生命周期指令 + │ │ + 机器码 → GPU kernel 执行 +``` + +deepxir 的当前流程已经是一条隐式的 "pass pipeline",但缺少 MLIR 的 **可组合性** 和 **可验证性**:每个 pass 的输入输出格式没有形式化定义,变换的正确性依赖人工保证而非结构化约束。 + +--- + +## 2. deepxir 当前控制流状态 + +### 2.1 已实现 + +``` +控制流 opcode: call, return, if + +call → translate.HandleCall() + - eager inline: 一次性将编译层 dxlang 翻译为执行层 [i,j] 坐标 + - 形参/实参绑定: replaceParams(parsed.Reads/Writes, bindings) + - 子栈根: /vthread/// + - 隐式 return 指令: 追加在子栈末尾 + +return → translate.HandleReturn() + - 返回值回传: 读取 retRef, 写入父栈 retSlot + - 子栈清除: KEYS + DEL + - PC 恢复: NextPC(parentPC) + +if → dispatch.If() + - 条件求值: isTruthy(condVal) → true / false + - 分支 PC: pc+"/true/0" 或 pc+"/false/0" + - 无合并点: 分支末尾直接进入 nextPC 或 done +``` + +### 2.2 缺失 + +| 缺失项 | 影响 | +|------|------| +| **无条件跳转 (`jump`)** | 无法实现循环回边、无法实现 goto-like 控制流 | +| **循环 (`for`/`while`)** | 循环只能通过前端展开为线性指令(如文档示例中的 python-sdk 写入) | +| **分支合并 (join/phi)** | if/else 无合并基本块,两个分支各自独立结束 | +| **结构化块 (block)** | 控制流没有明确的"基本块"边界,靠子树路径区分 | +| **switch/case** | 多路分支无支持 | + +--- + +## 3. 设计方案:基本块模型(C 方案通用基础) + +> C1~C5 五种方案共享核心的基本块模型,区别在于 lowering 位置和 IR 层数。 +> 详见 [spec-control-flow-v1.md](spec-control-flow-v1.md)。 + +### 3.1 C 方案核心思想 + +> **控制流从"用户意图"到"机器执行"是一个语义等价的格式变换过程。** + +不同 C 子方案对"几层 IR"和"lowering 在哪里"有不同选择: + +| 子方案 | IR 层数 | Lowering 位置 | +|--------|--------|-------------| +| C1 关键字 | 1 层 | 无(VM 直接解释) | +| C2 二层+Scheduler | 2 层 | Scheduler 服务 | +| C3 单层基本块 | 1 层 | 前端负责 | +| C4 Region | 1 层 | VM 原生执行 region | +| C5 关键字+VM内 | 1 层(外)/2 层(内) | VM 加载时内部 lowering | + +**C1 和 C2 是互补的**:C1 可作为 C2 的结构化 IR 层格式——先 C1 快速可用,需要全局优化时加 Scheduler 切换到 C2。 + +### 3.2 基本块模型(所有 C 方案的执行层基础) + +每个基本块: +- **入口标签**:唯一的 block id(如 `@0`, `@1`, `@2`) +- **指令序列**:0 条或多条顺序指令 +- **终止指令**:恰好 1 条(`br` 条件跳转 / `jump` 无条件跳转 / `return` / `call`) + +### 3.3 执行层存储格式(Redis) + +``` +/vthread///@0 → "br" # block 0 的终止指令 +/vthread///@0/-1 → "cond" # br 的条件变量 +/vthread///@0/-2 → "@1" # true 目标 +/vthread///@0/-3 → "@2" # false 目标 + +/vthread///@0/0 → "newtensor" # block 0 的指令 +/vthread///@0/0,1 → "/data/a" +... + +/vthread///@1 → "jump" # block 1 的终止指令 +/vthread///@1/-1 → "@3" # 目标 block +``` + +#### 对比当前格式 + +``` +当前: /vthread//[0,0]/[0,0] → "newtensor" (嵌套路径 = 子树) + ^^^^^^^^ + PC 路径嵌套表示子栈 + +方案: /vthread//@0 → block 元数据 + /vthread//@0/0 → block 0 第 0 条指令 + /vthread//@0/0,1 → block 0 第 0 条指令第 1 个 write +``` + +**优点**: +- block id 是平面索引(`@0`, `@1`),不是嵌套路径(`[0,0]/[1,0]`) +- PC 不再需要 `/` 分隔符解析层级——层级由 CALL/RETURN 隐式管理 +- 控制流图更清晰:block 有明确的入边和出边 + +### 3.4 控制流 IR:两种表达方式 + +#### 3.4.1 C1/C5 风格:关键字 IR(结构化) + +控制流以关键字形式嵌入指令流,VM 直接解释(C1)或加载时 lowering(C5)。 + +``` +def example(x) -> (y) { + newtensor("f32", "[4]") -> ./a + sum(./x) -> ./s + + if (greater(./s, 0)) { ← 关键字 + add(./x, 1.0) -> ./y + } else { + mul(./x, -1.0) -> ./y + } + + for (var i = 0; less(i, 10); add(i, 1) -> i) { + add(./y, i) -> ./y + } + + while (greater(./y, 100.0)) { + mul(./y, 0.5) -> ./y + } + + deltensor(./a) + return(./y) +} +``` + +**Redis 存储**(嵌套 key 天然表达控制流层次): + +``` +/vthread//[1,0] = "if" +/vthread//[1,0]/cond = "./cond" +/vthread//[1,0]/then/0 = "add" +/vthread//[1,0]/else/0 = "mul" + +/vthread//[2,0] = "for" +/vthread//[2,0]/init/0 = "var" +/vthread//[2,0]/cond = "less(./i, 10)" +/vthread//[2,0]/step/0 = "add(./i, 1)" +/vthread//[2,0]/body/0 = "add(./y, ./i)" +``` + +#### 3.4.2 C2/C3 风格:基本块 IR(平铺) + +结构化控制流 lowering 为基本块 + 终止指令。 + +``` +@0: + newtensor("f32", "[4]") -> ./a + sum(./x) -> ./s + greater(./s, 0) -> ./cond + br ./cond, @1, @2 + +@1: + add(./x, 1.0) -> ./y + jump @3 + +@2: + mul(./x, -1.0) -> ./y + jump @3 + +@3: + deltensor(./a) + return(./y) +``` + +### 3.5 控制流关键字定义(C1 专用) + +| 关键字 | 语义 | VM 行为 | +|--------|------|--------| +| `if (cond) { ... } [else { ... }]` | 条件分支 | 求值 cond → 进入 then 或 else 子作用域 | +| `for (init; cond; step) { ... }` | 计数循环 | 执行 init → 循环: 求值 cond → 执行 body → 执行 step | +| `while (cond) { ... }` | 条件循环 | 循环: 求值 cond → 执行 body | +| `loop { ... }` | 无限循环 | 循环 body,直到内部 break | +| `break` | 跳出最近循环 | 跳出当前循环作用域 | +| `continue` | 跳过本次迭代 | 跳到循环条件判断 | +| `switch (val) { case v1: ... default: ... }` | 多路分支 | 匹配 val 到 case | + +### 3.6 终止指令定义(C2/C3 专用) + +| Opcode | 含义 | Reads | 语义 | +|--------|------|-------|------| +| `br` | 条件分支 | `[cond, true_block, false_block]` | if cond → PC=true_block else PC=false_block | +| `jump` | 无条件跳转 | `[target_block]` | PC = target_block | +| `call` | 函数调用 | `[func_name, args...]` | 创建子栈帧, PC 入子栈 | +| `return` | 函数返回 | `[ret_val]` | 清除当前栈帧, PC 回父栈 | + +### 3.7 VM 执行循环 + +**C1/C4 模式**(VM 直接解释关键字/region): + +```go +func Execute(vtid) { + inst := decode(pc) + switch inst.Opcode { + case "if": + pushScope(vtid, eval(inst.Cond) ? inst.Then : inst.Else) + case "for": + pushLoopScope(vtid, inst.Init, inst.Cond, inst.Step, inst.Body) + case "while": + pushLoopScope(vtid, nil, inst.Cond, nil, inst.Body) + case "break": + popLoopScope(vtid) + case "return": + popFrame(vtid) + default: + dispatch(inst) // op-plat / heap-plat / VM 求值 + } +} +``` + +**C2/C3 模式**(block-loop,最低 VM 复杂度): + +```go +func Execute(vtid) { + block := entryBlock + for block != nil { + for _, inst := range block.Instructions { + dispatch(inst) + } + switch block.Terminator.Opcode { + case "br": + block = eval(terminator.Cond) ? loadBlock(terminator.True) : loadBlock(terminator.False) + case "jump": + block = loadBlock(terminator.Target) + case "return": + popFrame(); block = parent.AfterCall + case "call": + pushFrame(terminator.Func); block = newFrame.Entry + } + } +} +``` + +### 3.8 栈帧模型 + +``` +/vthread/vt1/ +├── frame:0 # 根栈帧 +│ ├── @0 = entry block +│ ├── @1 ... +│ └── call myfunc → frame:1 +│ +├── frame:1 # 子栈帧 (myfunc) +│ ├── @0 = entry block +│ ├── @1 ... +│ └── return → pop +│ +平铺路径 + frame 层次分离 +block id (@0, @1) 是平面索引,不与调用层级耦合 +``` + +--- + +## 4. 实现路线图 + +### Phase 1: C1 关键字 IR(快速可用) + +1. 引入控制流关键字 opcode:`if`, `for`, `while`, `break`, `continue`, `switch` +2. VM 实现嵌套作用域栈(scope stack),直接解释关键字 +3. 前端 `@compile` 生成关键字 IR +4. Redis 存储同构(嵌套 key) + +### Phase 2: 基本块执行(为 C2/C5 做准备) + +1. 引入 `jump`/`br` 终止指令 +2. VM 实现 block-loop 执行模式(与关键字模式并存) +3. 关键字 → 基本块 lowering(在 VM 内或 Scheduler 内) +4. 前端可选择生成关键字 IR 或基本块 IR + +### Phase 3: Scheduler 服务(C2) + +1. lowering 逻辑从 VM 中抽出,放入独立 Scheduler +2. Scheduler 实现:Region Flattening → 块内 SSA → 优化 pass +3. 多前端(Python/Go)通过 Scheduler 统一 lowering + +### Phase 4: 优化与高级控制流 + +1. 死代码消除、基本块合并、循环不变量外提 +2. 短路求值 (`&&`/`||` 展开为 br 链) +3. 尾调用优化 (tail call → jump) +4. Dialect 命名空间预留 (`op:linalg:`, `op:scf:`) diff --git a/doc/dxlang/frontend-control-flow.md b/doc/dxlang/frontend-control-flow.md new file mode 100644 index 00000000..5998aea1 --- /dev/null +++ b/doc/dxlang/frontend-control-flow.md @@ -0,0 +1,591 @@ +# pysdk 控制流代码生成方案 + +> **方案分析**:[spec-control-flow-v1.md](spec-control-flow-v1.md) — C1~C5 架构对比。 +> 本文分析 `front/py` 实现 + C1 关键字 IR 的具体生成设计。 + +## 1. 当前 front/py 架构分析 + +### 1.1 整体数据流 + +``` +用户 Python 代码 (torch-like API) + │ + ▼ + Tensor.__add__ / __matmul__ / .relu() ... + │ + ▼ + nn.functional.leaffunc_* → create_A_B_tf_C 工厂函数 + │ │ + │ newtensor(shape) 创建输出 Tensor + │ rtf_mod.rtf_op(a, b, out) 发射 IR + │ │ + ▼ ▼ + DeepxIR("add", [a, b], [out]) 序列化 + │ + ▼ + scheduler.send(ir) → UDPConn.sendto("localhost:9090") + │ + ▼ + [外部调度器] → Redis /src/func/ → VM 拾取执行 +``` + +**关键特征:即刻发射 (Eager Emission)** + +每一条 Python 层的张量操作,立即生成一条 `DeepxIR` 字符串并通过 UDP 发送。 +**没有图构建阶段、没有延迟发射、没有函数边界。** + +### 1.2 代码生成核心:rtf 模块 + +`front/py/deepx/nn/functional/rtf.py` — 所有 IR 发射的汇聚点: + +```python +# rtf.py — 5 种发射模板,覆盖全部当前算子 + +def A_B_op_C(op, a, b, out, author): # add(A, B) -> C + ir = DeepxIR(op, [tensor(a), tensor(b)], [tensor(out)]) + send(ir) + +def A_op_C(op, a, out, author): # relu(A) -> C + ir = DeepxIR(op, [tensor(a)], [tensor(out)]) + send(ir) + +def A_scalar_op_C(op, a, b, out, author): # addscalar(A, 2.0) -> C + ir = DeepxIR(op, [tensor(a), varnum(b)], [tensor(out)]) + send(ir) + +def A_B_c_op_D(op, a, b, c, out, author): # equal(A, B, ε) -> D + ir = DeepxIR(op, [tensor(a), tensor(b), varnum(c)], [tensor(out)]) + send(ir) + +def A_b1_b2_op_C(op, a, b1, b2, out, author): # reduce(A, dims, keepdim) -> C + ir = DeepxIR(op, [tensor(a), vector(b1), varbool(b2)], [tensor(out)]) + send(ir) +``` + +每种发射模板生成一条格式为 `opname (arg1, arg2) -> (ret1) // metadata` 的字符串。 + +### 1.3 操作符重载层 + +`front/py/deepx/tensor/elementwise.py` — Tensor 方法定义: + +```python +@tensor_method +def add(self, other, out='') -> Tensor: + from deepx.nn.functional import add as add_func + return add_func(self, other, out) # → leaffunc → rtf → send +``` + +`front/py/deepx/tensor/tensor.py` — Python 操作符重载: + +```python +def __add__(self, other): return self.add(other) +def __matmul__(self, other): return self.matmul(other) +def __mul__(self, other): return self.mul(other) +``` + +### 1.4 模块系统 + +`front/py/deepx/nn/modules/module.py` — Module 基类: + +```python +class Module: + def register_parameter(name, param): + # 注册参数时触发 rnewtensor(param) + from deepx.nn.functional.leaffunc_life import rnewtensor + rnewtensor(param) + + def __call__(self, *args): + return self.forward(*args, **kwargs) +``` + +`front/py/deepx/nn/modules/linear.py` — Linear 的 forward: + +```python +def forward(self, input): + y = input @ self.weight.mT # → matmul IR emit + if self.bias is not None: # ← Python if,图构建期求值! + y = y + self.bias # → add IR emit + return y +``` + +**关键发现**:`self.bias is not None` 是 Python 级判断,在**图构建期**(第一次调用 forward)就已经求值。这是"结构化控制流"——取决于模型结构而非运行时数据。 + +### 1.5 参数协议 + +`front/py/deepx/nn/deepxir.py` — `Param` 类型标记: + +| Param 类别 | Python 类型 | DeepxIR 字符串形式 | 用途 | +|-----------|-----------|-------------------|------| +| `tensor:` | Tensor | `tensor:X` | 张量操作数和返回值 | +| `var:` | int/float | `var:42` 或 `var:0.5` | 立即数标量 | +| `var:bool` | bool | `var:true` | 布尔值 | +| `var:string` | str | `var:f32` | 字符串参数 | +| `vector:` | tuple | `vector:[3 4 5]` | 形状/维度参数 | +| `listtensor:` | tuple[Tensor] | `listtensor:[A B]` | 张量列表参数 | + +### 1.6 架构特点总结 + +``` +┌──────────────────────────────────────────────────────┐ +│ 即刻发射 (Eager) │ +│ │ +│ Python op → rtf → DeepxIR → UDP → 调度器 → Redis │ +│ │ +│ ❌ 无计算图缓存 ❌ 无延迟发射 │ +│ ❌ 无函数边界感知 ❌ 无控制流抽象 │ +│ ❌ 无 tracer / JIT ❌ 无 block 概念 │ +└──────────────────────────────────────────────────────┘ +``` + +--- + +## 2. 控制流引入后的核心矛盾 + +### 2.1 矛盾的根源 + +有了 `if`/`for`/`while` 之后,pysdk 面临一个根本问题: + +```python +# 用户想表达:运行时根据 Tensor 值决定执行哪条分支 +def forward(self, x): + cond = x > 0 # ← 这是一个 Tensor,运行时才能求值 + if cond: # ← Python if 在图构建期就求值了! + y = self.branch_a(x) + else: + y = self.branch_b(x) + return y +``` + +**Python 的 `if cond` 要求 `cond` 是 `bool`,但 `cond` 是 `Tensor`,Python 的 `if` 无法根据 Tensor 的运行时值分支。** + +PyTorch 也面临同样问题,解决方案是 `torch.jit.script` / `torch.compile`。 + +### 2.2 两种控制流 + +| | 结构化控制流 | 动态控制流 | +|------|-----------|---------| +| 决定时机 | 图构建期(Python 运行时) | deepxir 执行期(VM 运行时) | +| 依赖数据 | Python 值(`None`/`int`/`bool`) | Tensor 值(在 Redis/gpu 上) | +| 当前支持 | ✅ 天然支持(如 `if self.bias is not None`) | ❌ 不支持 | +| 需要 IR | 不需要(Python 本身处理) | 需要 deepxir `if`/`for`/`while` | +| 示例 | 可选的 bias/activation | 循环直到收敛、动态路由 | + +**设计的重点是支持动态控制流。** + +--- + +## 3. 前端方案:Hybrid Eager+Defer + C1 关键字 IR + +> 前端生成 **C1 关键字 IR**(`if`/`for`/`while` 作为一等 opcode)。 +> 后端可选:VM 直接解释(C1)或 Scheduler lowering(C2)。 + +### 3.1 两种模式的边界 + +```python +# 模式 1:即刻发射(现有,不变) +x = a + b # 立刻生成 add(a,b) -> tmp // send via UDP + +# 模式 2:编译模式 → 生成 C1 关键字 IR +@deepxir.compile +def my_func(x: Tensor) -> Tensor: + if x.sum() > 0: # 控制流 → 编译到关键字 IR + y = self.branch_a(x) + else: + y = self.branch_b(x) + return y +``` + +| 特性 | 即刻发射 (Eager) | 编译模式 (@compile) | +|------|---------|---------| +| 触发方式 | 默认(任何 Tensor op) | 装饰器显式标注 | +| 发射时机 | 每个 op 立即 send | 函数调用时一次性生成 | +| 输出格式 | 单条 DeepxIR → UDP | C1 关键字 IR → Redis `/src/func/` | +| 控制流 | 仅结构化(Python if) | 完整动态(if/for/while/break/continue) | +| 函数签名 | 无 | 有(`def name(params) -> (rets)`) | + +## 4. C1 关键字 IR 生成设计 + +### 4.1 新增模块布局 + +``` +front/py/deepx/ +├── nn/ +│ ├── deepxir.py ← 现有 IR 类(保留) +│ ├── compiler.py ← 新增:编译器入口 @compile +│ ├── ir_builder.py ← 新增:IR 构建器 (DeepxFunc, Block) +│ ├── tracer.py ← 新增:操作拦截器 (symbolic tensor) +│ └── control_flow.py ← 新增:控制流捕获 (If/For/While context) +└── scheduler/ + └── client/ + ├── udpconn.py ← 现有 UDP(保留,即刻模式用) + └── redisconn.py ← 新增:Redis 直连(编译模式用) +``` + +### 4.2 编译器工作流程 + +``` +@deepxir.compile +def my_func(x: Tensor, threshold: float) -> Tensor: + cond = x.sum() > threshold ← ① SymbolicTensor 拦截 + if cond: ← ② IfContext 捕获 + y = layer1(x) + else: + y = layer2(x) + return y + +↓ 装饰器编译流程: + +1. 提取签名: "def my_func(x:Tensor, threshold:float) -> (y:Tensor)" +2. 运行函数体 → SymbolicTensor 拦截所有操作 +3. IfContext/WhileContext 捕获控制流为嵌套结构 +4. IRBuilder 收集为 C1 关键字 IR 树 +5. 序列化写入 /src/func/my_func +6. LPUSH notify:vm + +生成的 C1 关键字 IR: +---------------------------------------------- +def my_func(x, threshold) -> (y) { + sum(./x) -> ./s + greater(./s, threshold) -> ./cond + if (./cond) { + call layer1(./x) -> ./y + } else { + call layer2(./x) -> ./y + } + return(./y) +} +---------------------------------------------- +``` + +### 4.3 核心组件设计 + +#### 4.3.1 Symbolic Tensor (tracer.py) + +```python +class SymbolicTensor: + """ + 编译模式下的"占位 Tensor"。 + 不持有实际数据,只记录在 IR 中的符号名。 + """ + def __init__(self, name: str, shape: tuple, dtype: str): + self._name = name # e.g., "./x", "./sum_tmp" + self._shape = shape + self._dtype = dtype + self._ir_builder = get_current_ir_builder() # 线程局部 + + def __add__(self, other) -> 'SymbolicTensor': + # 不发射 UDP!而是在当前 block 追加指令 + out = self._ir_builder.alloc_temp() # 分配临时变量名 + self._ir_builder.emit("add", [self, other], [out]) + return out + + def __gt__(self, other) -> 'SymbolicTensor': + out = self._ir_builder.alloc_temp() + self._ir_builder.emit("greater", [self, other], [out]) + return out # ← 返回 SymbolicTensor(dtype='bool') +``` + +#### 4.3.2 IR Builder (ir_builder.py) + +```python +class IRBuilder: + """编译模式下收集 IR 指令的构建器""" + def __init__(self, func_name: str): + self.func_name = func_name + self.signature = None + self.blocks: list[Block] = [] + self.current_block: Block = None + self._temp_counter = 0 + + def alloc_temp(self) -> str: + self._temp_counter += 1 + return f"./t{self._temp_counter}" + + def emit(self, opcode: str, reads: list, writes: list): + """在当前 block 追加指令""" + self.current_block.add_instruction(opcode, reads, writes) + + def new_block(self) -> Block: + block = Block(len(self.blocks)) + self.blocks.append(block) + return block + + def finalize(self) -> str: + """生成最终的 dxlang 字符串""" + lines = [f"def {self.func_name}{self.signature} {{"] + for block in self.blocks: + lines.append(f" @{block.id}:\n") + for inst in block.instructions: + lines.append(f" {inst.to_dxlang()}\n") + lines.append("}") + return ''.join(lines) +``` + +#### 4.3.3 控制流捕获 (control_flow.py) + +```python +class IfContext: + """捕获 if/else 分支为基本块""" + def __init__(self, cond: SymbolicTensor): + self.ir = get_current_ir_builder() + self.cond = cond + self.entry_block = self.ir.current_block + + def __enter__(self): + # 创建 then block,设置结束指令为条件分支 + self.then_block = self.ir.new_block() + self.else_block = self.ir.new_block() + self.merge_block = self.ir.new_block() + + # 在 entry block 追加终止指令 + self.entry_block.add_terminator("br", [self.cond, self.then_block, self.else_block]) + + # 切换到 then block + self.ir.current_block = self.then_block + return self + + def __exit__(self, ...): + # then block 结束时追加 jump 到 merge + self.then_block.add_terminator("jump", [self.merge_block]) + + # 切换到 else block (支持 with IfContext(cond): ... else: ... 语法) + # 否则 else block 为空 (直接 jump merge) + + # 切换到 merge block + self.ir.current_block = self.merge_block +``` + +#### 4.3.4 编译器装饰器 (compiler.py) + +```python +def compile(func): + """ + 将 Python 函数编译为 deepxir 函数定义,写入 Redis /src/func/。 + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + # 1. 提取签名 + sig = extract_signature(func) + + # 2. 创建 IR builder + builder = IRBuilder(func.__name__) + builder.signature = sig + set_current_ir_builder(builder) + + # 3. 创建入口 block + 形参 SymbolicTensor + builder.current_block = builder.new_block() + params = create_symbolic_params(sig, args) + + # 4. 执行函数体(操作被拦截) + result = func(*params, **kwargs) + + # 5. 在末尾 block 追加隐式 return + builder.current_block.add_terminator("return", [result]) + + # 6. 生成 dxlang + 写入 Redis + dxlang = builder.finalize() + rdb.set(f"/src/func/{func.__name__}", sig) + for i, block in enumerate(builder.blocks): + for j, inst in enumerate(block.instructions): + rdb.set(f"/src/func/{func.__name__}/{i}/{j}", inst.to_dxlang()) + + # 7. 通知 VM + rdb.lpush("notify:vm", f"func:{func.__name__}") + + # 8. 清理线程局部 + clear_current_ir_builder() + + return wrapper +``` + +### 4.4 支持的控制流语法 + +```python +@deepxir.compile +def example(x: Tensor, n: int) -> Tensor: + # === if/else === + if x.sum() > 0: # 条件为 SymbolicTensor + y = layer_a(x) + else: + y = layer_b(x) + + # === for (固定次数) === + for i in range(10): # Python range → 编译期展开为顺序指令序列 + y = y + x # (不需要 deepxir for 控制流) + + # === for (动态边界) === + for _ in deepxir.loop(cond=lambda: y.sum() < threshold): + y = y * 0.9 + x * 0.1 + + # === while === + while y.sum() > 1e-6: # 条件为 SymbolicTensor + y = y * 0.5 + + return y +``` + +### 4.5 与现有即刻模式的共存 + +``` +┌──────────────────────────────────────────────────────────────┐ +│ pysdk 入口 │ +│ │ +│ user_code.py │ +│ │ │ +│ ├── x = a + b → Tensor.__add__ │ +│ │ → rtf_add() → DeepxIR → UDP │ +│ │ ✅ 即刻发射(现有,不变) │ +│ │ │ +│ └── @compile def f(x): → compiler.py │ +│ if x > 0: → IfContext 捕获 │ +│ ... → SymbolicTensor 拦截 │ +│ return y → finalize → dxlang → Redis │ +│ ✅ 编译模式(新增) │ +└──────────────────────────────────────────────────────────────┘ +``` + +--- + +## 5. Redis 存储格式(C1 关键字 IR) + +编译器输出的 C1 关键字 IR 在 Redis 中以嵌套 key 表达: + +``` +C1 关键字 IR: +---------------------------------------------- +def example(A:int, B:int) -> (C:int) { + newtensor("f32", "[4]") -> ./x + sum(./x) -> ./s + greater(./s, 0) -> ./cond + + if (./cond) { + add(./x, 1) -> ./y + } else { + mul(./x, -1) -> ./y + } + + deltensor(./x) + return(./y) +} +---------------------------------------------- + +对应的 Redis key(嵌套同构): + +/src/func/example = "def example(A:int, B:int) -> (C:int)" +/src/func/example/0 = "newtensor(\"f32\", \"[4]\") -> ./x" +/src/func/example/1 = "sum(./x) -> ./s" +/src/func/example/2 = "greater(./s, 0) -> ./cond" +/src/func/example/3 = "if" +/src/func/example/3/cond = "./cond" +/src/func/example/3/then/0 = "add(./x, 1) -> ./y" +/src/func/example/3/else/0 = "mul(./x, -1) -> ./y" +/src/func/example/4 = "deltensor(./x)" +/src/func/example/5 = "return(./y)" +``` + +**关键属性**:IR 的嵌套结构与 Redis key 的嵌套结构完全同构。 +VM 在执行 `if`/`for`/`while` 时自然地按照 key 子树进入子作用域。 + +--- + +## 6. 与 VM 翻译层的配合 + +当前 VM 的 `translate.go` 处理 CALL 指令时: +``` +CALL func_name(args...) → + 1. 读取签名: GET /src/func/ + 2. MGET 函数体 (按数字后缀排序) + 3. 逐条翻译为 vthread 执行层坐标 + 4. 追加隐式 return +``` + +引入 block 格式后,`translate.go` 需要升级: + +``` +CALL func_name(args...) → + 1. 读取签名: GET /src/func/ + 2. 读取 blocks: KEYS /src/func//@* + 3. 为每个 block 读取指令序列 (按数字后缀排序) + 4. 构建 CFG: 解析每个 block 的终止指令 (br/jump/return) + 5. Eager inline: 形参替换 + 写入 vthread 子树 + 6. Block 终止指令保留为控制流 opcode (br/jump/return) +``` + +**关键变化**:之前的翻译是线性的(指令列表),现在是图结构的(block 列表 + 跳转边)。 + +--- + +## 7. 与 Go 前端的对比 + +Go 前端 (`front/go/deepx/`) 采用 **图模式 (Graph Mode)**: + +```go +func (m *Transformer) Forward(x *Tensor) *Tensor { + for _, layer := range m.layers { // ← Go for, 图构建期展开 + x = layer.Forward(x) + } + return x +} +``` + +Go 的 `for` 循环在编译期展开(图构建期),生成的是**静态计算图**。Go 前端未来也需要支持动态控制流——两种前端可以共享同一套编译期抽象(IR Builder / Block / CFG)。 + +| 维度 | Python 即刻模式 | Python 编译模式 | Go 图模式 | +|------|---------|---------|--------| +| 发射时机 | 每个 op 立即 | 函数调用时批量 | 图构建完成后 | +| 控制流 | Python if/for(构建期) | deepxir if/for/while(执行期) | Go if/for(构建期) | +| 输出格式 | 单条 IR → UDP | 完整 dxlang → Redis | DOT 图 → 文件 | +| 嵌套调用 | 无 | 有(装饰器标注) | 有(直接调用) | +| 动态控制流 | ❌ | ✅ | ❌(未来需要) | + +--- + +## 8. 实现路线图 + +### Phase 1: IR Builder + SymbolicTensor(线性函数) + +1. 实现 `SymbolicTensor`(`__add__`/`__mul__` 等拦截,生成 C1 IR) +2. 实现 `IRBuilder`(单 scope 指令收集) +3. 实现 `@deepxir.compile` 装饰器(无分支,纯线性) +4. Redis 写入 (`/src/func/`) +5. VM `translate.go` 支持 C1 嵌套 key 格式 + +### Phase 2: 控制流捕获(C1 关键字) + +1. 实现 `IfContext` → 生成 `if (cond) { ... } else { ... }` 关键字 IR +2. 实现 `WhileContext` → 生成 `while (cond) { ... }` +3. 实现 `ForContext` → 生成 `for (init; cond; step) { ... }` +4. 实现 `break`/`continue` 上下文感知 +5. VM 实现嵌套作用域栈,直接解释关键字 + +### Phase 3: 基本块执行(为 C2/C5 准备) + +1. VM 新增 block-loop 执行模式 +2. 关键字 → 基本块 lowering 模块(先放在 VM 内) +3. `jump`/`br` 终止指令 + +### Phase 4: 优化与生产化 + +1. Scheduler 服务(lowering 从 VM 抽出) +2. 基本块合并、死代码消除 +3. Eager 与 @compile 混用 + +--- + +## 9. 关键设计决策 + +### 为什么不用 TorchScript 的 Script 模式 + +TorchScript 需要维护一个完整的 Python 子集编译器(类型推断、控制流分析、字节码拦截),复杂度极高。deepxir 只需要支持张量运算 + 基本控制流,拦截 `Tensor` 操作 + 控制流关键字就足够了。 + +### 为什么保留即刻模式 + +即刻模式(`rtf → DeepxIR → UDP`)是调试、探索式开发、单步测试的基础。编译模式是性能优化和复杂控制流的手段。两种模式互补而非替代。 + +### 为什么函数是编译单元 + +以**函数**为编译单元而不是整个程序,因为: +1. deepxir 的函数就是 VM 的调用/调度单元 +2. 编译一个函数 → 写入 `/src/func/` → VM 的 `CALL` 可直接使用 +3. 与现有的 `testdata/*.dx` 文件格式兼容 +4. 与 Go 前端的模块/层概念对齐 diff --git a/doc/dxlang/spec-control-flow-v1.md b/doc/dxlang/spec-control-flow-v1.md new file mode 100644 index 00000000..1f681f43 --- /dev/null +++ b/doc/dxlang/spec-control-flow-v1.md @@ -0,0 +1,813 @@ +# DeepX 控制流与前端代码生成 — 架构分析 v1 + +> 方案 C 确定为最优架构方向。本文展开 5 种子方案 + deepxir 语法完整设计。 + +--- + +## 1. 评判准则 + +| 准则 | 含义 | +|------|------| +| **关注点分离** | 每层只做一件事,层间接口明确 | +| **单一抽象级别** | 同一层内的概念在同一抽象高度 | +| **可验证性** | 层间转换有形式化保证 | +| **可扩展性** | 新能力不触动核心 | +| **概念完整性** | 一个核心思想贯穿始终 | + +--- + +## 2. 方案 C 的本质 + +> **控制流从"用户意图"到"机器执行"是一个语义等价的格式变换过程。 +> 这个变换应该发生在明确的边界上,每步变换可独立验证。** + +``` + IR 层数 Lowering 位置 +C1 单层关键字 1 层 无(VM 直接解释) +C2 二层分离 2 层 Scheduler 服务 +C3 单层基本块 1 层 前端负责 +C4 单层 Region 1 层 VM 原生执行 region +C5 关键字+VM内lower 1 层(外部) VM 加载时内部 lowering +``` + +--- + +## 3. deepxir 语法设计 + +> 语法需同时满足三者:**人**(可扫读)、**VM**(确定可解析)、**Agent**(模式固定可生成)。 + +### 3.1 核心约束 + +| 约束 | 原因 | +|------|------| +| **保留 `->` 和 `<-`** | 显式区分只读参数和写入参数方向,是 deepxir 的协议级语义 | +| **算子统一命名** | `tensor.new` / `tensor.del` 形成命名空间;栈变量 VM 自动推导 | +| **中缀运算符需 op-plat 注册** | 非标量类型的 `+`/`-`/`*`/`>` 等符号运算,后端必须声明支持 | + +### 3.2 语法规则 + +``` +program = func* +func = "fn" name "(" params ")" ("->" returns)? body +params = param ("," param)* +param = name (":" type)? +returns = name (":" type)? + | "(" name (":" type)? ("," name (":" type)?)* ")" +type = "f16" | "f32" | "f64" | "bf16" | "i8" | "i16" | "i32" | "i64" | "bool" | "string" + | type "[" shape "]" # f32[2,4] 或 f32[?,?] +shape = dim ("," dim)* +dim = int | "?" + +body = "{" stmt* "}" +stmt = assign | ctrl_if | ctrl_for | ctrl_while + | ctrl_break | ctrl_continue | ctrl_switch | ctrl_return + | lifecycle | bare_expr + +# === 赋值:两种箭头,语义等价 === +assign = prefix_op "->" name # 传统: 表达式在左, 结果在右 + | name "<-" prefix_op # C风格: 结果在左, 表达式在右 + | infix_expr "->" name # 中缀: x + y -> z + | name "<-" infix_expr # C中缀: z <- x + y + +# === 前缀调用(总是合法)=== +prefix_op = name "(" args ")" # add(x, y), matmul(A, B) + | unop operand # !flag, -x + +# === 中缀表达式(仅当 op-plat 注册了该符号)=== +infix_expr = operand binop operand # x + y, s > 0 +binop = "+" | "-" | "*" | "/" | "%" + | "==" | "!=" | "<" | ">" | "<=" | ">=" + | "&&" | "||" +unop = "-" | "!" + +# === 操作数 === +operand = name # 局部变量: x, tmp + | "/" name ("/" name)* # 堆路径: /models/W + | literal # 1.0, true, "f32" +args = operand ("," operand)* +literal = int | float | "true" | "false" | string + +# === 生命周期 === +lifecycle = "tensor.new" "(" shape "," type ")" "->" name # 堆分配 + | "tensor.del" "(" name ")" # 堆释放 + | "tensor.clone" "(" name ")" "->" name # 堆克隆 + +# 栈变量: VM 遇到新 name 在写位置首次出现时自动创建, 无需显式 tensor.new + +# === 控制流 === +ctrl_if = "if" operand body ("else" (ctrl_if | body))? +ctrl_for = "for" name "in" range body +range = operand ".." operand +ctrl_while = "while" operand body +ctrl_loop = "loop" body +ctrl_break = "break" +ctrl_continue = "continue" +ctrl_switch = "switch" operand "{" case* default? "}" +case = "case" literal ":" body +default = "default" ":" body +ctrl_return = "ret" operand? +``` + +### 3.3 语法决策说明 + +#### 为什么必须保留 `->` 和 `<-` + +deepxir 不是通用编程语言。它的指令格式(`reads` 数组 + `writes` 数组)天然有方向性。 +箭头直接表达这个方向: + +``` +add(x, y) -> z reads=[x, y] writes=[z] 一目了然 +z <- add(x, y) reads=[x, y] writes=[z] C 风格等价写法 + +# 如果用 =,方向信息丢失: +z = add(x, y) ← 失去了 reads/writes 的结构区分 +``` + +`->` 和 `<-` 是**语法级的多写入支持**: + +``` +split(x) -> (a, b) # 一目了然:x 读, a 和 b 写 +(a, b) <- split(x) # 等价 C 风格 + +# 如果用 = 则需要特殊语法: +(a, b) = split(x) # 混淆了 assignment 和 destructure +``` + +#### 为什么中缀需要 op-plat 注册 + +```deepxir +# 场景 1: VM 原生求值(标量) +1 + 2 -> x # VM 直接算, 不需要 op-plat + +# 场景 2: op-plat 求值(张量) +x + y -> z # 仅当 op-plat 注册了 add 算子支持 "+" 符号时合法 + # 否则编译/解析时报错, 强制使用显式前缀: +add(x, y) -> z + +# 场景 3: 混合 +x + 1.0 -> y # addscalar 算子, 需要 op-plat 注册 +``` + +op-plat 注册格式: + +```json +// /op/op-cuda/add +{ + "symbols": ["+"], + "dtype": ["f32", "f16", "bf16"], + ... +} +``` + +VM 在解析中缀表达式时: +1. 检查操作数类型 → 是否全是标量?→ VM 直接求值 +2. 操作数含张量 → 查 `/op//` 的 `symbols` 字段 +3. 符号已注册 → 展开为对应 opcode,允许 +4. 符号未注册 → 报错:"op-plat 不支持 `+` 运算, 请使用显式前缀" + +#### 为什么 `tensor.new` / `tensor.del` 而非 `alloc` / `free` + +命名空间统一: + +``` +tensor.new 堆分配 — 与 opcode 名一致(heap-plat 的协议名) +tensor.del 堆释放 — 同上 +tensor.clone 堆克隆 — 同上 + +matmul 计算 — op-plat 协议名 +add 计算 — op-plat 协议名 + +栈变量无显式命名 — VM 遇到新变量名自动推导 +``` + +好处:IR 中的名字就是 Redis 中的 opcode,零映射。 + +#### 为什么栈变量无需显式 new + +``` +# VM 执行: +sum(x) -> s ← s 首次出现在写位置, VM 自动创建栈变量 +s + 1.0 -> s ← s 已存在, 复用 + +# 等价于 VM 内部: +# 第一次遇到 s: SET /vthread//s = {dtype: f32, value: ...} +# 后续遇到 s: GET /vthread//s → 读写 +``` + +### 3.4 完整示例 + +```deepxir +# 线性函数 — 仅前缀调用 +fn add_test(A: f32[?,?], B: f32[?,?]) -> C: f32[?,?] { + matmul(A, B) -> C +} + +# 带控制流 + 堆生命周期 +fn dynamic_clamp(x: f32[?], min: f32, max: f32) -> y: f32[?] { + tensor.new([?], f32) -> mask_low + x < min -> mask_low # 中缀, 需 op-plat 注册 < + + if any(mask_low) { + where(mask_low, min, x) -> y + } else { + y <- x # C风格 + } + tensor.del(mask_low) + + tensor.new([?], f32) -> mask_high + y > max -> mask_high # 中缀, 需 op-plat 注册 > + + if any(mask_high) { + where(mask_high, max, y) -> y + } + tensor.del(mask_high) + + ret y +} + +# 循环 + 混合前缀/中缀 +fn training_step(data: f32[?,?]) -> loss: f32 { + 0.0 -> total # 标量直接用前缀 + + for i in 0..100 { + matmul(data, /models/W) -> pred # 前缀, 堆引用 + pred - data -> err # 中缀, 需注册 - + err * err -> sq # 中缀, 需注册 * + total + sum(sq) -> total # 中缀 + 前缀 + } + total / 100.0 -> loss # 中缀, 需注册 / + ret loss +} + +# C风格写法(等价) +fn cstyle_example(x: f32[4]) -> y: f32[4] { + s <- sum(x) # C风格前缀 + s > 0 -> cond # 中缀 + if cond { + y <- x + 1.0 # C风格中缀 + } else { + y <- x * -1.0 + } + ret y +} +``` + +### 3.5 中缀→前缀的展开规则(VM 解析层) + +``` +x + y -> z → add(x, y) -> z (二元) +x - y -> z → sub(x, y) -> z +x * y -> z → mul(x, y) -> z +x / y -> z → div(x, y) -> z +x % y -> z → mod(x, y) -> z +x == y -> z → equal(x, y) -> z +x != y -> z → notequal(x, y) -> z +x < y -> z → less(x, y) -> z +x > y -> z → greater(x, y) -> z +x <= y -> z → lessequal(x, y) -> z +x >= y -> z → greaterequal(x, y) -> z +x && y -> z → and(x, y) -> z +x || y -> z → or(x, y) -> z +!x -> y → not(x) -> y (单元) +-x -> y → neg(x) -> y (单元) +``` + +每个目标 opcode 必须被对应后端注册了 `symbols` 字段才允许中缀形式。 + +### 3.6 Redis 存储格式(嵌套同构) + +``` +fn example(x: f32[4]) -> y: f32[4] { + tensor.new([4], f32) -> a + sum(x) -> s + if s > 0 { + x + 1.0 -> y + } else { + y <- x * -1.0 + } + tensor.del(a) + ret y +} + +↓ Redis: +/src/func/example = "fn example(x: f32[4]) -> y: f32[4]" +/src/func/example/0 = "tensor.new([4], f32) -> a" +/src/func/example/1 = "sum(x) -> s" +/src/func/example/2 = "if" +/src/func/example/2/cond = "s > 0" +/src/func/example/2/then/0 = "x + 1.0 -> y" +/src/func/example/2/else/0 = "y <- x * -1.0" +/src/func/example/3 = "tensor.del(a)" +/src/func/example/4 = "ret y" +``` + +--- + +## 4. `->`/`<-` 与 SSA 对比分析 + +> **核心问题**:`->`/`<-` 方向语义和 SSA 单赋值模型,是变种等价关系还是根本不同的设计选择? + +### 4.1 本质差异:两种不同的语义模型 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ 两种模型的核心分歧 │ +│ │ +│ SSA = 值身份模型 (Value Identity) │ +│ 追踪"哪个值"被使用 │ +│ 同一存储位置的不同时刻值 → 不同的 SSA 名字 │ +│ 问题域: 这个计算依赖哪个具体的值? │ +│ │ +│ ->/<- = 存储效应模型 (Storage Effect) │ +│ 追踪"哪个存储位置"被读写 │ +│ 同一名字的不同时刻值 → 同一个变量在不同时刻 │ +│ 问题域: 这个操作读写哪些 KV 路径? │ +└─────────────────────────────────────────────────────────────────┘ +``` + +**它们不是等价变种。** 两种模型回答的是不同层次的问题,服务于不同的系统角色。 + +### 4.2 形式化对比 + +| 维度 | SSA | `->`/`<-` | +|------|-----|----------| +| **变量赋值次数** | 严格 1 次 | 无限次(可变) | +| **值标识** | 虚拟寄存器编号 (`%0`, `%1`) | KV 路径 (`./x`, `/data/W`) | +| **数据流表达** | 操作数引用 = 隐式 use-def 链 | 显式 reads[] / writes[] 数组 | +| **控制流汇合** | block arguments (替代 φ 节点) | 变量自然复用,无需特殊机制 | +| **支配关系** | 严格支配树 (定义支配所有使用) | 无形式支配约束 | +| **副作用建模** | 困难 (需要特殊 dialect/op) | 天然 (tensor.new/del 就是指令) | +| **理论基础** | 编译器优化理论 (Cytron et al. 1991) | 分布式 KV 状态机 | +| **典型系统** | LLVM IR, MLIR, GCC GIMPLE | deepxir, Redis Lua, SQL stored procedures | + +### 4.3 各自优势 + +#### SSA 优势 + +``` +1. 编译器优化使能器 + %1 = add(%a, %b) ← CSE: 同一操作数 → 同一结果, 可消除重复 + %2 = add(%a, %b) ← 因为 %a, %b 的 SSA 值身份不变, CSE 直接判定 %1 == %2 + + 对比 ->/<-: + add(a, b) -> c ← CSE 困难: a, b 的值可能在指令间被改写 (可变变量) + add(a, b) -> d ← VM 无法仅从名字判定 c == d, 需要追踪 a, b 的最近写入 + +2. use-def 链零成本 + %result = mul(%x, %y) ← %x 引用直接指向定义 %x 的唯一指令 + 无需额外数据结构即可遍历数据流图 + +3. 寄存器分配友好 + SSA → 干涉图 → 着色 → 寄存器 + 变量版本号天然就是活性区间标识 + +4. 形式化验证 + 支配树 + 支配边界 → 可形式证明优化变换的正确性 + 学术界 30+ 年理论积累 +``` + +#### `->`/`<-` 优势 + +``` +1. 执行协议直通 + add(a, b) -> c + ↓ 零映射 ↓ + {"opcode": "add", "reads": ["/vthread/vt1/a", "/vthread/vt1/b"], "writes": ["/vthread/vt1/c"]} + ↓ 直接发送到后端 ↓ + op-plat / heap-plat 执行 + + 对比 SSA: 需要额外 lowering 步骤 + %1 = add %0, %arg0 → 解析 use-def → 分配内存槽 → 生成 reads/writes + +2. 分布式状态天然 + /vthread/vt1/a ← KV 路径即变量身份 + 跨节点执行: 同一路径在不同节点指向同一数据分片 + SSA 的值编号 (%0, %1) 是局部的, 无法跨节点引用 + +3. 副作用指令自然嵌入 + tensor.new([4], f32) -> ./a ← 分配 GPU 显存, 这是一个副作用 + tensor.del(./a) ← 释放 GPU 显存, 这也是副作用 + + SSA 中: + %0 = tensor.new [4] : f32 ← MLIR 需要 memref dialect 专门处理 + tensor.del %0 ← 破坏 SSA 纯函数假设, 需要特殊标注 + +4. 多写入直观 + split(x) -> (a, b) ← 一目了然: 一个输入, 两个输出 + 对比 SSA: 需要 tuple 封装再解构, 或自定义 dialect +``` + +### 4.4 同一示例的两种表达 + +``` +问题: 计算 (x + y) * (x - y), 其中 x 是输入参数 + +┌─── SSA ──────────────────────┐ ┌─── ->/<- ────────────────────┐ +│ │ │ │ +│ fn calc(x: f32) -> f32 { │ │ fn calc(x: f32) -> z: f32 { │ +│ %0 = add(x, y) │ │ x + y -> sum │ +│ %1 = sub(x, y) │ │ x - y -> diff │ +│ %2 = mul(%0, %1) │ │ sum * diff -> z │ +│ ret %2 │ │ ret z │ +│ } │ │ } │ +│ │ │ │ +│ 特点: │ │ 特点: │ +│ • %0, %1, %2 各出现 1 次 │ │ • sum 出现 2 次 (写, 读) │ +│ • 数据流: use-def 链隐式 │ │ • 数据流: -> 方向显式 │ +│ • 适合做 CSE/复写传播 │ │ • 适合直接执行 │ +└───────────────────────────────┘ └───────────────────────────────┘ +``` + +这个简单例子中两者等价——SSA 的 `%0` 对应 `sum`, `%1` 对应 `diff`, `%2` 对应 `z`。 +**差异在控制流和副作用场景下才显著。** + +### 4.5 控制流场景下的关键差异 + +``` +场景: if/else 分支后合并使用结果 + +┌─── SSA (MLIR 风格) ──────────────────────────────────────────┐ +│ │ +│ fn branch_example(x: f32, flag: bool) -> f32 { │ +│ br flag ? @then : @else │ +│ │ +│ @then: │ +│ %t = add(x, 1.0) ← 定义 %t │ +│ br @merge(%t) ← %t 作为 block argument 传递 │ +│ │ +│ @else: │ +│ %e = sub(x, 1.0) ← 定义 %e │ +│ br @merge(%e) ← %e 作为 block argument 传递 │ +│ │ +│ @merge(%y: f32): ← %y 是 block argument │ +│ %r = mul(%y, 2.0) ← 使用 %y (来自 %t 或 %e) │ +│ ret %r │ +│ } │ +│ │ +│ 机制: block argument 取代 φ 节点 │ +│ 优势: SSA 贯穿始终, %y 有唯一定义 (block parameter) │ +│ 代价: 需要传递 block argument, 前端/VM 需理解支配关系 │ +└────────────────────────────────────────────────────────────────┘ + +┌─── ->/<- (deepxir 风格) ──────────────────────────────────────┐ +│ │ +│ fn branch_example(x: f32, flag: bool) -> result: f32 { │ +│ if flag { │ +│ x + 1.0 -> y ← 写入 y (slot) │ +│ } else { │ +│ x - 1.0 -> y ← 写入 y (同一个 slot) │ +│ } │ +│ y * 2.0 -> result ← 读取 y (无论哪个分支写的) │ +│ ret result │ +│ } │ +│ │ +│ 机制: 可变 slot 复用 │ +│ 优势: 无需 block argument, 前端/VM 逻辑简单 │ +│ 代价: 无值身份追踪, y 被写两次 (违反 SSA), 优化器需额外分析 │ +└────────────────────────────────────────────────────────────────┘ +``` + +**核心洞见**: SSA 的 block argument 和 `->`/`<-` 的可变 slot 是**对偶设计**: +- SSA 用**参数化**解决多来源问题 (值从 block 参数传入) +- `->`/`<-` 用**存储复用**解决多来源问题 (同一个 slot 被不同路径写入) + +### 4.6 副作用场景下的根本分歧 + +``` +场景: 循环内分配和释放临时张量 + +┌─── SSA (MLIR 需 memref dialect) ──────────────────────────────┐ +│ │ +│ %buf = memref.alloc()[4] : f32 ← 分配, 有副作用 │ +│ scf.for %i = 0 to 10 { │ +│ %tmp = memref.alloc()[4] : f32 ← 每次迭代分配新 memref │ +│ linalg.fill %tmp, %i │ +│ linalg.add %buf, %tmp -> %buf │ +│ memref.dealloc %tmp ← 释放 │ +│ } │ +│ │ +│ 问题: │ +│ • %buf 被循环内更新 → 不能用纯 SSA, 需 memref (可变内存抽象) │ +│ • memref 本质是"带副作用的内存槽" → SSA 的纯函数假设在此破裂 │ +│ • MLIR 的解决: 用 memref dialect 隔离副作用, 保持其余 IR 纯 SSA │ +└────────────────────────────────────────────────────────────────┘ + +┌─── ->/<- (deepxir 原生) ──────────────────────────────────────┐ +│ │ +│ tensor.new([4], f32) -> buf ← 堆分配, is_write=1 │ +│ for i in 0..10 { │ +│ tensor.new([4], f32) -> tmp ← 堆分配 │ +│ fill(tmp, i) -> tmp ← 写入 │ +│ buf + tmp -> buf ← 读写 buf │ +│ tensor.del(tmp) ← 堆释放 │ +│ } │ +│ │ +│ 优势: │ +│ • 副作用 (new/del) 与计算 (add/fill) 在同一层次 │ +│ • 每条指令天然携带 reads/writes → 执行层直接可用 │ +│ • 无需 dialect 隔离 → 概念完整, 学习成本低 │ +│ • buf 被多次写入 → 可变 slot, 与 GPU 显存操作模型一致 │ +└────────────────────────────────────────────────────────────────┘ +``` + +**根本分歧点**: SSA 假设纯函数计算模型 (值不可变), 副作必须隔离到特定 dialect。 +`->`/`<-` 假设可变状态模型 (存储位置可更新), 副作用是第一公民。 + +deepxir 的核心场景——GPU 显存分配/释放、分布式 KV 读写、多后端异构执行——**本质上就是副作用密集的**。 +因此 `->`/`<-` 是更自然的匹配。 + +### 4.7 等价性证明 + +**在纯计算 (无副作用、无控制流汇合) 的线性代码段内, SSA 与 `->`/`<-` 严格等价**: + +``` +定理 1 (线性段等价): + 对于不含控制流分支和副作用指令的线性指令序列, + SSA 表示和 ->/<- 表示之间存在双射 (bijection)。 + + 证明: + SSA: v0 = op(args...) →/<-: op(args...) -> name + 映射: 每个 SSA 虚拟寄存器 %i ←→ 首次出现在写位置的变量名 name + 每个 SSA use ←→ 读位置的变量名 + 映射是双射: 每个 %i 仅有一个定义点, 每个定义点产生一个新 %i, + ->/<- 侧每个写位置引入一个变量名, 与 SSA 一一对应。 +``` + +**在控制流汇合场景下, 语义等价但结构不等价**: + +``` +定理 2 (控制流汇合等价): + 对于带 φ 节点的 SSA 形式, + 存在一个 ->/<- 程序与之计算等价 (compute the same results), + 但结构上不等价 (φ 通过 block parameter 或 slot 复用实现)。 + + 映射方向: + SSA φ: %y = φ(%t, %e) → ->/<-: 在两个分支中写入同一个 slot y + + 逆映射方向: + ->/<-: 分支中写 slot y → SSA: 插入 φ 节点或 block argument + + 关键差异: + SSA → ->/<- 是信息丢失的 (值身份信息丢失, 只保留了存储位置信息) + ->/<- → SSA 是信息恢复的 (需要做 SSA 构造算法, 为标准编译器技术) +``` + +**在副作用场景下, 两者不等价**: + +``` +定理 3 (副作用分歧): + 对于包含内存分配/释放、I/O 等副作用的程序, + SSA (纯 MLIR core) 与 ->/<- 不等价。 + + 原因: + SSA 将副作用隔离到特定 dialect (memref, gpu, async), + 副作用指令的语义由 dialect 定义, 与纯 SSA 值流分离。 + + ->/<- 将副作用嵌入每条指令的 reads/writes 数组中, + 副作用与数据流在同一表示中融合。 + + 两种表示的计算结果等价, 但抽象模型不同: + SSA 分离了"值计算"和"效应执行"两个世界 + ->/<- 将世界统一为"对 KV 空间的读写操作" +``` + +### 4.8 互补共存:C2 架构中的双重表示 + +C2 架构的两层 IR 正是利用了两种模型的互补性: + +``` +┌────────────────────────────────────────────────────────────────┐ +│ C2 双层 IR 格局 │ +│ │ +│ Layer 1: 结构化 IR (前端输出, ->/<-) │ +│ ┌──────────────────────────────────────────────────┐ │ +│ │ fn example(x: f32[4]) -> y: f32[4] { │ │ +│ │ tensor.new([4], f32) -> a │ │ +│ │ sum(x) -> s │ │ +│ │ if s > 0 { ← 结构化控制流 │ │ +│ │ x + 1.0 -> y │ │ +│ │ } else { │ │ +│ │ y <- x * -1.0 ← C风格箭头 │ │ +│ │ } │ │ +│ │ tensor.del(a) │ │ +│ │ ret y │ │ +│ │ } │ │ +│ └──────────────────────────────────────────────────┘ │ +│ │ │ +│ │ Scheduler Lowering │ +│ ▼ │ +│ Layer 2: 基本块 IR (VM 执行, SSA 可选) │ +│ ┌──────────────────────────────────────────────────┐ │ +│ │ @0: │ │ +│ │ %0 = tensor.new [4] : f32 ← 块内 SSA (可选) │ │ +│ │ %1 = sum %x │ │ +│ │ %2 = greater %1, 0 │ │ +│ │ br %2 ? @1 : @2 ← 平铺 CFG │ │ +│ │ │ │ +│ │ @1: │ │ +│ │ %3 = add %x, 1.0 │ │ +│ │ jump @3 │ │ +│ │ │ │ +│ │ @2: │ │ +│ │ %4 = mul %x, -1.0 │ │ +│ │ jump @3 │ │ +│ │ │ │ +│ │ @3(%y: f32): ← block argument │ │ +│ │ tensor.del %0 │ │ +│ │ ret %y │ │ +│ └──────────────────────────────────────────────────┘ │ +│ │ +│ 关键: 两种模型各司其职 │ +│ • Layer 1 (->/<-): 面向人、Agent、前端 — 可读、可生成 │ +│ • Layer 2 (SSA+CFG): 面向 VM、优化器 — 可分析、可变换 │ +│ • Lowering 是有损变换: 高层语义展开, 值身份重构 │ +└────────────────────────────────────────────────────────────────┘ +``` + +### 4.9 为什么 deepxir 选择 `->`/`<-` 作为主语法而非 SSA + +| 决策因素 | `->`/`<-` | SSA | 分析 | +|---------|-----------|-----|------| +| **协议对齐** | 直接映射到 reads[]/writes[] | 需要额外 lowering | deepxir 的协议层就是以 reads/writes 为单位的, 强行用 SSA 增加无谓转换 | +| **前端生成复杂度** | 低 (变量名字直接映射) | 高 (需要 SSA 构造算法, 处理 φ/block arg) | Python/Go 前端生成 `->`/`<-` 只需记录变量名; 生成 SSA 需维护版本计数器 | +| **人可读性** | 高 (名字有意义: `./loss`, `./grad`) | 中 (编号 `%0`, `%1` 无意义, 需注释) | 调试时 `./loss` 比 `%42` 更直观 | +| **分布式语义** | 天然 (路径即全局标识) | 需额外映射 (值编号是局部的) | deepxir 是分布式系统, KV 路径就是全局标识 | +| **副作用表达** | 自然 (new/del 就是指令) | 需要 dialect 隔离 | 不需要引入 dialect 概念, 降低系统复杂度 | +| **优化能力** | 需额外分析 | 天然支持 | 但 deepxir 当前优化需求不高 (计算在 GPU 上), 且 C2 Scheduler 内部可以重建 SSA | + +**结论: `->`/`<-` 是 deepxir 的正确主语法选择。SSA 作为 C2 Scheduler 内部的优化表示, 在 lowering 阶段构建, 对前端和 VM 透明。** + +### 4.10 决策总结 + +``` +┌──────────────────────────────────────────────────────────────┐ +│ │ +│ SSA 和 ->/<- 不是等价变种。 │ +│ │ +│ 它们是两种不同的语义模型: │ +│ • SSA 建模值身份 (适合编译器分析和优化) │ +│ • ->/<- 建模存储效应 (适合分布式执行和协议直通) │ +│ │ +│ 对 deepxir 而言: │ +│ • 主语法用 ->/<- — 面向人、Agent、前端、协议 │ +│ • SSA 作为 C2 Scheduler 内部优化 IR — 对用户透明 │ +│ • 两者通过 Scheduler lowering 连接 │ +│ • C1 模式不引入 SSA (保持单层简单) │ +│ • C2 模式在 lowering 时重建 SSA (使能优化) │ +│ │ +│ 这不是非此即彼的选择, 而是分层的职责分配。 │ +│ MLIR 的选择也佐证了这一点: scf (结构化) = 类比 ->/<-, │ +│ cf (底层) = SSA + 基本块, 两层并存。 │ +│ │ +└──────────────────────────────────────────────────────────────┘ +``` + +--- + +## 5. 方案 C1:单层关键字 IR + +前后端共用 deepxir 语法。VM 直接解释控制流关键字,管理嵌套作用域。 + +### 5.1 VM 执行模型 + +```go +func Execute(vtid string) { + inst := decode(currentPC) + switch inst.Opcode { + case "tensor.new": + pushToHeapPlat(vtid, inst) // → heap-plat + case "tensor.del": + pushToHeapPlat(vtid, inst) + case "if": + cond := evalOperand(inst.Cond) + if isTruthy(cond) { + pushScope(vtid, inst.Then) + } else { + pushScope(vtid, inst.Else) + } + case "for": + pushLoopScope(vtid, inst.Var, inst.Start, inst.End, inst.Body) + case "while": + pushWhileScope(vtid, inst.Cond, inst.Body) + case "break": + popLoopScope(vtid) + case "continue": + rewindToLoopCond(vtid) + case "ret": + popFrame(vtid) + default: + dispatch(inst) // 计算 → op-plat / VM 求值 + } +} +``` + +### 5.2 架构评估 + +| 准则 | 评分 | +|------|------| +| 概念完整性 | ⭐⭐⭐⭐⭐ | +| 单一抽象级别 | ⭐⭐⭐⭐ | +| 可扩展性 | ⭐⭐⭐⭐ | + +--- + +## 6. 方案 C2:二层 IR + Scheduler Lowering + +``` +前端 → deepxir (C1 语法) + │ + [Scheduler] + Pass 1: Region Flattening → 基本块 + Pass 2: 块内 SSA 构造 + Pass 3: 优化 + Pass 4: Target Emission + │ + ▼ + 平铺基本块 → VM +``` + +基本块 IR 格式(C2 Scheduler 输出,VM 输入): + +``` +fn example(x: f32[4]) -> y: f32[4] { +@0: + tensor.new([4], f32) -> a + sum(x) -> s + greater(s, 0) -> cond # 展开为前缀, 不做中缀 + br cond ? @1 : @2 + +@1: + add(x, 1.0) -> y # 全部前缀, 无中缀 + jump @3 + +@2: + mul(x, -1.0) -> y + jump @3 + +@3: + tensor.del(a) + ret y +} +``` + +每个 block 有恰好一条终止指令:`br cond ? @t : @f` | `jump @target` | `ret [val]`。 +中缀在 lowering 时全部展开为前缀 opcode。 + +### 架构评估 + +| 准则 | 评分 | +|------|------| +| 关注点分离 | ⭐⭐⭐⭐⭐ | +| 单一抽象级别 | ⭐⭐⭐⭐⭐ | +| 可验证性 | ⭐⭐⭐⭐⭐ | +| 可扩展性 | ⭐⭐⭐⭐⭐ | +| 概念完整性 | ⭐⭐⭐⭐⭐ | +| **综合** | **25/25** | + +--- + +## 7. 方案 C3/C4/C5 精简 + +| 方案 | 核心 | 亮点 | 架构评分 | +|------|------|------|---------| +| C3 单层基本块 | 前端直接输出基本块(C2 VM 输入格式) | VM 极简 | 19/25 | +| C4 单层 Region | VM 原生理解嵌套 region | MLIR 概念对齐 | 22/25 | +| C5 关键字+VM内lower | 前端输出 C1,VM 内部 lowering | 对外简单对内优化 | 19/25 | + +--- + +## 8. 五方案横向对比 + +| 准则 | C1 关键字 | C2 二层+Scheduler | C3 单层块 | C4 Region | C5 关键字+VM内 | +|------|---------|------------------|----------|----------|------------| +| 关注点分离 | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | +| 单一抽象级别 | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | +| 可验证性 | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | +| 可扩展性 | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ | +| 概念完整性 | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | +| **综合** | **19** | **25** | **19** | **22** | **19** | + +### 关键维度对比 + +``` +前端负担: C3 > C1 = C4 = C5 > C2 +VM 复杂度: C1 = C4 > C5 > C2 = C3 +运维复杂度: C2 > C1 = C3 = C4 = C5 +优化能力: C2 > C5 > C1 = C3 = C4 +``` + +--- + +## 9. 结论 + +**架构最优:C2(二层 IR + Scheduler)** — 25/25,所有维度满分。 + +**C1 与 C2 互补**: +1. **先 C1**:deepxir 语法 → VM 直接解释,调试友好,快速可用 +2. **再 C2**:加 Scheduler 做 lowering + 优化,C1 语法即 C2 的结构化 IR 层格式 +3. VM 执行层不变(基本块格式),Scheduler 可插拔 + +**三个核心约束贯穿所有方案**: +- `->` / `<-` 保留方向语义,不可用 `=` 替代 +- `tensor.new` / `tensor.del` 统一命名,栈变量 VM 自推导 +- 中缀运算符需 op-plat 注册 `symbols` 字段 diff --git a/docs/deepxIR/readme.md b/doc/dxlang/test.md similarity index 100% rename from docs/deepxIR/readme.md rename to doc/dxlang/test.md diff --git a/docs/executor/deepx.op.drawio b/doc/executor/deepx.op.drawio similarity index 100% rename from docs/executor/deepx.op.drawio rename to doc/executor/deepx.op.drawio diff --git a/docs/executor/deepx.op.drawio.svg b/doc/executor/deepx.op.drawio.svg similarity index 100% rename from docs/executor/deepx.op.drawio.svg rename to doc/executor/deepx.op.drawio.svg diff --git a/docs/executor/deepx.op.jpg b/doc/executor/deepx.op.jpg similarity index 100% rename from docs/executor/deepx.op.jpg rename to doc/executor/deepx.op.jpg diff --git a/docs/front/aboutop.md b/doc/front/aboutop.md similarity index 100% rename from docs/front/aboutop.md rename to doc/front/aboutop.md diff --git a/docs/front/deepx.jpg b/doc/front/deepx.jpg similarity index 100% rename from docs/front/deepx.jpg rename to doc/front/deepx.jpg diff --git a/docs/front/deepx.op.drawio.svg b/doc/front/deepx.op.drawio.svg similarity index 100% rename from docs/front/deepx.op.drawio.svg rename to doc/front/deepx.op.drawio.svg diff --git a/docs/front/deepxpy.drawio.svg b/doc/front/deepxpy.drawio.svg similarity index 100% rename from docs/front/deepxpy.drawio.svg rename to doc/front/deepxpy.drawio.svg diff --git a/docs/front/front.md b/doc/front/front.md similarity index 100% rename from docs/front/front.md rename to doc/front/front.md diff --git a/docs/front/graph.md b/doc/front/graph.md similarity index 100% rename from docs/front/graph.md rename to doc/front/graph.md diff --git a/docs/front/node.md b/doc/front/node.md similarity index 100% rename from docs/front/node.md rename to doc/front/node.md diff --git a/docs/front/op.md b/doc/front/op.md similarity index 100% rename from docs/front/op.md rename to doc/front/op.md diff --git a/docs/front/py/about.md b/doc/front/py/about.md similarity index 100% rename from docs/front/py/about.md rename to doc/front/py/about.md diff --git a/docs/front/py/contribute.md b/doc/front/py/contribute.md similarity index 100% rename from docs/front/py/contribute.md rename to doc/front/py/contribute.md diff --git a/docs/front/py/deepx.rst b/doc/front/py/deepx.rst similarity index 100% rename from docs/front/py/deepx.rst rename to doc/front/py/deepx.rst diff --git a/doc/heap-plat/README.md b/doc/heap-plat/README.md new file mode 100644 index 00000000..ad980d70 --- /dev/null +++ b/doc/heap-plat/README.md @@ -0,0 +1,205 @@ +# heap-plat 设计 + +> heap-plat 是 DeepX 元程的 **堆管理平面**,负责 tensor 对象的生命周期管理。 +> 本文档定义 heap-plat 的抽象契约和所有实现的共用规范。 + +## 1. 定位 + +在元程 5 核架构中,heap-plat 负责: + +| 核心 | 角色 | +|------|------| +| KV 空间 (Redis) | 存储 tensor 元信息(dtype, shape, 物理地址) | +| heap-plat | 管理 tensor 外部存储(创建/删除/克隆 shm) | +| op-plat | 通过元信息中的物理地址访问 tensor 数据 | +| VM | 将 newtensor/deltensor 指令路由到 heap-plat | + +``` +VM PUSH newtensor ──→ cmd:heap-: + │ + heap-plat 消费 + │ + 分配 POSIX shm + 分配 GPU/CPU buffer + 写入 tensor 元信息到 Redis + │ +VM BLPOP done: ←── LPUSH 完成通知 +``` + +## 2. 抽象契约 + +任何 heap-plat 实现(Metal/CUDA/CPU)必须满足以下契约。 + +### 2.1 必须实现的命令 + +| 命令 | 语义 | 输入 | 输出 | +|------|------|------|------| +| `newtensor` | 创建 tensor | key, dtype, shape, device | 分配存储,写入 Redis 元信息 | +| `deltensor` | 删除 tensor | key | 释放存储,删除 Redis key | +| `clonetensor` | 克隆 tensor 到指定设备 | src_key, dst_key, device | 分配+拷贝,写入新 Redis key | + +### 2.2 通信模型 + +``` +消费: RPOP / BLPOP cmd:heap-: + ↓ +执行: POSIX shm_open / shm_unlink + mmap + ↓ +写入: SET = tensor 元信息 (JSON) + ↓ +通知: LPUSH done: {pc, status, ...} +``` + +### 2.3 命令格式 + +**newtensor:** +```json +{ + "vtid": "1", + "pc": "[0,0]", + "op": "newtensor", + "key": "/models/weights", + "dtype": "f32", + "shape": [1024, 512], + "device": "gpu0" +} +``` + +**deltensor:** +```json +{ + "vtid": "1", + "pc": "[5,0]", + "op": "deltensor", + "key": "/models/weights" +} +``` + +**clonetensor:** +```json +{ + "vtid": "1", + "pc": "[0,0]", + "op": "clonetensor", + "src": "/models/weights", + "dst": "/models/weights_gpu1", + "device": "gpu1" +} +``` + +### 2.4 完成通知格式 + +```json +{ + "pc": "[0,0]", + "status": "ok" +} +``` + +错误: +```json +{ + "pc": "[0,0]", + "status": "error", + "error": { + "code": "SHM_ALLOC_FAILED", + "message": "failed to allocate 2GB shared memory" + } +} +``` + +### 2.5 进程注册 + +启动时向 `/sys/heap-plat/` 注册: + +``` +/sys/heap-plat/metal:0 = {"program":"heap-metal", "device":"gpu0", "status":"running", "pid":, "started_at":} +``` + +命名规则: `:`,实例编号从 0 开始。 + +### 2.6 命令队列 + +VM 通过指定队列向特定实例发送指令: + +``` +cmd:heap-metal:0 → heap-metal 实例 0 (gpu0) +cmd:heap-metal:1 → heap-metal 实例 1 (gpu1) +cmd:heap-cuda:0 → heap-cuda 实例 0 (gpu0) +cmd:heap-cpu:0 → heap-cpu 实例 0 (cpu) +``` + +### 2.7 消费者循环 (所有实现共用) + +``` +1. 启动 → 注册 /sys/heap-plat/ +2. 循环: + a) BLPOP cmd:heap-: (超时 5s) + b) 解析 JSON 命令 + c) 根据 op 字段分发: + "newtensor" → handle_newtensor(req) + "deltensor" → handle_deltensor(req) + "clonetensor" → handle_clonetensor(req) + d) LPUSH done: 完成通知 +3. 退出 → DELETE /sys/heap-plat/ +``` + +## 3. Tensor 元信息格式 (统一) + +所有 heap-plat 实现写入 Redis 的 tensor 元信息格式必须一致: + +```json +{ + "dtype": "f32", + "shape": [1024, 512], + "byte_size": 2097152, + "device": "gpu0", + "address": { + "node": "n1", + "type": "shm", + "shm_name": "/deepx_t_abc123" + }, + "ctime": 1714000000, + "version": 5 +} +``` + +| 字段 | 类型 | 必需 | 说明 | +|------|------|------|------| +| dtype | string | 是 | 见数据类型表 | +| shape | array[int] | 是 | 多维形状 | +| byte_size | int | 是 | 总字节数 = element_count × dtype_size | +| device | string | 是 | gpu0, cpu | +| address.type | string | 是 | 统一为 "shm" | +| address.shm_name | string | 是 | POSIX shm 路径,如 /deepx_t_ | +| address.node | string | 是 | 机器标识 | +| ctime | int | 否 | 创建时间 | +| version | int | 否 | 每次更新递增 | + +**数据类型大小:** + +| dtype | bytes | +|-------|-------| +| f16, bf16 | 2 | +| f32, i32 | 4 | +| f64, i64 | 8 | +| i8, u8 | 1 | +| i16 | 2 | +| bool | 1 | + +## 4. 各平台实现概览 + +| 实现 | 目录 | 状态 | 说明 | +|------|------|------|------| +| [heap-metal](heap-metal.md) | executor/heap-metal/ | 待开发 | macOS Metal 统一内存 | +| [heap-cuda](heap-cuda.md) | executor/heap-cuda/ | 待开发 | Linux CUDA GPU 显存 | +| [heap-cpu](heap-cpu.md) | executor/heap-cpu/ | 待开发 | 纯 CPU 内存 | + +## 5. 待确定问题 + +| 问题 | 状态 | +|------|------| +| 引用计数 (refcount) | 暂不实现,由上层管理 | +| 跨节点 tensor 迁移 | 暂不实现 | +| 零拷贝 GPU 间传输 | 待评估 | +| shm_name 命名冲突 | 使用随机 hex 后缀 + EXISTS 检查 | diff --git a/doc/heap-plat/heap-cpu.md b/doc/heap-plat/heap-cpu.md new file mode 100644 index 00000000..4b1d730f --- /dev/null +++ b/doc/heap-plat/heap-cpu.md @@ -0,0 +1,68 @@ +# heap-cpu + +> 纯 CPU 内存的 heap-plat 实现。待开发。 +> +> **heap-cpu 的进程维持着 deepx 元程的堆在 CPU 设备平台的高可用。** + +## 1. 平台特性 + +| 特性 | 说明 | +|------|------| +| 设备 | CPU (无 GPU) | +| 内存模型 | 传统虚拟内存 | +| 分配方式 | POSIX shm + 普通 mmap/malloc | +| GPU 访问 | 不适用 | + +## 2. 设计要点 + +最简实现,无需 GPU 相关逻辑: + +``` +1. newtensor: + a) shm_open + ftruncate + mmap → CPU ptr + b) SET Redis 元信息 + +2. deltensor: + a) shm_unlink + b) DELETE Redis key +``` + +无 GPU 上下文,无 VRAM 分配,无 memcpy 传输。 + +## 3. Tensor 元信息 + +```json +{ + "dtype": "f32", + "shape": [1024, 512], + "byte_size": 2097152, + "device": "cpu", + "address": { + "type": "shm", + "shm_name": "/deepx_t_abc123" + } +} +``` + +## 4. 进程注册 + +``` +/sys/heap-plat/cpu:0 = {"program":"heap-cpu", "device":"cpu", "status":"running", "pid":} +``` + +## 5. 待开发 + +| 任务 | 说明 | +|------|------| +| shm 分配/释放 | shm_open / shm_unlink | +| 大页支持 | mmap MAP_HUGETLB (可选) | +| NUMA 感知 | mbind (可选) | + +## 6. 依赖 + +- hiredis +- librt (shm_open) + +## 7. 开发量 + +~250 行 C/C++ (最简实现)。 diff --git a/doc/heap-plat/heap-cuda.md b/doc/heap-plat/heap-cuda.md new file mode 100644 index 00000000..be281aaf --- /dev/null +++ b/doc/heap-plat/heap-cuda.md @@ -0,0 +1,91 @@ +# heap-cuda + +> Linux CUDA GPU 显存的 heap-plat 实现。待开发。 +> +> **heap-cuda 的进程维持着 deepx 元程的堆在 CUDA 设备平台的高可用。** + +## 1. 平台特性 + +| 特性 | 说明 | +|------|------| +| 设备 | NVIDIA GPU (CUDA) | +| 内存模型 | 独立显存 (VRAM),需 cudaMemcpy 传输 | +| 分配方式 | cudaMalloc + POSIX shm (寄存器映射) | +| shm 命名 | `/deepx_t_<8位随机hex>` | +| GPU 访问 | CUDA kernel 直接通过 device ptr 访问 | + +## 2. 设计要点 + +### 与 Metal 的关键差异 + +Metal (统一内存): +``` +mmap ptr ←→ GPU 直接访问 (零拷贝) +``` + +CUDA (独立显存): +``` +shm (CPU 可访问) ←→ cudaMemcpy ←→ VRAM (GPU 可访问) +``` + +需要一个额外的共享内存段来记录 VRAM 指针,供 op-cuda 获取。 + +### 方案 + +``` +1. newtensor: + a) 分配 shm (CPU 侧, 用于跨进程交换元信息) + b) cudaMalloc(&dptr, byte_size) → VRAM + c) 将 dptr 写入 shm 头部 (如 offset 0, sizeof(void*) 字节) + d) shm_name 和 dptr 写入 Redis 元信息 + +2. tensor 元信息 (多一个 vram_ptr): + { + "dtype": "f32", + "shape": [1024, 512], + "byte_size": 2097152, + "device": "gpu0", + "address": { + "type": "cuda", + "shm_name": "/deepx_t_abc123", + "vram_ptr": "0x7f...", // cudaMalloc 返回的 device ptr + "cuda_ctx": "" // CUDA context (多 GPU 场景) + }, + "ctime": 1714000000, + "version": 1 + } +``` + +### 多 GPU 场景 + +每张 GPU 独立一个实例: + +``` +heap-cuda:0 → gpu0 → 管理 /dev/nvidia0 +heap-cuda:1 → gpu1 → 管理 /dev/nvidia1 +``` + +CUDA context 与实例一一绑定。clonetensor 跨卡时: +```c +// GPU0 → GPU1 的 clone (通过 P2P 或中转) +cudaMemcpyPeer(dst_ptr_gpu1, 1, src_ptr_gpu0, 0, byte_size); +``` + +## 3. 待开发 + +| 任务 | 说明 | +|------|------| +| CUDA 设备初始化 | cudaSetDevice, context 管理 | +| cudaMalloc / cudaFree | VRAM 分配释放 | +| CPU↔GPU 数据桥 | shm ↔ VRAM 的 memcpy 封装 | +| P2P 传输 (clonetensor) | cudaMemcpyPeer 跨 GPU | +| 进程注册 | /sys/heap-plat/cuda:0 | + +## 4. 依赖 + +- CUDA Toolkit (libcuda) +- hiredis + +## 5. 开发量 + +~500 行 C++/CUDA C。 diff --git a/doc/heap-plat/heap-metal.md b/doc/heap-plat/heap-metal.md new file mode 100644 index 00000000..c9e76be5 --- /dev/null +++ b/doc/heap-plat/heap-metal.md @@ -0,0 +1,156 @@ +# heap-metal + +> macOS Metal 统一内存的 heap-plat 实现。 +> Apple Silicon 上 CPU/GPU 共享物理内存,shm_open + mmap 的指针 +> 可直接通过 newBufferWithBytesNoCopy 被 GPU 访问。 +> +> **heap-metal 的进程维持着 deepx 元程的堆在 Metal 设备平台的高可用。** + +## 1. 平台特性 + +| 特性 | 说明 | +|------|------| +| 设备 | Apple Silicon GPU (Metal) | +| 内存模型 | 统一内存 (CPU/GPU 共享物理内存) | +| 分配方式 | POSIX shm_open + ftruncate + mmap | +| shm 命名 | `/deepx_t_<8位随机hex>` | +| GPU 访问 | `newBufferWithBytesNoCopy` 直接包装 mmap 指针 | + +## 2. 代码位置 + +``` +executor/heap-metal/ +├── CMakeLists.txt +├── src/ +│ ├── main.mm 入口: Redis 连接, 主循环 +│ └── lifecycle/ +│ └── lifecycle.h/.mm newtensor / deltensor / clonetensor + +依赖: + executor/common-metal/ + ├── shm_tensor.h/.mm POSIX shm 封装 + ├── registry.h/.mm 注册表抽象 + └── metal_device.h/.mm Metal 设备管理 +``` + +## 3. newtensor 流程 + +``` +输入: + {"vtid":"1", "pc":"[0,0]", "op":"newtensor", "key":"/models/W", + "dtype":"f32", "shape":[1024,512], "device":"gpu0"} + +步骤: + 1. 计算 byte_size = 1024 × 512 × 4 = 2,097,152 + 2. 生成 shm_name = "/deepx_t_" + random_hex(8) + 3. shm_tensor_create(shm_name, byte_size): + a) shm_open(shm_name, O_CREAT|O_RDWR, 0600) → fd + b) ftruncate(fd, byte_size) + c) mmap(NULL, byte_size, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0) → ptr + d) close(fd) + 4. 构造 tensor 元信息 → SET /models/W + 5. LPUSH done: {"pc":"[0,0]", "status":"ok"} +``` + +**Mac 统一内存优势:** +mmap 返回的 CPU 指针在 Apple Silicon 上可直接被 GPU 使用, +无需 `cudaMemcpy` 或显式数据传输。 + +```objc +// op-metal 侧使用示例 +void* cpu_ptr = mmap_result; +id gpu_buf = [device newBufferWithBytesNoCopy:cpu_ptr + length:byte_size + options:MTLResourceStorageModeShared + deallocator:nil]; +``` + +## 4. deltensor 流程 + +``` +输入: + {"vtid":"1", "pc":"[5,0]", "op":"deltensor", "key":"/models/W"} + +步骤: + 1. GET /models/W → 获取 address.shm_name + 2. shm_unlink(shm_name) → 标记删除 + 3. munmap(ptr, byte_size) → 解除映射 (如果有缓存) + 4. UNLINK /models/W → 删除 Redis key + 5. LPUSH done: {"pc":"[5,0]", "status":"ok"} +``` + +## 5. clonetensor 流程 + +``` +输入: + {"vtid":"1", "pc":"[0,0]", "op":"clonetensor", + "src":"/models/W", "dst":"/models/W_gpu1", "device":"gpu1"} + +步骤: + 1. GET /models/W → 源 tensor 元信息 + 2. 调用 newtensor 逻辑创建目标 tensor (新 shm_name, device=gpu1) + 3. shm_open(src_shm_name) → 映射源数据 + 4. memcpy(dst_ptr, src_ptr, byte_size) // 统一内存下直接拷贝 + 5. SET /models/W_gpu1 → 新 tensor 元信息 + 6. LPUSH done: ... +``` + +## 6. 进程注册 + +``` +启动时: + SET /sys/heap-plat/metal:0 = { + "program": "heap-metal", + "device": "gpu0", + "status": "running", + "pid": , + "started_at": + } + +退出时: + DELETE /sys/heap-plat/metal:0 +``` + +## 7. Redis 命令队列 + +| Key | 说明 | +|-----|------| +| `cmd:heap-metal:0` | 默认实例的命令队列 (监听) | +| `cmd:heap-metal:1` | 第 2 个实例 (如有) | +| `done:` | 完成通知 (写入) | + +## 8. 依赖 + +```bash +brew install hiredis # Redis C 客户端 +``` + +CMakeLists.txt 需链接: +- hiredis +- common-metal (shm_tensor, metal_device) + +## 9. 测试 + +```bash +# 终端1: heap-metal +./heap_metal + +# 终端2: redis-cli 模拟 +redis-cli RPUSH cmd:heap-metal:0 \ + '{"vtid":"t","pc":"[0,0]","op":"newtensor","key":"/test/x","dtype":"f32","shape":[10,10],"device":"gpu0"}' + +redis-cli GET /test/x +# → {"dtype":"f32","shape":[10,10],"byte_size":400,"device":"gpu0","address":{"type":"shm","shm_name":"/deepx_t_a1b2c3d4",...}} + +redis-cli BLPOP done:t 1 + +redis-cli RPUSH cmd:heap-metal:0 \ + '{"vtid":"t","pc":"[1,0]","op":"deltensor","key":"/test/x"}' + +redis-cli GET /test/x +# → (nil) +``` + +## 10. 开发量 + +~300 行 C++/ObjC 新增代码。当前已有 426 行基础设施。 diff --git a/docs/highway.md b/doc/highway.md similarity index 100% rename from docs/highway.md rename to doc/highway.md diff --git a/docs/index.rst b/doc/index.rst similarity index 100% rename from docs/index.rst rename to doc/index.rst diff --git a/docs/language.md b/doc/language.md similarity index 100% rename from docs/language.md rename to doc/language.md diff --git a/doc/metaproc/README.md b/doc/metaproc/README.md new file mode 100644 index 00000000..dc42535c --- /dev/null +++ b/doc/metaproc/README.md @@ -0,0 +1,333 @@ +# 元程 (Metaproc) + +> **程序 = 数据结构 + 函数 + 数据** +> +> 这是元程的核心思想。在分布式系统中,三者被**显式地、可见地、共享地**定义在 +> 一个全局 KV 空间中。数据结构是地址空间的划分规则,函数是可复用的代码单元, +> 数据是计算的实际内容。三者分离,各司其职,通过路径空间统一编排。 + +--- + +## 1. 为什么是"元程" + +### 1.1 问题的根源 + +在单机上写程序,OS 和编译器已经替我们管理好了一切: + +``` +程序员的视角 (单机): + x = a + b ← 看起来只有"数据"和"运算" + y = f(x) + return y + +OS/编译器实际做的 (对程序员透明): + - 在栈上为 x 分配 4 字节 + - 为 a 找到寄存器 r1 + - call f → 分配新栈帧 → 保存 rbp → 传递参数 → ret → 恢复 rbp + - 所有这些(数据结构)都是隐式的 +``` + +但在分布式系统中,**隐式不再可行**。因为: +- 计算分布在不同的 GPU/CPU 上 +- 数据存储在不同的 shm/显存/磁盘上 +- 进程之间只能通过约定的数据格式通信 +- 没有"OS"替你管理全局状态 + +**你需要一个显式的、全局可见的、所有进程都能读懂的数据结构约定。** +这就是元程的 KV 空间。 + +### 1.2 "元程"的含义 + +"元程"(Metaproc)是"元进程"的缩写: + +- **元 (Meta)**:它定义了一个"进程"的抽象骨架——不是运行某个具体程序的进程, + 而是可以运行**任意程序**的进程框架。类比:OS 内核定义了进程的抽象 + (地址空间、线程、栈、堆),但内核本身不规定进程里跑什么程序。 +- **程 (Proc)**:它本身是一个"进程"——具有地址空间(KV 空间)、 + 执行流(vthread)、代码段(func)、数据段(heap),遵循与 OS 进程相同的 + 抽象模式。 + +``` +OS 进程 = 虚拟地址空间 + 线程 + 代码段 + 数据段 + 栈 +元程 = KV 空间 + vthread + /src/func/ + 堆 + /vthread/ +``` + +**关键区别**:OS 进程的地址空间是私有的、字节寻址的、单机的。 +元程的 KV 空间是**共享的、路径寻址的、分布式的**。 + +--- + +## 2. 核心公式:程序 = 数据结构 + 函数 + 数据 + +### 2.1 三元分解 + +``` +程序 = 数据结构 + 函数 + 数据 + +数据结构:变量的容器、寻址方式、组织规则 + → KV 路径空间 + 保留路径约定 + 二维寻址 + +函数: 可复用的执行逻辑单元 + → /src/func/ + /op//func/(源码层 + 编译层) + +数据: 实际被处理和传输的内容 + → 堆 tensor + 栈局部变量 + 立即数 +``` + +### 2.2 三元在元程中的映射 + +| | 数据结构 | 函数 | 数据 | +|---|---|---|---| +| **是什么** | KV 空间 + 路径约定 | func 定义 | tensor + 基础类型 | +| **存储位置** | 路径本身的组织 | `/src/func/`, `/op/.../func/` | `/models/`, `/vthread//` | +| **生命周期** | 系统级别的常量 | 注册后长期有效 | 取决于类型(堆=长期,栈=短期) | +| **谁定义** | 元程规范 | pysdk(用户代码) | pysdk(用户数据) | +| **谁使用** | 所有进程 | VM + 编译器 | op-plat + heap-plat | +| **类比 OS** | 虚拟地址空间布局 | .text 段 + 共享库 | .data/.bss + 栈 | + +### 2.3 为什么三元必须分离 + +``` +反例:如果把数据和函数混在一起 + /vthread/1/data/weights = ... + /vthread/1/func/forward = ... + → 每个 vthread 都要复制一份代码 → 浪费 + → 函数和数据的生命周期不同 → 管理混乱 + → VM 无法跨 vthread 复用代码 → 不能 CALL + +正解:三元分离 + 函数: /src/func/forward ← 全局一份,所有 vthread CALL + 数据: /models/weights ← 全局共享,vthread 通过路径引用 + 结构: /vthread/1/ ← 每个 vthread 独立,栈 + 局部变量 +``` + +--- + +## 3. 元程的五个核心 + +元程系统由五个核心组件构成: + +| 核心 | 角色 | 是什么 | 做什么 | +|---|---|---|---| +| **KV 空间** | 全局状态存储 | 数据结构 | 路径空间、命令队列、锁 | +| **pysdk** | 算法前端 | 函数 + 数据的写入者 | 注册 func 源码、创建 vthread | +| **op-plat** | 计算后端 | 函数的执行者 | 被动消费指令、执行 GPU/CPU 张量运算 | +| **heap-plat** | 堆管理 | 数据的生命周期管理 | tensor 创建/删除/克隆 | +| **VM** | 解释执行 | 数据结构的调度者 | func 翻译、指令路由、vthread 状态推进 | + +``` +┌─ pysdk ────────────────────────────────────────────────┐ +│ 函数: 注册源代码到 /src/func/ │ +│ 数据: 通过 heap-plat 创建 tensor │ +│ 结构: 创建 vthread 到 /vthread/ │ +└──────────────────────┬──────────────────────────────────┘ + │ + ▼ +┌─ KV 空间 ───────────────────────────────────────────────┐ +│ /src/func/ 源码 │ /op/.../func/ 编译 │ /vthread/ 执行 │ +│ 堆变量 │ 命令队列 │ 锁 │ +└──┬────────────┬──────────────┬────────────────────────────┘ + │ │ │ + ▼ ▼ ▼ +┌──────┐ ┌──────────┐ ┌───────────┐ +│ VM │ │ op-plat │ │ heap-plat │ +│ │ │ │ │ │ +│ 解释 │ │ 执行张量 │ │ 管理tensor│ +│ 调度 │ │ 计算 │ │ 生命周期 │ +└──────┘ └──────────┘ └───────────┘ +``` + +### 3.1 进程的被动性 + +所有核心进程(VM、op-plat、heap-plat)都**被动**执行——它们不主动发起操作, +而是消费命令队列。这类似于 OS 中的中断驱动模型: + +| 进程 | 等待什么 | 谁唤醒它 | +|---|---|---| +| VM | `BLPOP notify:vm` + `BLPOP done:` | pysdk + op-plat/heap-plat | +| op-plat | `RPOP cmd:op-:` | VM | +| heap-plat | `RPOP cmd:heap-:` | VM | + +**为什么被动?** +- 被动 = 解耦:生产者不需要知道消费者的具体实例 +- 被动 = 可伸缩:同一个命令队列可以有多个消费者竞争 +- 被动 = 容错:消费者崩溃,消息仍在队列中,重启后继续消费 + +--- + +## 4. 执行全景 + +### 4.1 从用户代码到 GPU 执行 + +``` +用户 Python 代码: + y = x @ w + b + y = y.relu() + + ↓ pysdk 翻译 + +/src/func/forward: + /src/func/forward/0 = matmul(./x, /models/W) -> ./mm + /src/func/forward/1 = add(./mm, /models/b) -> ./mm + /src/func/forward/2 = relu(./mm) -> ./y + + ↓ 编译器优化(可选) + +/op/op-cuda/func/forward: + /op/op-cuda/func/forward/0 = fused_matmul_add_relu(./x, /models/W, /models/b) -> ./y + + ↓ VM CALL → eager translate + +/vthread/1/[0,0]: + /vthread/1/[0,0]/[0,0] = "fused_matmul_add_relu" ← 操作码 + /vthread/1/[0,0]/[0,-1] = "./x" ← 读参数 + /vthread/1/[0,0]/[0,-2] = "/models/W" + /vthread/1/[0,0]/[0,-3] = "/models/b" + /vthread/1/[0,0]/[0, 1] = "./y" ← 写参数 + + ↓ VM dispatch → PUSH 到 op-plat + +cmd:op-cuda:0 ← {"vtid":"1", "opcode":"fused_matmul_add_relu", ...} + + ↓ op-plat 消费 + +op-cuda: RPOP → 解析参数 → GET tensor 元信息 → shm_open → GPU kernel → 完成 + + ↓ 完成通知 + +done:1 ← {"pc":"[0,0]", "status":"ok"} + + ↓ VM 醒来 + +/vthread/1 = {"pc": "[1,0]", "status": "done"} ← 继续下一条 +``` + +### 4.2 控制流示例 + +``` +func dynamic_forward(x, threshold) -> (y) { + sum(./x) -> ./s + greater(./s, threshold) -> ./cond + + if (./cond) { + add(./x, 1.0) -> ./y + } else { + mul(./x, -1.0) -> ./y + } ← VM 直接解释 if,不经过 op-plat +} +``` + +VM 在处理控制流指令时**不需要 PUSH 到 op-plat**——它自己就是控制流的解释器。 + +--- + +## 5. 与 OS 进程模型的对照 + +元程的设计有意模仿了 OS 进程的抽象,但在分布式维度上重新定义了每个概念: + +| OS 进程概念 | 元程对应 | 关键区别 | +|---|---|---| +| 虚拟地址空间 | KV 空间 | 字符串路径寻址、全局共享、分布式 | +| 线程 | Vthread | 状态存储于 KV 空间,多 VM 可并行拾取 | +| 代码段 (.text) | `/src/func/` + `/op/.../func/` | 人类可读 dxlang、后端可优化 | +| 堆段 (.data/.bss) | 非保留路径(堆变量) | tensor 元信息存 KV,实际数据存 shm | +| 栈段 | `/vthread//` 子树 | CALL 产生子栈嵌套,RETURN 删除 | +| 程序计数器 (PC) | `pc` 字段:`"[3,0]"` 或 `"[3,0]/[0,0]"` | 字符串路径,天然嵌套 | +| 栈帧 | CALL 产生的子维度 | 二维坐标 `[addr0, addr1]` | +| 系统调用 | PUSH 到 op-plat / heap-plat 队列 | 异步、可并行、可批量 | +| 互斥锁 | `LOCK/UNLOCK` | TTL 自动释放 | + +### 5.1 与 x86 指令的对照 + +| x86 指令 | 元程等价 | 执行者 | +|---|---|---| +| `mov [addr], value` | `SET /vthread/1/x = value` | VM | +| `add eax, ebx` | `PUSH add(./a, ./b) -> ./c` 到 op-plat | op-plat | +| `call func` | VM 翻译 `/op/.../func/` → `eager inline` 到子栈 | VM | +| `ret` | DELETE 子栈 + PC 回父栈 | VM | +| `cmp + jcc` | `if (./cond)` → VM 直接分支 | VM | +| `int 0x80` (syscall) | PUSH 到 cmd 队列 | VM→op-plat/heap-plat | + +--- + +## 6. 元程的"元" + +### 6.1 元程不是一种语言 + +元程是一个**分布式计算模型**。它不规定: +- 前端语言(Python、Go、C++ 都可以) +- 后端硬件(CUDA、Metal、CPU 都可以) +- 存储实现(Redis、etcd、自研 KV 都可以) + +它只规定: +- 数据结构怎么组织(KV 路径空间 + 保留路径约定) +- 函数怎么定义(`->`/`<-` 读写分离 + dxlang 指令格式) +- 数据怎么流转(命令队列 + 完成通知) +- 执行流怎么管理(vthread 状态机 + CALL/RETURN) + +### 6.2 元程满足的条件 + +只要一个 KV 存储提供以下能力,就能运行元程程序: + +| 能力 | 说明 | +|---|---| +| `GET/SET/DELETE` | 基本 KV 操作 | +| List (FIFO 队列) | `PUSH/POP/BLPOP` | +| 锁 | `LOCK/UNLOCK` + TTL | +| 事务 | `WATCH/MULTI/EXEC` 原子操作 | +| 基础类型 | int, float, bool, string, JSON | + +当前参考实现:DeepX,使用 Redis 作为 KV 空间。 + +--- + +## 7. 文档索引 + +| 文档 | 说明 | +|---|---| +| **[spec-v1.md](spec-v1.md)** | 元程规范 v1 — 抽象模型(KV 空间要求、vthread 模型、CALL/RETURN、异步执行) | +| **[metaproc-datastruct.md](metaproc-datastruct.md)** | 数据结构篇 — KV 路径空间、保留路径、vthread、func、tensor、命令队列 | +| **[deepx-design.md](deepx-design.md)** | DeepX 实现设计 — 五个核心组件的具体实现(Redis key 约定、VM 执行循环、op-plat/heap-plat 协议) | +| **[redis-keys.md](redis-keys.md)** | Redis Key 设计速查表 — 所有路径的 value 类型与示例 | +| **[dev-heap-plat.md](dev-heap-plat.md)** | heap-plat 开发指南 | +| **[dev-op-plat.md](dev-op-plat.md)** | op-plat 开发指南 | +| **[dev-pysdk.md](dev-pysdk.md)** | pysdk 开发指南 | + +### 7.1 阅读顺序建议 + +``` +第一遍(理解元程是什么): + 1. 本文 (README.md) ← 核心思想 + 2. metaproc-datastruct.md ← 数据结构全貌 + 3. spec-v1.md ← 抽象规范 + +第二遍(理解 DeepX 怎么实现): + 4. deepx-design.md ← 五个核心实现 + 5. redis-keys.md ← 路径速查 + +第三遍(上手开发): + 6. dev-heap-plat.md / dev-op-plat.md / dev-pysdk.md +``` + +--- + +## 8. 名称的由来 + +**Metaproc** = Meta + Process + +- **Meta**(元):它是"进程的进程"——定义了进程的抽象框架,而不绑定具体程序。 + 就像 OS 定义了进程的抽象但不规定进程里跑什么,元程定义了分布式进程的抽象 + 但不规定具体算法。 +- **Proc**(程):它本身具有进程的全部特征——地址空间、执行流、代码段、数据段、 + 栈。只是这些概念被重新定义为分布式版本。 + +与"元编程"(Metaprogramming)的区别: +- 元编程:程序操作程序(生成代码、反射、宏) +- 元程:定义程序的程序(定义分布式进程的抽象框架) + +--- + +> **核心命题**:在一个多机、多 GPU 的分布式系统中,你不能依赖 OS 替你管理状态。 +> 你必须显式定义:数据结构是什么(KV 路径空间)、函数放在哪(`/src/func/`)、 +> 数据怎么传(命令队列 + 完成通知)、执行流怎么管理(vthread + PC)。 +> 元程就是这三元的统一框架。 diff --git a/doc/metaproc/deepx-design.md b/doc/metaproc/deepx-design.md new file mode 100644 index 00000000..e1f6e3ee --- /dev/null +++ b/doc/metaproc/deepx-design.md @@ -0,0 +1,840 @@ +# DeepX 元程实现设计 + +> DeepX 是元程规范的第一种实现。本文档定义 DeepX 如何用 5 个核心组件 +> 将元程的抽象模型落地为可运行的分布式计算系统。 + +## 1. 架构总览 + +### 1.1 五个核心 + +| 核心 | 角色 | 变体 | +|------|------|------| +| Redis | KV 空间 — 全局状态存储、命名空间、命令队列、锁 | — | +| pysdk | 算法前端 — 注册源代码到 `/src/func/`,创建 vthread | — | +| op-plat | 计算 — 被动消费指令,执行 GPU/CPU 张量运算 | op-cuda, op-metal, op-cpu | +| heap-plat | 堆管理 — tensor 对象生命周期:创建/删除/克隆 shm | heap-cuda, heap-metal, heap-cpu | +| VM | 解释执行 — CALL 翻译、指令路由、状态推进 | — | + +``` +┌─ pysdk ───────────────────────────────────────────────────────┐ +│ 注册源代码到 /src/func/ │ +│ 创建 vthread 到 /vthread/ │ +└──────────────────────┬─────────────────────────────────────────┘ + │ + ▼ +┌─ Redis (KV 空间) ─────────────────────────────────────────────┐ +│ /src/func/ 源码 │ /op/.../func/ 编译 │ /vthread/ 执行 │ +│ 堆变量 │ List 命令队列 │ Lock 互斥锁 │ +└──┬────────────┬──────────────┬─────────────────────────────────┘ + │ │ │ + ▼ ▼ ▼ +┌──────┐ ┌──────────┐ ┌───────────┐ +│ VM │ │ op-plat │ │ heap-plat │ +│ │ │ │ │ │ +│ 解释 │ │ 执行张量 │ │ 管理tensor│ +│ 调度 │ │ 计算 │ │ 生命周期 │ +│ │ │ op-cuda │ │ heap- │ +│ │ │ op-metal │ │ cuda │ +│ │ │ op-cpu │ │ heap- │ +│ │ │ │ │ metal │ +│ │ │ │ │ heap- │ +│ │ │ │ │ cpu │ +└──────┘ └──────────┘ └───────────┘ +``` + +### 1.2 进程的被动性 + +所有核心进程(VM、op-plat、heap-plat)都**被动**执行,不主动发起操作: + +| 进程 | 驱动方式 | 消费来源 | +|------|---------|---------| +| VM | BLPOP 等待通知 | `notify:vm` (新 vthread 创建) + `done:` (op-plat 完成) | +| op-plat | RPOP 或 BLPOP | `cmd:op-:` (计算指令) | +| heap-plat | RPOP 或 BLPOP | `cmd:heap:` (生命周期指令) | + +## 2. Redis Key 路径约定 + +### 2.1 保留路径 + +``` +/src/func/ 函数源码 (平台无关 dxlang) +/op//func/ 后端编译产物 +/op//list 后端支持的算子列表 +/op// 算子元数据 +/vthread/ 所有 vthread 执行状态 +/sys/ 系统信息 +/cmd/ 命令队列前缀 +/notify/ 通知队列前缀 +/done/ 完成通知队列前缀 +``` + +### 2.2 函数路径 (三层架构) + +函数有三种表示,分别服务不同角色: + +``` +源码层 (/src/func/): + /src/func/ 函数签名 (dxlang 文本) + /src/func//0 第 0 条指令, 如 "matmul(A, B) -> ./Y" + /src/func//1 第 1 条指令, 如 "mul(./Y, alpha) -> ./Y" + /src/func//1/true/0 分支 true 的第 0 条 + +编译层 (/op//func/): + /op/op-cuda/func/ 编译后的函数签名 + /op/op-cuda/func//0 编译后的指令 (可能已融合/拆分) + /op/op-metal/func//0 Metal 编译产物 (可能不同于 CUDA) + +执行层 (/vthread/): + /vthread//[0,0]... VM CALL 时 eager 翻译 (见 §3.2) +``` + +数据流: pysdk→`/src/func/` → 编译器→`/op//func/` → VM CALL→`/vthread/` + +### 2.3 Vthread 路径 (执行层) + +`/vthread/` 存储**机器优化**的执行层格式。指令展开为 `[addr0, addr1]` 二维坐标。 +命名槽位 (`./mm`) 与指令坐标是平级子 key,互不嵌套。 + +``` +/vthread/ → {"pc":"[3,0]", "status":"running"} ← vthread 自身 +/vthread//[0,0] 指令 #0 操作码 +/vthread//[0,-1] 指令 #0 读参数 #1 +/vthread//[0,1] 指令 #0 写参数 #1 +/vthread//[0,0]/[0,0] 子栈 (CALL 产生) +/vthread//a 命名槽位 (与 [0,0] 平级) +/vthread//mm 局部变量 (./mm 解析结果) +``` + +注:`pc` 和 `status` 是 `/vthread/` 的 value 中的字段,不是独立的子 key。 + +### 2.4 命令队列 + +``` +cmd:op-cuda: op-cuda 的命令队列 (如 cmd:op-cuda:0) +cmd:op-metal: op-metal 的命令队列 +cmd:op-cpu: op-cpu 的命令队列 +cmd:heap-cuda: heap-cuda 的命令队列 +cmd:heap-metal: heap-metal 的命令队列 +done: vthread 的完成通知队列 +notify:vm VM 的唤醒通知队列 +``` + +### 2.5 系统路径 + +``` +/sys/op-plat/ op-plat 进程注册 {type, device, status, load} +/sys/heap-plat/ heap-plat 进程注册 +/sys/vm/ VM 实例注册 +/sys/vtid_counter vthread ID 自增计数器 +/sys/config DeepX 全局配置 +``` + +注:op-plat 的算子能力注册在 `/op//list` 和 `/op//`(§2.7), +与 `/sys/` 下的进程注册(存活性、负载)是分离的。 + +### 2.6 堆变量(隐式命名空间) + +除上述保留路径外,所有其他路径均为堆变量。命名规范由上层决定: + +``` +/models/bert/encoder/0/weights ← 建议: 模型名/层名/参数名 +/data/cifar10/train ← 建议: 数据集/用途 +/checkpoints/run_001/step_1000 ← 建议: 运行ID/步数 +/rl/state_buffer ← 自由命名 +/src/func/forward ← 保留 (函数源码) +/op/op-cuda/func/forward ← 保留 (编译产物) +/vthread/1/... ← 保留 (执行栈) +``` + +### 2.7 op-plat 算子注册 + +op-plat **程序**的算子能力是静态的,所有**进程实例**共享。 +区分程序级(算子列表、元数据、编译产物)和进程级(运行状态、负载)。 + +**程序级(静态,所有实例共享):** + +``` +/op/op-cuda/list → ["matmul", "add", "relu", "softmax", + "fused_matmul_add_relu", "fused_linear_norm", ...] + +/op/op-cuda/matmul → { + "category": "matmul", + "dtype": ["f32", "f16", "bf16"], + "max_shape": [8192, 8192, 8192], + "fusion_group": "linear" +} + +/op/op-cuda/fused_matmul_add_relu → { + "category": "fused", + "dtype": ["f32", "f16"], + "replaces": ["matmul", "add", "relu"] +} + +/op/op-cuda/func//... → 该程序专属的编译产物 +``` + +**进程级(动态,每个实例独立):** + +``` +/sys/op-plat/cuda:0 → {"program":"op-cuda", "device":"gpu0", "status":"running", "load":0.3} +/sys/op-plat/cuda:1 → {"program":"op-cuda", "device":"gpu1", "status":"running", "load":0.7} +/sys/op-plat/metal:0 → {"program":"op-metal", "device":"gpu0", "status":"running", "load":0.1} +``` + +实例命名规则: `:`,如 `cuda:0`, `cuda:1`, `metal:0`。 + +**编译器的使用:** + +``` +融合: + 1. GET /op/op-cuda/list → 过滤 category="fused" 的算子 + 2. 对每个 fused 算子,读取 replaces 列表 + 3. 扫描 /src/func/ 的指令序列,滑动窗口匹配 replaces + 4. 匹配成功 → 用 fused 算子等价替换 + 5. 写入 /op/op-cuda/func/ + +拆分: + 1. GET /op/op-cuda//max_shape → 单卡能力上限 + 2. 对比 /src/func/ 中该算子的输入 tensor shape + 3. 超出上限 → 按 batch 维或 hidden 维拆分 + 4. 标注子算子目标设备 (如 GPU0/GPU1) + 5. 写入 /op/op-cuda/func/ +``` + +**VM 的指令路由:** + +``` +VM 需要执行 matmul: + 1. GET /op/op-cuda/list → 含 "matmul" → 该程序支持 + 2. GET /sys/op-plat/cuda:0 → {device:"gpu0", load:0.3} + 3. GET /sys/op-plat/cuda:1 → {device:"gpu1", load:0.7} + 4. 选择负载最低的实例 → cuda:0 + 5. PUSH cmd:op-cuda:0 (命令队列与该实例绑定) +``` + +TODO: 编译器实现细节,待编译器设计阶段确定。 + +## 3. VM 进程设计 + +### 3.1 VM 状态机 + +``` + ┌─────────┐ + │ start │ + └────┬────┘ + │ 扫描 /vthread/ 找到 status=init 的 vthread + ┌────▼────┐ + ┌───►│ idle │◄──────────────┐ + │ └────┬────┘ │ + │ │ 拾取 vthread │ + │ ┌────▼────┐ │ + │ │ running │ │ + │ └────┬────┘ │ + │ │ │ + │ ┌────┴────┐ │ + │ │ │ │ + │ 计算指令 控制流/生命周期 │ + │ │ │ │ + │ ▼ ▼ │ + │ PUSH VM 直接处理 │ + │ 到op-plat (call/if/ │ + │ │ return/del) │ + │ ▼ │ │ + │ BLPOP │ │ + │ done: │ │ + │ │ │ │ + │ ▼ │ │ + │ PC++ │ │ + │ │ │ │ + └────┴─────────┘ │ + │ │ + │ vthread 全部执行完 │ + ┌────▼────┐ │ + │ wait │────────────────────┘ + │ (BLPOP │ 新 vthread 创建 + │ notify) │ + └─────────┘ +``` + +### 3.2 VM 执行循环 + +``` +1. 扫描 /vthread/ 子树,找到 status=init 的 vthread + → 有: 拾取执行 + → 无: BLPOP notify:vm (等待新 vthread 创建) + +2. 对当前 vthread: + a) GET /vthread/ → {pc: "[n,0]", status: "running", ...} + b) GET /vthread//[n,0] → opcode + +3. 指令分发: + + [张量计算指令] add, matmul, relu, mul, exp, ... + → 构造 op-plat 任务包 (§4.2) + → PUSH cmd:op-cuda: + → SET /vthread/ = {pc: "[n,0]", status: "wait", ...} + → BLPOP done: + → 醒来后检查结果 + → SET /vthread/ = {pc: "[n+1,0]", status: "running", ...} + + [控制流指令] + call (VM 翻译: 编译层→执行层, eager): + → 读取 [n,-1] func_name, [n,-2..] args, [n,1] return_slot + → 批读取 /op//func// 下所有指令 (一次 Redis MGET/Pipeline) + → 逐条翻译: 解析 dxlang → 形参替换为实参 → 展开为 [i,0],[i,-1]...[i,1] + → 批写入 /vthread//[n,0]/ 子栈 + → SET /vthread/ = {pc: "[n,0]/[0,0]", status: "running", ...} (进入子栈) + + return: + → 将返回值写入父栈 CALL 指令的写参数槽位 + → DELETE 当前子栈 KV 路径 + → SET /vthread/ = {pc: "[n+1,0]", status: "running", ...} (父栈下一条) + + if: + → 读取条件: [n,-1] = cond, [n,0]/true/0 和 [n,0]/false/0 + → 根据 cond 结果设置 PC + + for: + → 初始化迭代器, 循环 body 直到耗尽 + + [生命周期指令] + newtensor: + → PUSH cmd:heap: + → 等待完成 (或同步) + + deltensor: + → PUSH cmd:heap: + + var: + → SET /vthread// = (基础类型直存 Redis) + +4. vthread 执行完毕 (pc 超出 seq 范围): + → SET /vthread/ = {pc: "...", status: "done", ...} + → 清理 /vthread// 子树 (GC) + → 回到步骤 1 +``` + +### 3.3 多 VM 并行 + +多个 VM 进程可并行运行,每个 VM 独立拾取 status=init 的 vthread: + +``` +VM-1 拾取 /vthread/1/ → 执行 +VM-2 拾取 /vthread/2/ → 执行 +VM-3 空闲 → BLPOP notify:vm + +避免竞争: 使用 Redis 原子操作标记 vthread 已被拾取 + WATCH /vthread/ + GET → {pc: "...", status: "init"} + MULTI + SET /vthread/ = {pc: "...", status: "running"} + EXEC + → 只有一个 VM 能成功 +``` + +TODO: 多 VM 间的负载均衡策略,待开发时确定。 + +### 3.4 性能关键:批读取与本地缓存 + +VM 每次从 Redis GET 一个 key 都有网络往返延迟 (~100μs)。 +相比之下,VM 读自己进程内存仅 ~ns 级。减少 Redis 访问是性能关键。 + +**当前设计的 Redis 访问次数(逐条执行 20 条指令的 func):** + +``` +每步操作 Redis 访问次数 +────────────────────────── ────────────── +CALL 翻译: 读取 func 源码 1 次 (MGET 批量取所有指令) +CALL 翻译: 写入子栈 1 次 (Pipeline 批量写) +逐条执行: GET pc + GET opcode 20 × 2 = 40 次 +发射到 op-plat: PUSH 20 次 +等待完成: BLPOP 20 次 +更新 PC: SET 20 次 +总计 ~102 次 Redis 往返 +``` + +**优化方向:** + +**(a) 批读取 — 减少执行循环的 Redis 次数:** + +``` +当前 (逐条): + GET /vthread/1 → pc=[0,0] ← 1 次 + GET /vthread/1/[0,0] → opcode ← 1 次 + SET /vthread/1 = {pc:[0,0], status:"wait"} + ...等 op-plat... + GET /vthread/1 → pc=[1,0] ← 1 次 + GET /vthread/1/[1,0] → opcode ← 1 次 + +优化后 (VM 本地预取一批指令): + MGET /vthread/1, /vthread/1/[0,0], /vthread/1/[0,-1], /vthread/1/[0,1], + /vthread/1/[1,0], /vthread/1/[1,-1], /vthread/1/[1,1], ... + /vthread/1/[19,0], /vthread/1/[19,-1], /vthread/1/[19,1] + → 1 次往返,拿到整个子栈的所有指令,缓存到 VM 本地内存 + → 后续逐条执行时从本地缓存读取,零 Redis 访问 + → 每批只改 PC 和 status (SET /vthread/) +``` + +**(b) MGET 批量取 tensor 元信息:** + +``` +发射指令前,需要读取参数指向的 tensor 元信息 (dtype, shape, shm): + GET /vthread/1/a ← 1 次 + GET /vthread/1/b ← 1 次 + GET /models/W ← 1 次 + + → MGET /vthread/1/a, /vthread/1/b, /models/W ← 1 次 +``` + +**(c) 算子融合 — 编译器在 /src/func/→/op//func/ 等价替换:** + +算子融合是编译器的职责。编译器读取 `/src/func/` 的源码 + +`/op//list` 的融合算子注册信息,将匹配的连续指令替换为等价的融合指令, +写入 `/op//func/`。VM 和 op-plat 对融合无感知。 + +``` +融合前 (/src/func/forward 源码): + /src/func/forward/0 = matmul(A, B) -> ./mm + /src/func/forward/1 = add(./mm, b) -> ./mm + /src/func/forward/2 = relu(./mm) -> ./out + +编译器读取 /op/op-cuda/list → 含 fused_matmul_add_relu +编译器读取 /op/op-cuda/fused_matmul_add_relu.replaces → ["matmul","add","relu"] +匹配成功 → 等价替换 → 写入编译层: + +融合后 (/op/op-cuda/func/forward): + /op/op-cuda/func/forward/0 = fused_matmul_add_relu(A, B, b) -> ./out + (原来 3 条指令 → 1 条) +``` + +VM 执行时从 `/op/op-cuda/func/forward` 读取,透明受益: + +``` +融合前: VM PUSH matmul → 等 → PUSH add → 等 → PUSH relu → 等 + → 3 次 VM↔Redis 往返 + 3 次 VM↔op-plat 往返 + +融合后: VM PUSH fused_matmul_add_relu → 等 + → 1 次调度, 减少 2/3 的 VM 和 Redis 开销 + → GPU 侧也减少 kernel launch 次数 (3→1) +``` + +TODO: 编译器融合规则的实现,待编译器设计阶段确定。 + +**(d) Tensor 并行拆分 — 编译器在 /src/func/→/op//func/ 等价拆分:** + +当单个算子涉及的 tensor 过大、超出单卡显存或需要跨卡并行时,编译器将其拆分为 +多个等价的子算子,标注目标设备,写入编译层。 + +``` +拆分前 (/src/func/forward 源码): + /src/func/forward/0 = matmul(A, W) -> ./out + +编译器分析 → 发现 A 跨 2 张卡 → 拆分替换 → 写入编译层: + +拆分后 (/op/op-cuda/func/forward): + /op/op-cuda/func/forward/0 = slice(A, 0, 512) -> ./A_shard0 + /op/op-cuda/func/forward/1 = slice(A, 512, 1024) -> ./A_shard1 + /op/op-cuda/func/forward/2 = matmul(./A_shard0, W) -> ./out0 @gpu0 + /op/op-cuda/func/forward/3 = matmul(./A_shard1, W) -> ./out1 @gpu1 + /op/op-cuda/func/forward/4 = concat(./out0, ./out1) -> ./out +``` + +VM 和 op-plat 对拆分无感知——VM 照常执行编译层的指令序列, +根据 `@gpu0`/`@gpu1` 标注路由到对应实例。 + +融合与拆分是编译器在 `/op//func/` 层的一对互补操作: + +``` +融合: 多条 → 一条 (减少调度次数, 适用于 GPU 内) +拆分: 一条 → 多条 (增加并行度, 适用于跨 GPU) +``` + +TODO: 编译器拆分策略和设备拓扑感知,待编译器设计阶段确定。 + +**(e) VM 本地子栈缓存:** + +``` +VM 在执行一个 vthread 的某个栈帧时: + 1. CALL 翻译完成后,整个子栈的 [i,j] key 一次性 MGET 到本地 map + 2. 执行期间,opcode 和参数查找全走本地内存 (O(1) hash) + 3. 仅以下操作写 Redis: + - 更新 PC (SET /vthread/) + - PUSH 到 op-plat 命令队列 + - BLPOP 完成通知 + 4. RETURN 时 DELETE 子栈 (1 次批量操作) +``` + +TODO: 本地缓存的失效策略(当外部 WATCH 修改了 vthread 状态时),待开发时确定。 + +## 4. op-plat 协议 + +### 4.1 op-plat 生命周期 + +``` +1. 启动 → 注册到 /sys/op-plat/: + { "type": "op-cuda", "device": "gpu0", "status": "idle", + "capabilities": ["add", "matmul", "relu", "softmax", ...] } + +2. 进入消费循环: + while true: + RPOP cmd:op-cuda:0 (或 BLPOP 阻塞等待) + 解析指令 + 执行 GPU kernel + LPUSH done: 完成通知 + +3. 退出 → DELETE /sys/op-plat/ +``` + +### 4.2 指令格式 + +VM PUSH 到 `cmd:op-cuda:` 的指令: + +```json +{ + "vtid": "1", + "pc": "[3,0]", + "opcode": "matmul", + "inputs": [ + { + "key": "/vthread/1/a", + "dtype": "f32", + "shape": [1024, 512], + "address": { + "node": "n1", + "device": "gpu0", + "type": "shm", + "shm_name": "/deepx_t_abc123", + "byte_size": 2097152 + } + } + ], + "outputs": [ + { + "key": "/vthread/1/c", + "dtype": "f32", + "shape": [1024, 256], + "address": { + "node": "n1", + "device": "gpu0", + "type": "shm", + "shm_name": "/deepx_t_def456", + "byte_size": 1048576 + } + } + ], + "params": { + "transpose_a": false, + "transpose_b": true + } +} +``` + +TODO: 是否直接发送 GPU 指针而非 shm_name,待开发时根据性能需求确定。 + +### 4.3 完成通知格式 + +op-plat 计算完成后 LPUSH 到 `done:`: + +```json +{ + "pc": "[3,0]", + "status": "ok", + "outputs_updated": [ + {"key": "/vthread/1/c", "new_shape": [1024, 256]} + ] +} +``` + +错误情况: +```json +{ + "pc": "[3,0]", + "status": "error", + "error": { + "code": "GPU_OOM", + "message": "out of memory: requested 2GB, available 1.5GB" + } +} +``` + +### 4.4 批量发射 + +VM 可将无依赖的多条指令打包为一批,一次 PUSH: + +```json +{ + "batch": [ + {"pc": "[3,0]", "opcode": "add", ...}, + {"pc": "[4,0]", "opcode": "mul", ...}, + {"pc": "[5,0]", "opcode": "relu", ...} + ] +} +``` + +op-plat 可以并行执行 batch 内的指令(不同 CUDA stream 或 Metal command queue)。 + +TODO: 批量发射的依赖分析由编译器完成还是 VM 运行时分析,待开发时确定。当前阶段 VM 逐条发送。 + +## 5. heap-plat 协议 + +### 5.1 heap-plat 生命周期 + +``` +1. 启动 → 注册到 /sys/heap-plat/: + { "type": "heap-cuda", "device": "gpu0", "status": "idle" } + +2. 进入消费循环: + while true: + RPOP cmd:heap:0 (或 BLPOP) + 解析指令 + 执行 shm 分配/释放/克隆 + 回复到 done: 或直接 SET 元信息到堆路径 + +3. 退出 → DELETE /sys/heap-plat/ +``` + +### 5.2 指令格式 + +**创建 tensor:** +```json +{ + "vtid": "1", + "pc": "[0,0]", + "op": "newtensor", + "key": "/models/weights", + "dtype": "f32", + "shape": [1024, 512], + "device": "gpu0" +} +``` + +heap-plat 执行: +1. 分配 shm + GPU buffer +2. SET `/models/weights` = 完整元信息(含 shm_name、byte_size、device) + +**删除 tensor:** +```json +{ + "vtid": "1", + "pc": "[5,0]", + "op": "deltensor", + "key": "/models/weights" +} +``` + +**克隆 tensor:** +```json +{ + "vtid": "1", + "pc": "[0,0]", + "op": "clonetensor", + "src": "/models/weights", + "dst": "/models/weights_gpu1", + "device": "gpu1" +} +``` + +TODO: 引用计数机制 (refcount) 是否在 heap-plat 侧实现,还是由 VM 统一管理,待开发时确定。 + +## 6. Tensor 元信息格式 + +堆变量和 vthread 命名槽位的 value 的 tensor 元信息: + +```json +{ + "dtype": "f32", + "shape": [1024, 512], + "byte_size": 2097152, + "device": "gpu0", + "address": { + "node": "n1", + "type": "shm", + "shm_name": "/deepx_t_abc123" + }, + "ctime": 1714000000, + "version": 5 +} +``` + +对于 vthread 命名槽位中的基础类型,value 直接是字面量: +``` +/vthread/1/a = 1 (int) +/vthread/1/b = 3.14 (float) +/vthread/1/flag = true (bool) +``` + +## 7. pysdk 接口设计 + +### 7.1 当前模式:直接发送 IR 序列 + +pysdk 直接将 deepxIR 序列写入 `/src/func/` 和 `/vthread/`: + +```python +# front/py/deepx/nn/functional/ 下的代码模式 (当前) +kv = KVSpace() + +# 1. 定义 func (写入源码层) +kv.set("/src/func/forward", "(forward(A, B) -> (C))") +kv.set("/src/func/forward/0", "matmul(A, B) -> ./mm") +kv.set("/src/func/forward/1", "relu(./mm) -> C") + +# 2. 创建 vthread +vtid = kv.alloc_vtid() +kv.set(f"/vthread/{vtid}", {"pc": "[0,0]", "status": "init"}) + +# 3. 写入入口指令 (call main) +kv.set(f"/vthread/{vtid}/[0,0]", "call") +kv.set(f"/vthread/{vtid}/[0,-1]", "forward") +kv.set(f"/vthread/{vtid}/[0,-2]", "/models/A") +kv.set(f"/vthread/{vtid}/[0,-3]", "/models/B") +kv.set(f"/vthread/{vtid}/[0,1]", "./C") + +# 4. 通知 VM +kv.push("notify:vm", {"event": "new_vthread", "vtid": vtid}) +``` + +### 7.2 未来模式:经编译器 + +``` +pysdk 写入 /src/func/ (源码层) + +编译器: + 1. 读取 /src/func/, /op//list, /op// + 2. 融合: 扫描指令序列, 匹配 fused 算子的 replaces 模式, 等价替换 + 3. 拆分: 对比 max_shape, 超出则 slice+子算子+concat + 4. 插入 deltensor 指令 + 5. 写入 /op//func/ (编译层) + +VM CALL 时读取 /op//func/, 不再读 /src/func/ +``` + +TODO: 编译器的输入格式 (Python AST? dxlang DSL? deepxIR?) 和输出约定,待设计阶段确定。 + +### 7.3 Func 缓存 + +pysdk 维护本地 func 缓存,避免重复发送: + +```python +class FuncCache: + def set_func(self, name, signature, instructions): + if self._cache.get(name) == hash(instructions): + return # 未变化,跳过 + kv.set(f"/src/func/{name}", signature) + for i, inst in enumerate(instructions): + kv.set(f"/src/func/{name}/{i}", inst) + self._cache[name] = hash(instructions) +``` + +## 8. 启动流程 + +### 8.1 集群启动顺序 + +``` +1. Redis 启动 (已有) +2. heap-plat 启动 × N (每个 GPU 一个) + → 注册到 /sys/heap-plat/ +3. op-plat 启动 × N (每个 GPU 一个) + → 注册到 /sys/op-plat/ + → 开始 BLPOP cmd:op-cuda: +4. VM 启动 × M (可多个) + → 注册到 /sys/vm/ + → 扫描 /vthread/ 或 BLPOP notify:vm +5. pysdk 启动 + → 发送 func 定义到 /src/func/ + → 编译器 (可选) 编译到 /op//func/ + → 创建 vthread + → PUSH notify:vm +``` + +### 8.2 Vthread 创建流程 + +``` +pysdk: + 1. 分配 vtid (Redis INCR /sys/vtid_counter) + 2. SET /vthread/ = {"pc": "[0,0]", "status": "init"} + 3. 写入入口指令序列 + 4. PUSH notify:vm {"event": "new_vthread", "vtid": ""} + +VM: + 1. BLPOP notify:vm → 收到事件 + 2. WATCH /vthread/ + 3. GET → {pc: "...", status: "init"} + 4. MULTI → SET /vthread/ = {pc: "...", status: "running"} → EXEC + 5. 开始执行循环 (§3.2) +``` + +## 9. 错误处理 + +### 9.1 错误分类 + +| 错误类型 | 示例 | 处理方式 | +|---------|------|---------| +| op-plat 执行失败 | GPU OOM, 数值溢出 | status→error, error 信息写入 /vthread/ 的 value | +| heap-plat 执行失败 | shm 分配失败, 磁盘满 | status→error | +| 超时 | op-plat 无响应 | VM 超时后 status→error | +| 锁冲突 | LOCK 超时 | 调用者决定重试或报错 | + +### 9.2 错误传播 + +``` +op-plat 返回 error: + 1. VM BLPOP 收到 error 完成通知 + 2. SET /vthread/ = {pc: "[n,0]", status: "error", error: {...}} + 3. VM 释放该 vthread, 回到 idle 状态 + +pysdk 可通过 GET /vthread/ 检查 status 字段感知错误 +TODO: 错误恢复策略 (重试 / 跳过 / 降级) 待开发时确定 +``` + +## 10. 监控与运维 + +### 10.1 系统状态查询 + +``` +GET /sys/op-plat/0 → op-plat 状态和负载 +GET /sys/heap-plat/0 → heap-plat 状态 +KEYS /vthread/* → 所有 vthread 列表 +GET /vthread/1 → 特定 vthread 的状态 (含 pc 和 status 字段) +``` + +### 10.2 liveness 检测 + +TODO: 各进程的心跳机制和故障检测,待开发时确定。 + +## 11. 目录结构映射 + +``` +executor/ +├── vm/ VM 进程 (新增) +│ └── src/main.mm +├── op-metal/ op-plat (Metal 实现, 已有) +├── op-cuda/ op-plat (CUDA 实现, 待开发) +├── heap-metal/ heap-plat (Metal 实现, 已有) +├── heap-cuda/ heap-plat (CUDA 实现, 待开发) +└── common-metal/ 共享库 (已有) + +front/ +└── py/deepx/ pysdk + └── nn/functional/ deepxIR 算子序列发送 + +doc/ +└── metaproc/ 元程设计文档 + ├── spec-v1.md 元程规范 v1 (抽象模型) + ├── spec-control-flow-v1.md 控制流与前端代码生成 v1 整合设计 (6 套方案) + ├── deepx-design.md 本文件 (DeepX 实现设计) + └── CONVERSATION.md 设计对话记录 +``` + +## 12. 待开发时确定的问题 + +| 问题 | 当前状态 | +|------|---------| +| 批量发射的依赖分析 (编译器 vs VM 运行时) | 暂定 VM 逐条发送 | +| 指令中发送物理地址 vs key 引用 | 暂定发送完整 tensor 元信息 | +| 多 VM 负载均衡策略 | 暂定 SETNX 竞争拾取 | +| 引用计数 (heap-plat vs VM) | 暂未实现 | +| 编译器输入格式与输出约定 | pysdk 直接发送 IR 可工作,编译器为后续优化 | +| 多卡/跨节点 tensor 迁移 | 暂不实现,当单机处理 | +| 动态图支持 | 暂不支持 | +| 进程心跳与故障检测 | 暂不实现 | +| 错误恢复策略 | 暂定 status→error,等待人工介入 | diff --git a/doc/metaproc/deepx-speed-strategy.md b/doc/metaproc/deepx-speed-strategy.md new file mode 100644 index 00000000..fc6f3489 --- /dev/null +++ b/doc/metaproc/deepx-speed-strategy.md @@ -0,0 +1,320 @@ +# deepx 速度赶超策略 + +> 分析 deepx/dxlang 如何在训练和推理速度上赶超 PyTorch + vLLM,覆盖单卡→多节点→大规模场景。 + +## 一、速度差距的根因分析 + +### 1.1 当前 deepx 的执行模型 + +``` +dxlang 指令 → VM解释(每指令6次Redis往返) → op-plat(单算子dispatch) → GPU +``` + +**训练场景的具体差距**,以 Llama-70B 一次 forward+backward 为例: + +| 维度 | PyTorch FSDP (8×H100) | deepx 当前架构 | +|------|----------------------|---------------| +| 每层 matmul 执行方式 | 直接 launch cuBLAS kernel | VM 6次Redis → RPUSH cmd → op-plat BLPOP → BLPOP done | +| 算子间数据传递 | GPU 显存直接传递 | 经过 Redis key + JSON marshal/unmarshal | +| 梯度同步 | NCCL all-reduce (900 GB/s NVLink) | 不存在 | +| 参数分片 | FSDP flatten + reduce-scatter | 不存在 | +| MFU (硬件利用率) | ~70% | <1% | +| 1T token 训练耗时 | ~3 天 | 理论不可完成(每op ~1ms Redis开销) | + +**核心矛盾**:当前 deepx 每条 dxlang 指令都走 Redis 控制面,而 PyTorch 的训练图一旦编译完成,就完全在 GPU 上执行。Redis 的单次往返延迟 (~0.15ms) 对于 GPU 计算(一次 matmul 可能 <0.01ms 完成)是天文数字。 + +### 1.2 推理场景的差距 + +| 维度 | vLLM (8×H100) | deepx 当前架构 | +|------|--------------|---------------| +| 单请求延迟 | 预填充+解码 一体化 | VM 解释 + op dispatch,逐 token 串行 | +| KV-cache 管理 | PagedAttention (块级虚拟内存) | 不存在 | +| 批量调度 | Continuous batching (token级) | vthread 逐个执行 | +| 吞吐 (Llama-70B) | ~50K tok/s | <100 tok/s (粗略估) | + +--- + +## 二、五个关键架构缺口(及填平方案) + +### 缺口 1:从指令解释 → 图编译 + +**现状**:VM 逐条 `Execute()` 指令,每条 6 次 Redis 往返。 + +**方案**:在 dxlang 函数粒度做 AOT/JIT 编译。 + +``` +当前: + dxlang func → VM逐条解释 → 每指令 Redis dispatch → op-plat 单算子 → GPU + +改造后: + dxlang func → dxlang Compiler → Fused GPU Kernel → 一次 GPU launch + ↑ ↑ + Redis仅存源码 zero Redis in hot path +``` + +具体实现: + +``` +# 编译层新增 (缓存编译产物,避免重复编译) +/op//kernel/ = base64(compiled_kernel_binary) +/op//kernel//meta = {"grid": [N,M], "block": [256], "smem": 48KB} + +# VM CALL 时: +1. 检查 /op/cuda/kernel/ 是否存在 +2. 存在 → 直接 launch kernel (一次 GPU call) +3. 不存在 → 编译: dxlang func → 融合 CUDA kernel → 存入 Redis → launch +``` + +**核心收益**:训练时每条指令从 ~1ms 降到 ~10μs(仅 GPU launch 开销),提升 **100×**。 + +### 缺口 2:无 GPU-to-GPU 通信 + +**现状**:GPU 间数据传输完全不存在。多卡 = 各算各的。 + +**方案**:新增 `comm-plat`(通信平面),封装 NCCL/RDMA。 + +``` +op-plat 新增通信原语: + allreduce(tensor_list, "sum") → reduced_tensors + allgather(tensor) → gathered_tensors + reduce_scatter(tensor) → scattered_tensors + broadcast(tensor, src_rank) → replicated_tensor + send(tensor, dst_node) / recv(src_node) → tensor +``` + +Redis 中的表示(仅存通信拓扑元信息,不存数据): + +``` +/sys/topology/nodes = ["n0:8×H100", "n1:8×H100", ...] +/sys/topology/links = {"n0:n1": "IB_NDR400", ...} +/sys/topology/n0/gpus = {"gpu0": "H100_80GB", ...} +``` + +数据本身通过 NVLink/InfiniBand 直通,Redis 只记录"谁和谁在通信"。 + +### 缺口 3:无分布式并行策略 + +**方案**:利用 dxlang 的声明式语义,在编译阶段自动插入并行策略。 + +``` +# 用户在 dxlang 中写单卡逻辑: +def train_step(x:tensor, y:tensor) -> (loss:tensor): + forward(x) -> ./h + compute_loss(./h, y) -> ./loss + backward(./loss) -> ./grad + +# 编译器根据拓扑自动展开为 (以 2×4 TP+DP 为例): +# Node 0 GPU 0 (TP rank 0, DP rank 0): +def train_step_shard(x_shard_0:tensor, y:tensor) -> (loss:tensor): + col_parallel_linear(x_shard_0, W_col_0) -> ./h0 + allreduce(./h0) -> ./h + row_parallel_linear(./h, W_row_0) -> ./out0 + compute_loss(./out0, y) -> ./loss0 + backward(./loss0) -> ./grad0 + allreduce(./grad0) -> ./grad +``` + +Redis 核心价值:**所有参数在 KV 空间有全局路径**,编译器可以看到: + +- 哪些 tensor 需要分片(shape 大 → 自动 TP) +- 哪些参数可以复制(小参数 → DP 广播) +- 数据怎么路由(`/data/shard_0` → node 0, `/data/shard_1` → node 1) + +vs PyTorch:PyTorch 的 FSDP/DeepSpeed 配置分散在 rank 配置中,deepx 做到**配置集中 + 编译器自动决策**。 + +### 缺口 4:推理侧无 KV-cache 管理 + +**方案**:heap-plat 升级为 PagedAttention 式块管理器。 + +``` +heap-plat 新增: + alloc_block(block_size) → block_id # 分配一个 KV-cache 块 + free_block(block_id) # 释放 + copy_block(src_id, dst_id) # CoW 前缀共享 + defrag() # 整理碎片 +``` + +Redis 中的映射(轻量元信息): + +``` +/heap/kv_cache/blocks/free = [0, 1, 2, ...] # List: 空闲块 +/heap/kv_cache/blocks/used = ["req_42:0→3", ...] # 已分配映射 +/heap/kv_cache/block_size = 16 # 每块 token 数 +/heap/kv_cache/total_blocks = 4096 +``` + +vthread 天然映射到 continuous batching: + +``` +# VM scheduler 同时推进多个 vthread: +/vthread/req_1/ pc="[5,0]" status="wait_decode" # 等待采样 +/vthread/req_2/ pc="[3,0]" status="wait_decode" # 同上 +/vthread/req_3/ pc="[7,0]" status="running" # 正在执行 attention + +# 一个 batch 中混合 prefill + decode,动态加入/退出 +``` + +### 缺口 5:无混合精度 / 量化基础设施 + +**方案**:在 dxlang 类型系统层统一处理。 + +``` +# 类型系统支持精度标注: +def forward(x:tensor, w:tensor) -> (y:tensor): + matmul(x, w) -> ./y + +# 编译器自动插入 cast (BF16 训练 / FP8 推理): +def forward(x:tensor, w:tensor) -> (y:tensor): + cast(x, master_fp32) -> ./x32 # 编译器插入 + cast(w, master_fp32) -> ./w32 + matmul(./x32, ./w32) -> ./y32 + cast(./y32, bf16) -> ./y +``` + +量化权重作为特殊 heap 变量,Redis 记录量化元信息: + +``` +/models/llama70b/layer0/q_weight = { + "dtype": "int4", + "shape": [8192, 28672], + "zero_point": 8, + "scale": 0.037, + "group_size": 128, + "address": {"type": "shm", "shm_name": "/deepx_q_abc"} +} +``` + +--- + +## 三、多节点大规模:deepx 的结构性优势 + +PyTorch 的多节点方案本质是 **MPI 思想**:rank 编号、process group、显式通信。这在小规模(<100 GPU)效果好,但在以下场景吃力: + +### 场景 A:异构集群(不同 GPU 型号混部 + CPU 节点) + +PyTorch 的 FSDP 假设同构 GPU。deepx 的路径寻址天然支持异构: + +``` +# H100 节点上的 tensor 分片 +/data/shard_h100_0 → node:0, device:gpu0 (H100, 80GB) +# A100 节点上的 tensor 分片 +/data/shard_a100_0 → node:1, device:gpu0 (A100, 40GB) + +# 编译器根据 device 能力自动调整分片大小: +# H100 → 分片更大,A100 → 分片更小 +``` + +Redis 集中了"全局物理视图",编译器做分片决策时无需 rank 协商。 + +### 场景 B:弹性训练(节点动态加入/退出) + +``` +# 节点加入: +/sys/topology/nodes ← 追加 "n3:8×H100" +# 编译器重编译,redistribute shards +# 无需手动改 rank 配置 + +# 节点故障: +/sys/topology/nodes ← 标记 n1 为 "dead" +# 自动触发 checkpoint 恢复 + 重分片 +``` + +### 场景 C:超大规模 MoE(专家路由) + +``` +# MoE router 输出 (存在 Redis 中): +/vthread/req_42/route = {"token_0": "expert_3", "token_1": "expert_7", ...} + +# op-plat 消费路由信息,专家分布在不同节点: +/expert/3/forward → node:5, gpu:2 +/expert/7/forward → node:2, gpu:0 + +# 通信模式: all-to-all (scatter tokens, gather results) +# Redis 只记录路由决策,token 数据走 RDMA +``` + +PyTorch 实现 MoE 需要手动管理专家映射和 all-to-all 通信。deepx 用路径空间做声明式路由。 + +--- + +## 四、量化对比:能否赶超? + +### 训练速度 + +| 阶段 | PyTorch FSDP | deepx 改造后 | 关键 | +|------|-------------|-------------|------| +| 单卡单层 matmul | cuBLAS 直调 | 融合 CUDA kernel(同底层库) | 持平 | +| 多卡梯度同步 | NCCL all-reduce | comm-plat NCCL(同底层库) | 持平 | +| 图编译优化 | torch.compile (Inductor) | dxlang compiler(融合+内存规划) | 可达 90%+ | +| 异构调度 | 手动配置 | Redis 集中 + 自动分片 | **领先** | +| 弹性容错 | Elastic 启动器 | Redis 拓扑热更新 | **领先** | +| MFU 上限 | ~70% | ~65-70% | 接近持平 | + +**结论**:训练速度,单卡硬上限相同(都是 cuBLAS/NCCL),**deepx 可以在 <16 节点内追平,>64 节点的异构/弹性场景可能反超**。 + +### 推理速度 + +| 阶段 | vLLM | deepx 改造后 | 关键 | +|------|------|-------------|------| +| Attention kernel | FlashAttention-3 | 复用(C++ FFI 对接) | 持平 | +| KV-cache 管理 | PagedAttention | heap-plat 块管理 | 可达同等 | +| Continuous batching | 内置调度器 | VM multi-vthread 调度 | 设计更灵活 | +| Prefix caching | 自动 | Redis 显式全局共享 | **领先** | +| 多模型混合服务 | 需多实例 | 共享 KV 空间路由 | **领先** | +| 吞吐上限 | ~50K tok/s | ~45-50K tok/s | 接近持平 | + +--- + +## 五、路线图 + +``` +Phase 1: 单卡图编译 (当前 → 3个月) + ├─ dxlang func → fused Metal/CUDA kernel 编译器 + ├─ 保留 Redis 用于源码存储和 kernel cache + └─ 目标: 单卡训练吞吐 ≥ PyTorch eager 的 80% + +Phase 2: 单机多卡 (3个月 → 6个月) + ├─ comm-plat: NCCL 封装为 opcode (allreduce / broadcast / reduce_scatter / allgather) + ├─ heap-plat: FSDP-style shard/reconstruct (参数分片 + 收集) + ├─ dxlang compiler: 自动插入 TP/DP 切分 pass + └─ 目标: 8×H100 训练 MFU ≥ 60% + +Phase 3: 多节点 (6个月 → 12个月) + ├─ RDMA backend for comm-plat (InfiniBand / RoCE) + ├─ 分层 all-reduce (node内 NVLink → node间 IB) + ├─ Redis 拓扑管理 + 弹性容错 (热加入/热退出/自动恢复) + └─ 目标: 64节点线性加速比 ≥ 85% + +Phase 4: 推理优化 (并行 Phase 2-3) + ├─ heap-plat PagedAttention (块分配/释放/CoW/碎片整理) + ├─ VM continuous batching scheduler (多 vthread 并发推进) + ├─ 量化路径 (AWQ/GPTQ → dxlang 类型系统 + 编译器自动 cast) + └─ 目标: Llama-70B ≥ 40K tok/s +``` + +--- + +## 六、核心结论 + +**deepx 赶超 PyTorch/vLLM 的路径不是"做得更快",而是"做得不同"**: + +| 维度 | PyTorch/vLLM | deepx 改造后 | +|------|-------------|-------------| +| 编程模型 | 命令式 Python | 声明式 dxlang + 编译器 | +| 分布式协调 | 去中心化 (rank/group) | **集中式 KV 空间** + 去中心化执行 | +| 优化方式 | 手动配置 (FSDP/TP/PP) | **编译器自动决策** | +| 异构支持 | 勉强(同构假设,异构需手动调参) | **原生(路径寻址,按 device 能力自适应)** | +| 弹性 | 重启式(node failure → 全量重启) | **热更新(拓扑 KV 变更,增量恢复)** | +| 硬上限 | cuBLAS/NCCL 物理极限 | 同左(物理无法超越) | + +**一句话**:在均质纯 GPU 集群上 deepx 只能追平,不可能超越(共享相同的数学和硬件上限)。但在 **异构混合集群、弹性训练、超大规模 MoE、多模型混合服务** 这些 PyTorch 架构不擅长的场景中,deepx 的集中式 KV 空间 + 声明式 dxlang 有结构性优势。 + +--- + +## 参考 + +- 当前 VM 吞吐分析: `.claude/skills/debug-kvspace.md`(每指令 6 次 Redis 往返,~800 native inst/s) +- Redis Key 布局: `doc/metaproc/redis-keys.md` +- dxlang 控制流设计: `doc/dxlang/spec-control-flow-v1.md` +- heap-plat 设计: `doc/heap-plat/README.md` +- op-plat 设计: `doc/op-plat/README.md` diff --git a/doc/metaproc/dev-heap-plat.md b/doc/metaproc/dev-heap-plat.md new file mode 100644 index 00000000..b3f73107 --- /dev/null +++ b/doc/metaproc/dev-heap-plat.md @@ -0,0 +1,203 @@ +# heap-plat 开发指南 + +> 开发 heap-metal (macOS 统一内存)。heap-cuda / heap-cpu 暂不开发。 +> Redis 连通测试后,最先开发此组件——先有内存,才能计算。 + +## 1. 角色与职责 + +**heap-\* 的进程维持着 deepx 元程的堆在 \* 设备平台的高可用。** + +heap-plat 管理 tensor 对象的生命周期:创建、删除、克隆。 + +| 能力 | 说明 | +|------|------| +| Tensor 创建 (newtensor) | 分配 POSIX shm + GPU buffer,写入元信息到 Redis | +| Tensor 删除 (deltensor) | 释放 shm,删除 Redis key | +| Tensor 克隆 (clonetensor) | 在指定设备上创建 tensor 副本 | +| 进程注册 | 启动时向 `/sys/heap-plat/` 注册 | + +## 2. 当前状态 + +``` +executor/heap-metal/ +├── src/ +│ ├── main.mm 入口 (已连接 Redis, 有主循环) +│ └── lifecycle/ +│ └── lifecycle.h 命令处理 (创建/删除/查询 tensor) + +依赖 common-metal: + src/shm_tensor.h/.mm POSIX shm 创建/打开/关闭 + src/registry.h 注册表抽象基类 + +代码量: 426 行 +``` + +## 3. 通信模型 + +``` + VM heap-metal + ── ───────────── + PUSH cmd:heap-metal:0 ──→ RPOP/BLPOP 消费 + │ + 分配/释放 shm + 写入 Redis 元信息 + │ + BLPOP done: ←────────── LPUSH 完成事件 +``` + +## 4. 待开发任务 + +### 任务 H1: Redis 命令消费循环 (main.mm) + +```cpp +// 伪代码 +while (true) { + auto cmd = redis.blpop("cmd:heap-metal:0", timeout_sec=5); + if (!cmd) continue; // 超时,下一轮 + + auto req = json::parse(cmd); + string vtid = req["vtid"]; + string pc = req["pc"]; + string op = req["op"]; + + if (op == "newtensor") { + handle_newtensor(req); + } else if (op == "deltensor") { + handle_deltensor(req); + } else if (op == "clonetensor") { + handle_clonetensor(req); + } + + // 回复完成 + json done = {{"pc", pc}, {"status", "ok"}}; + redis.lpush("done:" + vtid, done.dump()); +} +``` + +**依赖:** hiredis (Redis C 客户端)。Mac 上 `brew install hiredis`。 + +### 任务 H2: newtensor 实现 + +``` +输入: {vtid, pc, op:"newtensor", key:"/models/weights", dtype:"f32", shape:[1024,512], device:"gpu0"} + +处理: + 1. 计算 byte_size = element_count(shape) × dtype_size(dtype) + 例如: 1024×512×4 = 2,097,152 bytes (f32=4bytes) + 2. 生成 shm_name = "/deepx_t_" + random_hex(8) + 3. 调用 shm_tensor_create(shm_name, byte_size) → 分配 POSIX shm + 4. 构造 tensor 元信息 → SET /models/weights + 5. 回复 LPUSH done: {"pc":"...", "status":"ok"} +``` + +**写入 Redis 的 tensor 元信息:** +```json +{ + "dtype": "f32", + "shape": [1024, 512], + "byte_size": 2097152, + "device": "gpu0", + "address": { + "node": "n1", + "type": "shm", + "shm_name": "/deepx_t_a1b2c3d4" + }, + "ctime": 1714000000, + "version": 1 +} +``` + +**Mac 统一内存说明:** +Mac Apple Silicon 上 CPU 和 GPU 共享物理内存。shm_open + mmap 返回的指针 +可以直接被 Metal 使用(通过 newBufferWithBytesNoCopy 包装)。 + +### 任务 H3: deltensor 实现 + +``` +输入: {vtid, pc, op:"deltensor", key:"/models/weights"} + +处理: + 1. GET /models/weights → 获取 shm_name + 2. shm_tensor_unlink(shm_name) → 释放 POSIX shm + 3. UNLINK /models/weights → 删除 Redis key + 4. 回复 LPUSH done: {"pc":"...", "status":"ok"} +``` + +### 任务 H4: clonetensor 实现 + +``` +输入: {vtid, pc, op:"clonetensor", src:"/models/weights", dst:"/models/weights_gpu1", device:"gpu1"} + +处理: + 1. GET src → 获取源 tensor 元信息 + 2. 分配新 shm → shm_tensor_create(new_shm_name, src.byte_size) + 3. memcpy(src_ptr, dst_ptr, src.byte_size) // Mac 统一内存下直接 memcpy + 4. SET dst → 新 tensor 元信息 (含 new_shm_name, device) + 5. 回复 LPUSH done: {"pc":"...", "status":"ok"} +``` + +### 任务 H5: 进程注册 + +启动时注册到 Redis: + +``` +SET /sys/heap-plat/metal:0 = { + "program": "heap-metal", + "device": "gpu0", + "status": "running", + "pid": , + "started_at": +} +``` + +退出时清理: +``` +DELETE /sys/heap-plat/metal:0 +``` + +## 5. 编译与运行 (macOS) + +```bash +# 安装依赖 +brew install hiredis + +# 构建 +cd executor/heap-metal +mkdir -p build && cd build +cmake .. && make + +# 运行 +./heap_metal +``` + +## 6. 验证方法 + +```bash +# 终端1: 启动 heap-metal +./heap_metal + +# 终端2: 通过 redis-cli 发送测试命令 +redis-cli RPUSH cmd:heap-metal:0 '{"vtid":"test1","pc":"[0,0]","op":"newtensor","key":"/test/x","dtype":"f32","shape":[100,100],"device":"gpu0"}' + +# 查看结果 +redis-cli GET /test/x +# 应返回: {"dtype":"f32","shape":[100,100],"byte_size":40000,"device":"gpu0","address":{...}} + +# 检查完成通知 +redis-cli BLPOP done:test1 1 +# 应返回: {"pc":"[0,0]","status":"ok"} + +# 删除 +redis-cli RPUSH cmd:heap-metal:0 '{"vtid":"test1","pc":"[1,0]","op":"deltensor","key":"/test/x"}' +``` + +## 7. 开发量评估 + +| 任务 | 新增代码 | 难度 | +|------|---------|------| +| H1: 命令消费循环 | ~80 行 | 低 | +| H2: newtensor | ~80 行 | 低 | +| H3: deltensor | ~50 行 | 低 | +| H4: clonetensor | ~50 行 | 低 | +| H5: 进程注册 | ~40 行 | 低 | +| **合计** | **~300 行** | **低** | diff --git a/doc/metaproc/dev-op-plat.md b/doc/metaproc/dev-op-plat.md new file mode 100644 index 00000000..86ed09f2 --- /dev/null +++ b/doc/metaproc/dev-op-plat.md @@ -0,0 +1,300 @@ +# op-plat 开发指南 + +> 开发 op-metal (macOS Metal GPU)。op-cuda / op-cpu 暂不开发。 +> heap-plat 完成后,有内存分配能力了再开发此组件。 + +## 1. 角色与职责 + +op-plat 是执行张量计算指令的被动进程。 + +| 能力 | 说明 | +|------|------| +| 指令消费 | 从 `cmd:op-metal:0` 队列 RPOP/BLPOP 消费指令 | +| 张量计算 | 执行 GPU kernel (elementwise, matmul, reduce, changeshape) | +| 完成通知 | 计算完成后 LPUSH 到 `done:` | +| 算子注册 | 启动时向 `/op/op-metal/` 注册支持的算子 | + +## 2. 当前状态 + +``` +executor/op-metal/ +├── src/ +│ ├── client/main.mm 入口 (占位, 待改造) +│ ├── deepx/ +│ │ ├── metal_context.hpp/mm Metal 设备管理 +│ │ ├── mem/mem_metal.hpp 内存管理 (shm 包装) +│ │ ├── dtype_metal.hpp 数据类型映射 +│ │ └── tensorfunc/ +│ │ ├── elementwise_miaobyte.hpp add/sub/mul/div +│ │ ├── elementwise_common.hpp relu/sigmoid/tanh/gelu +│ │ ├── init_miaobyte.hpp zeros/ones/arange +│ │ ├── metal_common.hpp Metal 工具函数 +│ │ └── tensorlife_miaobyte.hpp newtensor/deltensor/clone +│ └── test/shm/ 跨进程 shm 测试验证通过 + +代码量: 1,325 行 +``` + +## 3. 通信模型 + +``` + VM op-metal + ── ──────── + PUSH cmd:op-metal:0 ───→ RPOP/BLPOP 消费 + │ │ + │ 根据 key 从 Redis GET tensor 元信息 + │ 通过 shm_name 获取 GPU 指针 + │ 执行 Metal GPU kernel + │ │ + BLPOP done: ←────── LPUSH 完成事件 +``` + +## 4. 待开发任务 + +### 任务 O1: Redis 命令消费循环 (main.mm) + +```cpp +// 伪代码 +while (true) { + auto cmd = redis.rpop("cmd:op-metal:0"); + if (!cmd) { usleep(100); continue; } + + auto req = json::parse(cmd); + string opcode = req["opcode"]; // "add", "matmul", "relu", ... + string vtid = req["vtid"]; + string pc = req["pc"]; + + // 1. 获取 tensor 的 GPU 指针 + vector input_ptrs; + for (auto& inp : req["inputs"]) { + auto meta = redis.get(inp["key"]); // GET tensor 元信息 + auto [fd, ptr] = shm_open_and_map(meta["address"]["shm_name"]); + MTL::Buffer* buf = device->newBufferWithBytesNoCopy(ptr, meta["byte_size"], ...); + input_ptrs.push_back(buf); + } + + vector output_ptrs; + for (auto& out : req["outputs"]) { + auto meta = redis.get(out["key"]); + auto [fd, ptr] = shm_open_and_map(meta["address"]["shm_name"]); + MTL::Buffer* buf = device->newBufferWithBytesNoCopy(ptr, meta["byte_size"], ...); + output_ptrs.push_back(buf); + } + + // 2. 分发到 GPU kernel + json result = dispatch_kernel(opcode, inputs, outputs, req["params"]); + + // 3. 更新输出 tensor 元信息 (如果 shape 变了) + for (auto& out : req["outputs"]) { + auto meta = redis.get(out["key"]); + meta["version"] = meta["version"].get() + 1; + redis.set(out["key"], meta); + } + + // 4. 通知 VM 完成 + redis.lpush("done:" + vtid, { + {"pc", pc}, + {"status", result["status"]}, // "ok" or "error" + {"outputs_updated", req["outputs"]} + }); +} +``` + +**依赖:** hiredis + +### 任务 O2: Tensor 元信息获取 + +op-plat 根据 Redis key 获取 tensor 的 shm 地址: + +``` +指令中的 inputs: [{"key": "/vthread/1/a", ...}] + → GET /vthread/1/a → {"dtype","shape","device","address":{"shm_name":"...","byte_size":...}} + → shm_open(shm_name) → fd + → mmap(fd, ..., byte_size) → CPU ptr + → device->newBufferWithBytesNoCopy(ptr, byte_size, ...) → MTLBuffer (GPU ptr) +``` + +可选的本地路径缓存 (减少 Redis GET): +```cpp +struct PathCacheEntry { + void* gpu_ptr; + size_t byte_size; + int version; +}; +unordered_map path_cache; +``` + +当前版本每次重新 GET。缓存失效策略后续处理。 + +### 任务 O3: 完成通知 + +计算完成后,更新输出 tensor 元信息并通知 VM: + +```cpp +// 更新输出 tensor version +redis.set("/vthread/1/c", updated_meta); + +// 通知 VM +redis.lpush("done:" + vtid, json{ + {"pc", pc}, + {"status", "ok"}, + {"outputs_updated", json::array({"/vthread/1/c"})} +}.dump()); +``` + +### 任务 O4: 算子注册 (程序级) + +启动时向 `/op/op-metal/` 注册支持的算子: + +``` +启动时注册: + SET /op/op-metal/list = [ + "add", "sub", "mul", "div", + "relu", "sigmoid", "tanh", "gelu", + "zeros", "ones", "arange" + ] + + SET /op/op-metal/add = { + "category": "elementwise", + "dtype": ["f32", "f16", "i32"] + } + + SET /op/op-metal/relu = { + "category": "activation", + "dtype": ["f32", "f16"] + } + + SET /op/op-metal/matmul = { + "category": "matmul", + "dtype": ["f32", "f16"], + "max_shape": [8192, 8192, 8192] + } + + 注意: 算子注册在 /op/op-metal/ (程序级),与 /sys/op-plat/ (进程级) 分离 +``` + +### 任务 O5: 算子路由 (dispatch_kernel) + +```cpp +json dispatch_kernel(string opcode, vector inputs, + vector outputs, json params) { + if (opcode == "add") return kernel_add(inputs, outputs); + if (opcode == "sub") return kernel_sub(inputs, outputs); + if (opcode == "mul") return kernel_mul(inputs, outputs); + if (opcode == "div") return kernel_div(inputs, outputs); + if (opcode == "relu") return kernel_relu(inputs, outputs); + if (opcode == "sigmoid") return kernel_sigmoid(inputs, outputs); + if (opcode == "tanh") return kernel_tanh(inputs, outputs); + if (opcode == "gelu") return kernel_gelu(inputs, outputs); + if (opcode == "zeros") return kernel_zeros(inputs, outputs); + if (opcode == "ones") return kernel_ones(inputs, outputs); + if (opcode == "arange") return kernel_arange(inputs, outputs); + if (opcode == "matmul") return kernel_matmul(inputs, outputs, params); + if (opcode == "sum") return kernel_sum(inputs, outputs, params); + + return {{"status", "error"}, {"error", {{"code", "UNKNOWN_OP"}, {"message", opcode}}}}; +} +``` + +### 任务 O6: 进程注册 + +启动时注册到 `/sys/op-plat/`: + +``` +SET /sys/op-plat/metal:0 = { + "program": "op-metal", + "device": "gpu0", + "status": "running", + "load": 0.0, + "pid": , + "started_at": +} +``` + +## 5. GPU Kernel 覆盖情况 + +| 类别 | 已实现 | 需补充 | +|------|-------|--------| +| elementwise | add/sub/mul/div (miaobyte) | — | +| activation | relu/sigmoid/tanh/gelu (common) | — | +| init | zeros/ones/arange (miaobyte) | — | +| matmul | — | **需开发** (MPSMatrixMultiplication 或 Metal shader) | +| reduce | — | **需开发** (sum/mean/max) | +| changeshape | — | **需开发** (reshape/transpose/concat/slice) | + +### matmul 开发建议 + +macOS 上矩阵乘法推荐使用 MPS (Metal Performance Shaders): + +```objc +// 使用 MPSMatrixMultiplication +#import + +MPSMatrixMultiplication* matmul = [[MPSMatrixMultiplication alloc] + initWithDevice:device + transposeLeft:false + transposeRight:true + resultRows:M resultColumns:N interiorColumns:K + alpha:1.0 beta:0.0]; + +[matmul encodeToCommandBuffer:commandBuffer + leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; +``` + +## 6. 编译与运行 (macOS) + +```bash +# 安装依赖 +brew install hiredis + +# 构建 +cd executor/op-metal +mkdir -p build && cd build +cmake .. && make + +# 运行 +./op_metal +``` + +## 7. 验证方法 + +```bash +# 终端1: 启动 op-metal +./op_metal + +# 终端2: 通过 redis-cli 检查算子注册 +redis-cli GET /op/op-metal/list +# → ["add","sub","mul","div","relu","sigmoid","tanh","gelu","zeros","ones","arange"] + +# 发送测试指令 (需先通过 heap-plat 创建 tensor) +redis-cli RPUSH cmd:op-metal:0 '{ + "vtid":"test1", + "pc":"[0,0]", + "opcode":"add", + "inputs":[ + {"key":"/test/a","dtype":"f32","shape":[100],"address":{"shm_name":"/deepx_t_xxx","byte_size":400}} + ], + "outputs":[ + {"key":"/test/c","dtype":"f32","shape":[100],"address":{"shm_name":"/deepx_t_yyy","byte_size":400}} + ], + "params":{} +}' + +# 检查完成通知 +redis-cli BLPOP done:test1 1 +``` + +## 8. 开发量评估 + +| 任务 | 新增代码 | 难度 | +|------|---------|------| +| O1: Redis 命令循环 | ~200 行 | 中 | +| O2: Tensor 元信息获取 + shm 映射 | ~100 行 | 低 | +| O3: 完成通知 | ~50 行 | 低 | +| O4: 算子注册 | ~50 行 | 低 | +| O5: 算子路由 (dispatch) | ~100 行 | 低 | +| O6: 进程注册 | ~40 行 | 低 | +| matmul kernel | ~150 行 | 中 | +| reduce kernel | ~100 行 | 中 | +| changeshape kernel | ~50 行 | 低 | +| **合计** | **~840 行** | **中** | diff --git a/doc/metaproc/dev-pysdk.md b/doc/metaproc/dev-pysdk.md new file mode 100644 index 00000000..dc86798a --- /dev/null +++ b/doc/metaproc/dev-pysdk.md @@ -0,0 +1,350 @@ +# pysdk 开发指南 + +> pysdk 是 DeepX 的 Python 算法前端,负责注册源码到 Redis、创建 vthread。 +> 当前 front/py/deepx 已有 9,524 行代码,需要在现有基础上增量改造。 + +## 1. 角色与职责 + +pysdk 是用户侧接口,负责将用户的模型代码翻译为 `/src/func/` 下的函数定义和 `/vthread/` 下的执行单元。 + +| 能力 | 说明 | +|------|------| +| 函数注册 | 将 func 签名 + 指令序列写入 `/src/func//` | +| Vthread 创建 | 分配 vtid,写入入口 CALL 指令,唤醒 VM | +| Tensor 管理 | 创建/删除堆 tensor,通过 heap-plat | +| 状态查询 | GET `/vthread/` 检查执行状态和结果 | +| Func 缓存 | 避免重复发送未变化的函数 | + +## 2. 当前状态 + +``` +front/py/deepx/ +├── tensor/ tensor 操作 (elementwise, matmul, reduce, ...) +├── nn/ deepxIR 生成, parameter +├── optim/ sgd, adam +└── ... + +代码量: 9,524 行 Python,约 30 个文件 +``` + +当前模式: 通过 HTTP 直接发送 deepxIR 序列到 Redis HTTP 代理。 +改造后: 通过 KVSpace 客户端写入 `/src/func/` 和 `/vthread/`。 + +## 3. 改造策略 + +**增量改造** — 不重写现有 9,524 行代码,而是: + +1. 新增 `KVSpace` 客户端类 (封装 Redis 操作) +2. 新增 `FuncCache` 类 (避免重复发送) +3. 新增 `VThreadCreator` 类 (创建 vthread) +4. 现有 `tensor/` 算子逐步适配,从 HTTP 模式切换到 KVSpace 模式 + +## 4. 待开发任务 + +### 任务 P1: KVSpace 客户端抽象 + +封装 Redis 操作为 KVSpace 语义接口: + +```python +import redis +import json + +class KVSpace: + """ + KV 空间客户端,封装 Redis 操作为元程路径语义。 + + 用法: + kv = KVSpace("redis://localhost:6379") + kv.set("/models/weights", {"dtype": "f32", "shape": [1024, 512]}) + meta = kv.get("/models/weights") + """ + + def __init__(self, redis_url="redis://localhost:6379"): + self.redis = redis.from_url(redis_url) + + # === 基本 KV 操作 === + def get(self, key: str): + val = self.redis.get(key) + return json.loads(val) if val else None + + def set(self, key: str, value): + return self.redis.set(key, json.dumps(value)) + + def delete(self, key: str): + return self.redis.delete(key) + + def exists(self, key: str) -> bool: + return self.redis.exists(key) > 0 + + def keys(self, pattern: str) -> list: + return self.redis.keys(pattern) + + # === 函数源码写入 === + def set_func(self, name: str, signature: str, instructions: list): + """ + 将函数定义写入 /src/func// + + Args: + name: 函数名, 如 "gemm", "forward" + signature: dxlang 签名, 如 "(gemm(A:tensor, B:tensor) -> (Y:tensor))" + instructions: 指令列表, 每条是 dxlang 字符串 + 如 ["matmul(A, B) -> ./Y", "relu(./Y) -> ./out"] + """ + self.set(f"/src/func/{name}", signature) + for i, inst in enumerate(instructions): + self.set(f"/src/func/{name}/{i}", inst) + + def get_func(self, name: str) -> dict: + """读取函数定义 (调试用)""" + sig = self.get(f"/src/func/{name}") + insts = [] + i = 0 + while True: + inst = self.get(f"/src/func/{name}/{i}") + if inst is None: + break + insts.append(inst) + i += 1 + return {"signature": sig, "instructions": insts} + + # === Vthread 管理 === + def alloc_vtid(self) -> str: + """分配新的 vthread ID""" + return str(self.redis.incr("/sys/vtid_counter")) + + def get_vthread(self, vtid: str) -> dict: + """获取 vthread 状态""" + return self.get(f"/vthread/{vtid}") + + def wait_vthread(self, vtid: str, timeout: float = 30.0, + poll_interval: float = 0.05) -> dict: + """ + 等待 vthread 执行完成 + + Returns: + vthread 最终状态 (status 为 "done" 或 "error") + """ + import time + deadline = time.time() + timeout + while time.time() < deadline: + state = self.get_vthread(vtid) + if state is None: + raise RuntimeError(f"vthread {vtid} not found") + if state["status"] in ("done", "error"): + return state + time.sleep(poll_interval) + raise TimeoutError(f"vthread {vtid} timeout after {timeout}s") + + # === 命令队列 === + def push(self, queue: str, value): + return self.redis.rpush(queue, json.dumps(value)) + + def pop(self, queue: str, timeout: int = 0): + result = self.redis.blpop(queue, timeout) + return json.loads(result[1]) if result else None +``` + +### 任务 P2: FuncCache + +避免重复发送相同 func 定义: + +```python +class FuncCache: + """函数缓存,基于内容 hash 避免重复发送""" + + def __init__(self, kv: KVSpace): + self.kv = kv + self._hashes: dict = {} + + @staticmethod + def _compute_hash(signature: str, instructions: list) -> str: + import hashlib + content = signature + "".join(instructions) + return hashlib.md5(content.encode()).hexdigest() + + def set_if_changed(self, name: str, signature: str, instructions: list) -> bool: + """ + 仅在函数内容变化时才写入 KV 空间 + + Returns: + True 如果有变更并写入, False 如果未变化跳过 + """ + h = self._compute_hash(signature, instructions) + if self._hashes.get(name) == h: + return False + self.kv.set_func(name, signature, instructions) + self._hashes[name] = h + return True + + def invalidate(self, name: str): + """强制下次写入 (用于调试)""" + self._hashes.pop(name, None) +``` + +### 任务 P3: Vthread 创建器 + +```python +class VThreadCreator: + def __init__(self, kv: KVSpace): + self.kv = kv + + def create(self, entry_func: str, bindings: dict, + entry_inst: str = "[0,0]") -> str: + """ + 创建 vthread 并通知 VM + + 写入 /vthread// 的入口 CALL 指令: + [0, 0] = "call" + [0,-1] = entry_func + [0,-2] = bindings[绑定1] + ... + [0, 1] = "./out" (默认返回值) + + Args: + entry_func: 入口函数名, 如 "forward" + bindings: 参数绑定, 如 {"A": "./a", "B": "./b", "alpha": 1.0} + entry_inst: 入口指令坐标, 默认 "[0,0]" + + Returns: + vtid 字符串 + """ + vtid = self.kv.alloc_vtid() + + # 设置 vthread 状态 + self.kv.set(f"/vthread/{vtid}", { + "pc": entry_inst, + "status": "init" + }) + + # 写入入口 CALL 指令 + base = f"/vthread/{vtid}/{entry_inst}" + self.kv.set(base, "call") # [0,0] = opcode + self.kv.set(f"{base}/-1", entry_func) # [0,-1] = func_name + + # 写入实参: [0,-2], [0,-3], ... + for i, (param_name, param_value) in enumerate(bindings.items()): + self.kv.set(f"{base}/{-i-2}", str(param_value)) + + # 默认返回值槽位 + self.kv.set(f"{base}/1", "./out") # [0,1] = 返回值 + + # 唤醒 VM + self.kv.push("notify:vm", {"event": "new_vthread", "vtid": vtid}) + + return vtid +``` + +### 任务 P4: 现有算子适配 + +当前 `tensor/elementwise.py` 等文件直接发送 HTTP 请求。 +改造为通过 KVSpace 写入 `/src/func/`: + +```python +# === 改造前 (当前模式) === +def add(A, B) -> C: + # 通过 HTTP 发送 deepxIR + send_ir({"op": "add", "A": A, "B": B, "C": C}) + +# === 改造后 (KVSpace 模式) === +def add(kv: KVSpace, A: str, B: str, out_name: str): + """ + 注册 add 函数源码到 KV 空间 + + Args: + kv: KVSpace 客户端 + A: 输入 tensor 的路径 (如 "./a", "/models/X") + B: 输入 tensor 的路径 + out_name: 输出变量名 (如 "c") + """ + func_name = f"add_{A.replace('/', '_')}_{B.replace('/', '_')}" + signature = f"(add(A:tensor, B:tensor) -> ({out_name}:tensor))" + instructions = [f"add({A}, {B}) -> ./{out_name}"] + kv.set_func(func_name, signature, instructions) + return func_name +``` + +**适配策略:** 先不改现有业务代码,新增一个 `kv_adapter.py` 模块, +提供与原接口兼容的适配函数。后续逐步将业务逻辑迁移到新接口。 + +### 任务 P5: Tensor 元信息辅助 + +```python +def tensor_meta(dtype: str, shape: list, device: str = "gpu0") -> dict: + """构造 tensor 元信息 (用于 newtensor 请求)""" + dtype_sizes = { + "f16": 2, "f32": 4, "f64": 8, "bf16": 2, + "i8": 1, "i16": 2, "i32": 4, "i64": 8, "u8": 1 + } + import math + count = math.prod(shape) + byte_size = count * dtype_sizes[dtype] + return { + "dtype": dtype, + "shape": shape, + "byte_size": byte_size, + "device": device + } +``` + +## 5. 依赖 + +```bash +pip install redis +``` + +## 6. 验证方法 + +```python +# test_kvspace.py +from deepx.kvspace import KVSpace, FuncCache, VThreadCreator + +def test_kvspace(): + kv = KVSpace("redis://localhost:6379") + + # 1. 写入函数源码 + kv.set_func("add_test", + signature="(add_test(A:tensor, B:tensor) -> (C:tensor))", + instructions=["add(A, B) -> C"] + ) + + # 2. 验证写入 + func = kv.get_func("add_test") + assert func["signature"] == "(add_test(A:tensor, B:tensor) -> (C:tensor))" + assert func["instructions"][0] == "add(A, B) -> C" + + # 3. FuncCache 测试 + cache = FuncCache(kv) + assert cache.set_if_changed("add_test", func["signature"], func["instructions"]) == True + # 第二次应跳过 + assert cache.set_if_changed("add_test", func["signature"], func["instructions"]) == False + + # 4. 创建 vthread (需要 VM 运行) + creator = VThreadCreator(kv) + kv.set("/vthread/1/a", "tensor_ref_a") # 模拟已有局部变量 + kv.set("/vthread/1/b", "tensor_ref_b") + vtid = creator.create("add_test", {"A": "./a", "B": "./b"}) + print(f"Created vthread: {vtid}") + + # 5. 等待执行 (VM 需运行) + state = kv.wait_vthread(vtid, timeout=5) + print(f"Vthread state: {state}") +``` + +## 7. 开发量评估 + +| 任务 | 新增代码 | 难度 | +|------|---------|------| +| P1: KVSpace 客户端 (~150 行 Python) | ~150 行 | 低 | +| P2: FuncCache (~40 行) | ~40 行 | 低 | +| P3: Vthread 创建器 (~60 行) | ~60 行 | 低 | +| P4: 现有算子适配 (kv_adapter 模块) | ~150 行 | 中 | +| P5: Tensor 元信息辅助 (~30 行) | ~30 行 | 低 | +| 现有 9,524 行业务代码逐步迁移 | 待评估 | 中 | +| **合计 (新增)** | **~500 行 Python** | **低** | + +## 8. 注意 + +1. **pysdk 是增量改造** — 不重写现有代码,先加 KVSpace 层,逐步迁移 +2. **FuncCache 关键** — 避免每轮训练重复发送相同 func 定义 +3. **Vthread 创建后通知 VM** — 通过 `notify:vm` 队列 +4. **编译阶段暂跳过** — 当前 pysdk 直接写 `/src/func/`,VM 在读时做 eager 展开。未来编译器阶段再引入 `/op//func/` diff --git a/doc/metaproc/metaproc-datastruct.md b/doc/metaproc/metaproc-datastruct.md new file mode 100644 index 00000000..c011aaf3 --- /dev/null +++ b/doc/metaproc/metaproc-datastruct.md @@ -0,0 +1,361 @@ +# 元程 — 数据结构 + +> 定义元程的基础数据类型、扩展数据类型,以及在 KV 空间中的布局。 + +--- + +## 1. 为什么单列一篇数据结构 + +C 语言的数据结构定义在 `` 和 `struct` 中,编译器在编译期确定每个变量的 +字节偏移。CUDA 继承了 C 的类型体系,加上 `__shared__` / `__constant__` 内存空间标注。 +这些类型定义是语言的基石——写了 `int x`,编译器就知道分配 4 字节。 + +元程没有编译器。数据类型不是编译器规则,而是**所有进程之间共享的约定**。 +基础类型(int / float / bool / string)直接存为 KV 空间的 value。 +扩展类型(tensor)的 value 是元信息 JSON,实际数据通过 shm 指针引用。 +KV 空间的路径布局就是元程的"内存映射"。 + +--- + +## 2. 基础数据类型 + +元程的基础类型直接以字面量或短字符串形式存储在 KV value 中,无需额外编码: + +| 类型 | 示例 value | value 大小 | 对应 C 类型 | 对应 CUDA | +|------|-----------|-----------|------------|----------| +| `int` | `1`, `-42`, `0` | ~1-11 字节(字符串) | `int` (4B) | `int` | +| `float` | `3.14`, `-0.5`, `1e-5` | ~4-15 字节 | `float` (4B) / `double` (8B) | `float` / `double` | +| `bool` | `true` / `false` | 4-5 字节 | `bool` (1B,通常) | `bool` | +| `string` | `"hello"`, `"f32"` | 取决于内容,建议 < 1KB | `char[]` / `char*` | `char*` | + +**基础类型可以直接作为 KV 的 value**: + +``` +/vthread/1/flag = true ← bool +/vthread/1/count = 42 ← int +/vthread/1/rate = 0.001 ← float +/vthread/1/name = "f32" ← string +``` + +与 C 的关键区别: + +| | C / CUDA | 元程 KV | +|---|---|---| +| 存储位置 | 栈 / 寄存器 / 显存 | Redis key 的 value | +| 寻址方式 | 指针 / 偏移量 | 字符串路径 | +| 类型检查 | 编译期 | 运行时(VM 解析 value) | +| 生命周期 | 作用域退出即释放 | 显式 DELETE | +| 精度控制 | `int` vs `int32_t` 编译器决定 | 约定决定(建议标注 dtype string) | + +**注意**:基础类型的 value 在 Redis 中存为字符串。Redis 的 int 操作(INCR)可用, +但建议统一走 GET→解析→运算→SET 路径以保证类型安全。 + +--- + +## 3. 扩展数据类型:Tensor + +### 3.1 定义 + +Tensor 是元程唯一的一级扩展数据类型。它是一个多维数组,其实际数据存储于外部内存 +(POSIX shm / GPU 显存),KV 空间中仅存储**元信息引用**: + +```json +{ + "dtype": "f32", + "shape": [1024, 512], + "byte_size": 2097152, + "device": "gpu0", + "address": { + "node": "n1", + "type": "shm", + "shm_name": "/deepx_t_abc123" + }, + "ctime": 1714000000, + "version": 5 +} +``` + +| 字段 | 类型 | 必需 | 说明 | +|------|------|------|------| +| `dtype` | string | 是 | `f16`, `f32`, `f64`, `bf16`, `bf8`, `i8`, `i16`, `i32`, `i64`, `u8`, `bool` | +| `shape` | array[int] | 是 | 如 `[1024, 512]`, `[2, 3, 224, 224]` | +| `byte_size` | int | 是 | element_count × dtype_size。对于 bool,1 bit → 向上取整到字节 | +| `device` | string | 是 | `gpu0`, `gpu1`, `cpu` | +| `address` | object | 是 | 物理地址信息 | +| `address.node` | string | 是 | 机器标识,如 `"n1"` | +| `address.type` | string | 是 | `"shm"` (POSIX共享内存) 或 `"gpu"` (GPU显存直接引用) | +| `address.shm_name` | string | 是 | POSIX shm 名称,如 `"/deepx_t_abc123"` | +| `ctime` | int | 否 | 创建时间戳 | +| `version` | int | 否 | 每次写入后递增 | + +### 3.2 dtype 字节宽度 + +| dtype | 字节数 | C/CUDA 对应 | 说明 | +|-------|--------|-----------|------| +| `f16` | 2 | `half` / `__half` | IEEE 754 half | +| `f32` | 4 | `float` | IEEE 754 single | +| `f64` | 8 | `double` | IEEE 754 double | +| `bf16` | 2 | `__nv_bfloat16` | Brain floating point | +| `bf8` | 1 | — | 8-bit brain float (E4M3 / E5M2) | +| `i8` | 1 | `int8_t` / `char` | | +| `i16` | 2 | `int16_t` / `short` | | +| `i32` | 4 | `int` / `int32_t` | | +| `i64` | 8 | `long long` / `int64_t` | | +| `u8` | 1 | `uint8_t` | | +| `bool` | 1 | `bool` | 每元素 1 bit,byte_size = ceil(n/8) | + +``` +byte_size 计算: + element_count = prod(shape) // shape 各维乘积 + byte_size = element_count * dtype_size // 普通类型 + byte_size = ceil(element_count / 8) // bool 类型 +``` + +### 3.3 与 C / CUDA 的数据结构对比 + +#### 3.3.1 数组声明对比 + +``` +C: + float A[1024][512]; ← 编译期确定大小,栈或数据段 + float* A = malloc(1024 * 512 * 4); ← 运行时分配,堆 + +CUDA: + float* d_A; + cudaMalloc(&d_A, 1024 * 512 * 4); ← 显存分配 + __shared__ float s_A[256]; ← 共享内存 + +元程: + GET /models/A → {dtype:"f32", shape:[1024,512], address:{shm_name:"/deepx_t_xxx"}} + ← KV 中存元信息 + shm_open("/deepx_t_xxx") → ptr ← 实际数据在 shm 中 +``` + +#### 3.3.2 核心差异 + +| 维度 | C | CUDA | 元程 | +|------|---|------|------| +| **数据位置** | 栈/堆(CPU 内存) | global/shared/constant 显存 | POSIX shm / GPU 显存(由 device 字段决定) | +| **寻址方式** | 指针 `float*` + 偏移 | 指针 `float*` + 偏移 | KV 路径 `GET /models/A` → 元信息 → shm ptr | +| **形状信息** | 丢失(退化为指针) | 丢失(退化为指针) | **保留**(shape 字段始终可查) | +| **类型信息** | 编译期已知,运行期丢失 | 同 C | **保留**(dtype 字段始终可查) | +| **跨进程传递** | 需序列化 + IPC | 需 cudaIpcOpenMemHandle | 路径字符串 + shm name(元程原生) | +| **生命周期** | 作用域/手动 free | 手动 cudaFree | heap-plat 管理 newtensor/deltensor | +| **多设备** | 不涉及 | 单 GPU,显式 cudaMemcpy | 元信息标注 device,heap-plat 负责跨设备克隆 | +| **并发安全** | 锁 | 原子操作 + 同步 | Redis LOCK / WATCH + 原子操作 | + +#### 3.3.3 形状/类型不丢失的设计收益 + +C 和 CUDA 中,数组一旦传给函数就退化为指针,丢失 shape 和类型信息: + +```c +// C: 函数签名看不到 shape +void matmul(float* A, float* B, float* C, int M, int N, int K); +// ↑ 仅指针 ↑ shape 靠额外参数传递 + +// CUDA: 同样问题,加上显存地址语义 +__global__ void matmul_kernel(float* A, float* B, float* C, int M, int N, int K); +``` + +元程中,shape 和 dtype 是 tensor 元信息的**一等字段**,不会丢失: + +``` +op-plat 收到的指令: +{ + "opcode": "matmul", + "inputs": [ + {"key": "/models/A", "dtype": "f32", "shape": [1024, 512], "address": {...}}, + {"key": "/models/B", "dtype": "f32", "shape": [512, 256], "address": {...}} + ], + "outputs": [ + {"key": "/vthread/1/c", "dtype": "f32", "shape": [1024, 256], "address": {...}} + ] +} +``` + +op-plat 从指令中直接获取 shape 和 dtype,无需额外查询。这消除了 C/CUDA 中 +"传指针的同时还要传 shape 参数"的惯例。 + +--- + +## 4. KV 空间中的类型布局 + +### 4.1 概念上的"内存映射" + +KV 空间的路径不是线性的字节地址,而是一棵键值树。不同类型的值分布在不同的子树中: + +``` +/ (根) +│ +├── 基础类型存放在 vthread 栈或系统路径中,value 是字面量 +├── Tensor 元信息分布在堆路径和 vthread 栈中,value 是 JSON +└── Tensor 实际数据不存 KV 空间,通过 shm_name 引用 +``` + +### 4.2 各路径区间的类型分布 + +| 路径区间 | 存放类型 | value 形式 | 读写者 | +|----------|---------|-----------|--------| +| `/vthread//` | **基础类型** 或 **tensor 元信息** | 字面量 或 JSON | VM, op-plat | +| `/vthread/` | **JSON**(状态对象) | `{"pc":"...","status":"..."}` | VM(写),pysdk(读) | +| `/vthread//[addr0, addr1]` | **string**(指令片段) | `"matmul"`, `"./a"`, `"3.14"` | VM(读写) | +| `/models/*`, `/data/*` | **tensor 元信息** | JSON | heap-plat(写),VM/op-plat(读) | +| `/src/func//N` | **string**(指令文本) | `"add(A, B) -> ./C"` | pysdk(写),VM(读) | +| `/op//func//N` | **string**(编译后指令) | 同上 | 编译器(写),VM(读) | +| `/sys/*` | **基础类型** 或 **JSON** | int 或 JSON | 各进程(注册/读取) | +| `/cmd/*`, `/done/*`, `/notify/*` | **JSON**(消息) | 见下方 | 各进程 | + +### 4.3 具体布局示例 + +``` +# === 基础类型 === +/vthread/1/flag = true ← bool, 直接存 +/vthread/1/iter = 0 ← int, 直接存 +/vthread/1/lr = 0.001 ← float, 直接存 +/sys/vtid_counter = 42 ← int, 直接存 (Redis INCR 可用) + +# === Tensor 元信息 (JSON) === +/models/bert/W = {"dtype":"f32","shape":[768,3072],"device":"gpu0","address":{...},"byte_size":...} +/models/bert/b = {"dtype":"f32","shape":[3072],"device":"gpu0","address":{...},"byte_size":...} +/vthread/1/mm = {"dtype":"f32","shape":[1024,256],"device":"gpu0","address":{...},"byte_size":...} +/vthread/1/out = {"dtype":"f32","shape":[1024,10],"device":"cpu","address":{...},"byte_size":...} + +# === 指令 (string) === +/src/func/gemm/0 = "matmul(A, B) -> ./Y" +/vthread/1/[0,0] = "matmul" +/vthread/1/[0,-1] = "/models/A" +/vthread/1/[0,-2] = "/models/B" +/vthread/1/[0, 1] = "./Y" + +# === vthread 状态 (JSON) === +/vthread/1 = {"pc":"[0,0]","status":"running"} + +# === 系统注册 (JSON) === +/sys/op-plat/cuda:0 = {"program":"op-cuda","device":"gpu0","status":"running","load":0.3} +/sys/config = {"max_vthreads":100,"timeout_ms":30000} + +# === 命令消息 (JSON, 存储在 List 中) === +cmd:op-cuda:0 ← RPUSH {"vtid":"1","pc":"[0,0]","opcode":"matmul","inputs":[...],"outputs":[...]} +done:1 ← LPUSH {"pc":"[0,0]","status":"ok"} +``` + +### 4.4 值大小约束 + +| 类型 | 最大 value 大小 | 原因 | +|------|----------------|------| +| 基础类型 (int/float/bool) | ~15 字节 | 字面量字符串 | +| string | 1 KB | 避免 Redis 大 key 性能问题 | +| Tensor 元信息 (JSON) | ~10 KB | 主要是 address 字段,shm_name 通常 < 64B | +| Tensor 实际数据 | **不存 KV** | 存外部 shm/显存,KV 仅存引用 | +| 命令消息 (JSON) | ~10 KB | 包含 tensor 元信息的完整拷贝 | + +**设计原则**:KV 空间存**描述**,外部存储存**数据**。这是元程与 C/CUDA 的又一 +根本差异——C 的 struct 是数据本身,元程的 value 是数据的**引用和描述**。 + +--- + +## 5. Tensor 元信息 vs C struct vs CUDA 内存描述符 + +### 5.1 三方对照 + +``` +C struct (在 CPU 内存中表示一个数组): + struct { + float* data; ← 指向实际数据 + int dims[4]; ← 形状(可选,取决于库) + int ndim; ← 维度数(可选) + int dtype; ← 类型枚举(可选,通常编译期已知) + } + 问题: 每个库自己定义 (PyTorch ATen, NumPy, Eigen 各不同) + +CUDA 内存描述符: + struct cudaPointerAttributes { + enum cudaMemoryType memoryType; ← host / device / managed + int device; ← GPU 编号 + void* devicePointer; ← 设备指针 + void* hostPointer; ← 主机指针(统一内存时) + } + 问题: 没有 shape / dtype,仅描述"指针属性",不是"数组属性" + +元程 Tensor 元信息 (JSON, 存于 KV): + { + "dtype": "f32", ← 类型始终可查 + "shape": [1024, 512], ← 形状始终可查 + "byte_size": 2097152, ← 预计算,O(1) 获取 + "device": "gpu0", ← 设备标注 + "address": { + "node": "n1", ← 机器 + "type": "shm", ← 内存类型 + "shm_name": "/deepx_t_xxx" ← 全局唯一标识 + }, + "version": 5 ← 变更追踪 + } + 优势: 自包含、跨语言、跨进程、携带完整描述 +``` + +### 5.2 跨设备传递对比 + +``` +场景: GPU0 上的 tensor 需要传给 GPU1 + +CUDA: + cudaMemcpyPeer(dst, 1, src, 0, size); ← 显式 P2P 拷贝 + // 需要知道: 源指针, 目标指针, 大小, 两个 GPU 编号 + // 调用者自己管理这些信息 + +元程: + 1. VM 发送 clonetensor 到 heap-plat: + {"op":"clonetensor", "src":"/models/W", "dst":"/models/W_gpu1", "device":"gpu1"} + 2. heap-plat GET /models/W → 获取完整元信息 (含当前 device) + 3. heap-plat 在 gpu1 上分配 shm → memcpy → 写入新元信息 + 4. 完成 + + // VM 只需要知道路径和目标 device,heap-plat 处理其余全部 +``` + +--- + +## 6. 类型安全与运行时检查 + +C/CUDA 的类型检查在编译期,运行时只有原始字节。元程的类型检查在运行时: + +| | C / CUDA | 元程 | +|---|---|---| +| 检查时机 | 编译期 `float* p = (float*)malloc(...)` | 运行时 op-plat 解析 dtype 字段 | +| 类型错误后果 | 编译失败(安全)或未定义行为(cast) | op-plat 返回 error: `"dtype mismatch: f32 vs f16"` | +| shape 不匹配 | 段错误 / 静默错误 | op-plat 返回 error: `"shape mismatch: [10] vs [20]"` | +| 内存越界 | 段错误 / 安全漏洞 | shm 大小由 byte_size 约束,op-plat 可做边界检查 | + +--- + +## 7. 汇总:元程数据类型体系 + +``` +元程数据类型 +├── 基础类型 (value = 字面量) +│ ├── int → "42" +│ ├── float → "3.14" +│ ├── bool → "true" +│ └── string → "hello" +│ +├── 扩展类型 (value = JSON 元信息,数据在外部) +│ └── tensor +│ ├── dtype : f16 | f32 | f64 | bf16 | bf8 | i8 | i16 | i32 | i64 | u8 | bool +│ ├── shape : [d1, d2, ...] +│ ├── byte_size: element_count × dtype_size +│ ├── device : gpu0 | gpu1 | cpu +│ └── address : {node, type, shm_name} +│ +└── 结构化类型 (value = JSON 对象,数据在 KV 中) + ├── vthread 状态 : {pc, status, error?} + ├── 进程注册 : {program, device, status, load} + ├── 算子元数据 : {category, dtype, max_shape, replaces?} + └── 命令消息 : {vtid, pc, opcode, inputs, outputs} +``` + +--- + +> **关联文档**: +> - [README.md](README.md) — 元程总篇(核心思想:程序=数据结构+函数+数据) +> - [spec-v1.md](spec-v1.md) — 元程规范 v1(抽象模型) +> - [redis-keys.md](redis-keys.md) — Redis Key 布局速查表 diff --git a/doc/metaproc/redis-keys.md b/doc/metaproc/redis-keys.md new file mode 100644 index 00000000..881702d9 --- /dev/null +++ b/doc/metaproc/redis-keys.md @@ -0,0 +1,522 @@ +# Redis Key 设计列表 + +> DeepX 元程系统使用的所有 Redis key 路径、value 类型和约定。 +> Redis 作为 KV 空间,key 是路径空间的唯一标识,value 存储结构化数据 (JSON)。 + +## 1. 路径空间总览 + +``` +/ (根) +├── src/func/ 函数源码层 (pysdk 写入, 人类可读) +├── op/ 算子注册与编译产物 +│ ├── op-cuda/ +│ ├── op-metal/ +│ └── op-cpu/ +├── vthread/ vthread 执行状态 (栈) +├── sys/ 系统信息 +├── cmd/ 命令队列 +├── notify/ 通知队列 +├── done/ 完成通知队列 +├── lock/ 互斥锁 +├── models/ 堆变量 (示例) +├── data/ 堆变量 (示例) +└── checkpoints/ 堆变量 (示例) +``` + +--- + +## 2. 源码层: /src/func/ + +pysdk 写入的函数源码,dxlang 人类可读文本格式。 + +| Key | Value 类型 | 示例 | 说明 | +|-----|-----------|------|------| +| `/src/func/` | string | `(add(A:tensor, B:tensor) -> (C:tensor))` | 函数签名 (dxlang) | +| `/src/func//N` | string | `add(A, B) -> ./C` | 第 N 条指令 | +| `/src/func//N/true/0` | string | `add(A, B) -> ./Y` | 分支 true 子块第 0 条 | +| `/src/func//N/false/0` | string | `sub(A, B) -> ./Y` | 分支 false 子块第 0 条 | +| `/src/func//N/body/0` | string | `add(./i, ./a) -> ./b` | for 循环体第 0 条 | + +**指令格式 (左读右写):** +``` +opcode(read_param_1, read_param_2, ...) -> write_param_1, write_param_2 +``` + +**示例:** +``` +/src/func/gemm = "(gemm(A:tensor, B:tensor, alpha:f32, beta:f32, C:tensor) -> (Y:tensor))" +/src/func/gemm/0 = "matmul(A, B) -> ./Y" +/src/func/gemm/1 = "mul(./Y, alpha) -> ./Y" +/src/func/gemm/2 = "mul(C, beta) -> ./C" +/src/func/gemm/3 = "add(./Y, ./C) -> ./Y" +``` + +--- + +## 3. 编译层: /op//func/ + +编译器读取 `/src/func/` 源码后产出的后端专属编译产物。VM CALL 时读取此层。 + +### 3.1 函数编译产物 + +| Key | Value 类型 | 示例 | 说明 | +|-----|-----------|------|------| +| `/op//func/` | string | `(gemm(...) -> (...))` | 编译后的函数签名 | +| `/op//func//N` | string | `fused_matmul_add(A, B, b) -> ./out` | 编译后第 N 条指令 | + +**注意:** 编译层的指令序号 N 和源码层的 N 可能不同(融合后 N 减少,拆分后 N 增加)。 + +``` +示例 (CUDA 融合): + /op/op-cuda/func/gemm/0 = "fused_matmul_mul_mul_add(A, B, alpha, C, beta) -> ./Y" + (源码层 4 条 → 编译层 1 条) + +示例 (Tensor 并行拆分): + /op/op-cuda/func/forward/0 = "slice(A, 0, 512) -> ./A_shard0" + /op/op-cuda/func/forward/1 = "slice(A, 512, 1024) -> ./A_shard1" + /op/op-cuda/func/forward/2 = "matmul(./A_shard0, W) -> ./out0 @gpu0" + /op/op-cuda/func/forward/3 = "matmul(./A_shard1, W) -> ./out1 @gpu1" + /op/op-cuda/func/forward/4 = "concat(./out0, ./out1) -> ./out" +``` + +### 3.2 算子注册 (程序级 — 同一程序的所有进程实例共享) + +| Key | Value 类型 | 示例 | 说明 | +|-----|-----------|------|------| +| `/op//list` | JSON array | `["matmul", "add", "relu", "sigmoid"]` | 该程序支持的全部算子 | +| `/op//` | JSON object | 见下方 | 单个算子的元数据 | + +**算子元数据格式:** +```json +{ + "category": "matmul | elementwise | reduce | changeshape | activation | fused | init", + "dtype": ["f32", "f16", "bf16"], + "max_shape": [8192, 8192, 8192], + "fusion_group": "linear" +} +``` + +**融合算子额外字段:** +```json +{ + "category": "fused", + "dtype": ["f32", "f16"], + "replaces": ["matmul", "add", "relu"] +} +``` + +**完整示例:** +``` +/op/op-cuda/list = ["matmul", "add", "mul", "relu", "sigmoid", "fused_matmul_add_relu"] +/op/op-cuda/matmul = {"category":"matmul", "dtype":["f32","f16","bf16"], "max_shape":[8192,8192,8192], "fusion_group":"linear"} +/op/op-cuda/relu = {"category":"activation", "dtype":["f32","f16","bf16"]} +/op/op-cuda/fused_matmul_add_relu = {"category":"fused", "dtype":["f32","f16"], "replaces":["matmul","add","relu"]} + +/op/op-metal/list = ["add", "sub", "mul", "div", "relu", "sigmoid", "tanh", "zeros", "ones"] +/op/op-metal/add = {"category":"elementwise", "dtype":["f32","f16","i32"]} +``` + +--- + +## 4. 执行层: /vthread/ + +VM 管理的 vthread 执行状态。指令展开为二维坐标 `[addr0, addr1]`,命名槽位为平级子 key。 + +### 4.1 Vthread 自身 + +| Key | Value 类型 | 示例 | 说明 | +|-----|-----------|------|------| +| `/vthread/` | JSON object | `{"pc":"[3,0]", "status":"running"}` | vtid 自身状态:pc + status + 可选 error | + +**pc 字段格式:** +- 根栈: `"[0,0]"`, `"[3,0]"` +- 子栈: `"[n,0]/[0,0]"`, `"[n,0]/[3,0]"` +- 深层嵌套: `"[2,0]/[1,0]/[0,0]"` + +**status 字段值:** + +| status | 含义 | +|--------|------| +| `init` | 已创建,待 VM 拾取 | +| `running` | VM 正在调度执行 | +| `wait` | 等待异步操作 (op-plat / heap-plat 完成) | +| `error` | 执行出错 | +| `done` | 执行完毕,可 GC | + +**error 字段 (仅 status=error 时存在):** +```json +{ + "pc": "[3,0]", + "status": "error", + "error": { + "code": "GPU_OOM", + "message": "out of memory: requested 2GB, available 1.5GB" + } +} +``` + +### 4.2 指令坐标 (二维寻址) + +| Key 模式 | Value 类型 | 含义 | addr1 规则 | +|----------|-----------|------|-----------| +| `/vthread//[addr0, 0]` | string | 操作码 | `0` = opcode | +| `/vthread//[addr0, -1]` | string | 读参数 #1 | `-N` = 第 N 个读取参数 | +| `/vthread//[addr0, -2]` | string | 读参数 #2 | | +| `/vthread//[addr0, 1]` | string | 写参数 #1 | `+N` = 第 N 个写入参数 | +| `/vthread//[addr0, 2]` | string | 写参数 #2 | | + +addr0 是序列维整数,表示指令在栈帧内的顺序位置。 + +**示例 (指令 `add(./a, ./b) -> ./c`):** +``` +/vthread/1/[0, 0] = "add" +/vthread/1/[0,-1] = "./a" +/vthread/1/[0,-2] = "./b" +/vthread/1/[0, 1] = "./c" +``` + +**示例 (指令 `matmul(A, B) -> ./Y`):** +``` +/vthread/1/[0, 0] = "matmul" +/vthread/1/[0,-1] = "/models/A" +/vthread/1/[0,-2] = "/models/B" +/vthread/1/[0, 1] = "./Y" +``` + +**CALL 指令示例:** +``` +/vthread/1/[0, 0] = "call" +/vthread/1/[0,-1] = "gemm" # func_name +/vthread/1/[0,-2] = "/models/A" # 实参1 +/vthread/1/[0,-3] = "/models/B" # 实参2 +/vthread/1/[0,-4] = "1.0" # 实参3 (立即数) +/vthread/1/[0, 1] = "./out" # 返回值绑定的槽位 +``` + +### 4.3 子栈 (CALL 产生) + +``` +/vthread//[n,0]/ ← 子栈根 +/vthread//[n,0]/[0,0] ← 子栈指令 #0 操作码 +/vthread//[n,0]/[0,-1] ← 子栈指令 #0 读参数 #1 +/vthread//[n,0]/[1,0] ← 子栈指令 #1 操作码 +/vthread//[n,0]/[m,0]/[0,0] ← 更深层嵌套 + +嵌套路径与 pc 的对应: + 根栈: pc = "[0,0]" + 根栈 CALL n 后: pc = "[n,0]/[0,0]" + 子栈 CALL m 后: pc = "[n,0]/[m,0]/[0,0]" +``` + +### 4.4 命名槽位 (局部变量) + +命名槽位是 `/vthread//` 下的平级子 key,与指令坐标 `[addr0, addr1]` 互不嵌套。 + +| Key | Value | 说明 | +|-----|-------|------| +| `/vthread//` | 基础类型 或 tensor 元信息 | 局部变量,与 dxlang 源码变量名一致 | + +**基础类型直存:** +``` +/vthread/1/a = 1 (int) +/vthread/1/b = 3.14 (float) +/vthread/1/flag = true (bool) +/vthread/1/name = "hello" (string, < 1KB) +``` + +**Tensor 类型:** +``` +/vthread/1/mm = {"dtype":"f32", "shape":[1024,256], "device":"gpu0", "address":{"type":"shm","shm_name":"/deepx_t_abc123"}, "byte_size":1048576} +``` + +命名槽位与指令坐标的平级关系: +``` +/vthread/1/ ← vtid 自身 +/vthread/1/[0,0] ← 指令坐标 +/vthread/1/[0,-1] ← 指令坐标 +/vthread/1/[0,1] ← 指令坐标 +/vthread/1/a ← 命名槽位 (平级) +/vthread/1/mm ← 命名槽位 (平级) +/vthread/1/[0,0]/[0,0] ← 子栈 (嵌套) +``` + +--- + +## 5. 系统路径: /sys/ + +| Key | Value 类型 | 示例 | 说明 | +|-----|-----------|------|------| +| `/sys/vtid_counter` | int | `42` | vthread ID 自增计数器 (原子 INCR) | +| `/sys/config` | JSON object | `{"max_vthreads": 100, "timeout_ms": 30000}` | 全局配置 | +| `/sys/op-plat/` | JSON object | 见下方 | op-plat 进程实例注册 | +| `/sys/heap-plat/` | JSON object | 见下方 | heap-plat 进程实例注册 | +| `/sys/vm/` | JSON object | `{"status":"running", "pid":12349, "started_at":1714000004}` | VM 实例注册 | + +**进程级注册格式:** + +``` +/sys/op-plat/cuda:0 = {"program":"op-cuda", "device":"gpu0", "status":"running", "load":0.3, "pid":12345, "started_at":1714000000} +/sys/op-plat/cuda:1 = {"program":"op-cuda", "device":"gpu1", "status":"running", "load":0.7, "pid":12346, "started_at":1714000001} +/sys/op-plat/metal:0 = {"program":"op-metal", "device":"gpu0", "status":"running", "load":0.1, "pid":12347, "started_at":1714000002} +/sys/heap-plat/metal:0 = {"program":"heap-metal", "device":"gpu0", "status":"running", "pid":12348, "started_at":1714000003} +/sys/vm/0 = {"status":"running", "pid":12349, "started_at":1714000004} +``` + +**实例命名规则:** `:`,如 `cuda:0`, `metal:0`,对应命令队列 `cmd:op-cuda:0`。 + +--- + +## 6. 命令队列: /cmd/ 和 /done/ 和 /notify/ + +使用 Redis List (FIFO 队列)。生产者 RPUSH,消费者 RPOP / BLPOP。 + +### 6.1 op-plat 命令队列 + +| Key | 消费方 | 生产者 | 说明 | +|-----|--------|--------|------| +| `cmd:op-cuda:0` | op-cuda 实例 0 | VM | CUDA 计算指令 | +| `cmd:op-metal:0` | op-metal 实例 0 | VM | Metal 计算指令 | +| `cmd:op-cpu:0` | op-cpu 实例 0 | VM | CPU 计算指令 | + +**消息格式 (JSON):** +```json +{ + "vtid": "1", + "pc": "[3,0]", + "opcode": "matmul", + "inputs": [ + { + "key": "/vthread/1/a", + "dtype": "f32", + "shape": [1024, 512], + "address": { + "node": "n1", + "device": "gpu0", + "type": "shm", + "shm_name": "/deepx_t_abc123", + "byte_size": 2097152 + } + } + ], + "outputs": [ + { + "key": "/vthread/1/c", + "dtype": "f32", + "shape": [1024, 256], + "address": { + "node": "n1", + "device": "gpu0", + "type": "shm", + "shm_name": "/deepx_t_def456", + "byte_size": 1048576 + } + } + ], + "params": {} +} +``` + +**批量消息 (可选):** +```json +{ + "batch": [ + {"pc": "[3,0]", "opcode": "add", "inputs": [...], "outputs": [...]}, + {"pc": "[4,0]", "opcode": "mul", "inputs": [...], "outputs": [...]} + ] +} +``` + +### 6.2 heap-plat 命令队列 + +| Key | 消费方 | 说明 | +|-----|--------|------| +| `cmd:heap-metal:0` | heap-metal | Metal 内存管理指令 | +| `cmd:heap-cuda:0` | heap-cuda | CUDA 内存管理指令 | +| `cmd:heap-cpu:0` | heap-cpu | CPU 内存管理指令 | + +**newtensor:** +```json +{"vtid":"1", "pc":"[0,0]", "op":"newtensor", "key":"/models/weights", "dtype":"f32", "shape":[1024,512], "device":"gpu0"} +``` + +**deltensor:** +```json +{"vtid":"1", "pc":"[5,0]", "op":"deltensor", "key":"/models/weights"} +``` + +**clonetensor:** +```json +{"vtid":"1", "pc":"[0,0]", "op":"clonetensor", "src":"/models/weights", "dst":"/models/weights_gpu1", "device":"gpu1"} +``` + +### 6.3 完成通知队列 + +| Key | 消费方 | 生产者 | 说明 | +|-----|--------|--------|------| +| `done:` | VM | op-plat / heap-plat | vthread 的完成通知 | + +**成功:** +```json +{"pc": "[3,0]", "status": "ok", "outputs_updated": [{"key": "/vthread/1/c", "new_shape": [1024, 256]}]} +``` + +**失败:** +```json +{"pc": "[3,0]", "status": "error", "error": {"code": "GPU_OOM", "message": "out of memory: requested 2GB, available 1.5GB"}} +``` + +### 6.4 VM 唤醒通知 + +| Key | 消费方 | 生产者 | 说明 | +|-----|--------|--------|------| +| `notify:vm` | VM | pysdk | 新 vthread 创建后唤醒 VM | + +```json +{"event": "new_vthread", "vtid": "42"} +``` + +--- + +## 7. 堆变量 (隐式命名空间) + +除保留路径外,KV 空间中的所有其他路径均为堆变量。value 存储 tensor 元信息。 + +### 7.1 Tensor 元信息格式 + +```json +{ + "dtype": "f32", + "shape": [1024, 512], + "byte_size": 2097152, + "device": "gpu0", + "address": { + "node": "n1", + "type": "shm", + "shm_name": "/deepx_t_abc123" + }, + "ctime": 1714000000, + "version": 5 +} +``` + +| 字段 | 类型 | 必需 | 说明 | +|------|------|------|------| +| dtype | string | 是 | `f32`, `f16`, `f64`, `bf16`, `i8`, `i16`, `i32`, `i64`, `u8`, `bool` | +| shape | array[int] | 是 | 如 `[1024, 512]` 或 `[2, 3, 224, 224]` | +| byte_size | int | 是 | element_count × dtype_size | +| device | string | 是 | `gpu0`, `gpu1`, `cpu` | +| address | object | 是 | 物理地址信息 | +| address.node | string | 是 | 机器标识 | +| address.type | string | 是 | `shm`, `gpu`, `cpu` | +| address.shm_name | string | 是 | POSIX shm 名称 | +| ctime | int | 否 | 创建时间 (unix timestamp) | +| version | int | 否 | 版本号 (每次写入递增) | + +### 7.2 命名约定建议 + +``` +/models/// +/data// +/checkpoints// +/rl// +``` + +示例: +``` +/models/bert/encoder_0/weights = {"dtype":"f32", "shape":[768,3072], ...} +/models/bert/encoder_0/bias = {"dtype":"f32", "shape":[3072], ...} +/data/cifar10/train = {"dtype":"f32", "shape":[50000,3,32,32], ...} +``` + +--- + +## 8. 锁 + +| Key | Value | 操作 | 说明 | +|-----|-------|------|------| +| `/lock/` | holder 标识符 | `SET NX EX` / Lua DEL | 排他锁 | + +锁 key 命名约定: `/lock/` + +``` +/lock/tensor:/models/weights → 保护 weights 的并发修改 +/lock/vthread:1 → 保护 vthread 1 的并发拾取 (也可用 WATCH/MULTI/EXEC) +``` + +--- + +## 9. Key 操作复杂度与约束 + +### 9.1 批量操作 + +| 模式 | Redis 命令 | 说明 | +|------|-----------|------| +| 读多个指令 | `MGET key1 key2 ...` | 一次往返 | +| 读多个 tensor 元信息 | `MGET /vthread/1/a /vthread/1/b /models/W` | 一次往返 | +| 写多条指令 | `Pipeline: SET ... SET ...` | 一次往返 | +| 原子抢占 vthread | `WATCH ... MULTI ... EXEC` | 事务 | +| 扫描 vthread | `KEYS /vthread/*` | 避免高频调用 | +| 递归删除子栈 | `KEYS prefix*` → 批量 `DEL` | 两步 | + +### 9.2 值大小约束 + +| 类型 | 约束 | 建议 | +|------|------|------| +| 短字符串 | < 1KB | 指令、签名 | +| JSON 对象 | < 10KB | tensor 元信息、状态 | +| Tensor 实际数据 | **不存 Redis** | 存外部 shm,Redis 仅存引用 | + +### 9.3 命名约束 + +- key 中不使用空格,用 `/` 分隔层级 +- 指令坐标: `[addr0, addr1]`,addr0 非负整数,addr1 整数 +- 命名槽位: 字母数字 + 下划线,与 dxlang 变量名一致 +- 实例名: `:`,如 `cuda:0`, `metal:0` + +--- + +## 10. Key 路径速查表 + +``` +# 源码层 +/src/func/ string 函数签名 +/src/func//N string 第 N 条指令 +/src/func//N/true/0 string 分支 true 子块 +/src/func//N/false/0 string 分支 false 子块 +/src/func//N/body/0 string for 循环体 + +# 编译层 +/op//func/ string 编译后签名 +/op//func//N string 编译后指令 + +# 算子注册 (程序级,所有实例共享) +/op//list array 算子列表 +/op// object 算子元数据 + +# 执行层 +/vthread/ object {pc, status, error?} +/vthread//[addr0, 0] string 操作码 +/vthread//[addr0, -N] string 读参数 (N=1,2,...) +/vthread//[addr0, +N] string 写参数 (N=1,2,...) +/vthread//[addr0,0]/[sub0,0] string 子栈操作码 +/vthread// any 命名槽位 (局部变量) + +# 系统 +/sys/vtid_counter int vthread ID 计数器 +/sys/config object 全局配置 +/sys/op-plat/ object op-plat 进程注册 +/sys/heap-plat/ object heap-plat 进程注册 +/sys/vm/ object VM 实例注册 + +# 命令队列 (List) +cmd:op-: list op-plat 命令队列 +cmd:heap-: list heap-plat 命令队列 +done: list 完成通知队列 +notify:vm list VM 唤醒通知 + +# 锁 +/lock/ string 排他锁 + +# 堆变量 (隐式) +/ object tensor 元信息 +``` diff --git a/doc/metaproc/spec-v1.md b/doc/metaproc/spec-v1.md new file mode 100644 index 00000000..e636fdf0 --- /dev/null +++ b/doc/metaproc/spec-v1.md @@ -0,0 +1,707 @@ +# 元程规范 (Metaproc Specification) v1 + +> 元程是一种基于 KV 寻址空间的分布式计算模型。 +> 本规范定义元程的抽象语义,不依赖任何具体实现(Redis、CUDA、Metal 等)。 +> 实现方只需提供一个满足 §2 要求的 KV 空间,即可运行元程程序。 + +元程系统由 5 个核心组成: + +| 核心 | 角色 | 说明 | +|------|------|------| +| KV 空间 | 全局状态存储 | 命名空间、命令队列、锁 | +| 算法前端 | 源码注册 | 将 func 定义写入 KV 空间,创建执行单元 | +| op-plat | 计算 | 被动消费指令,执行张量运算 | +| heap-plat | 堆管理 | tensor 对象生命周期:创建/删除/克隆 | +| VM | 解释执行 | CALL 翻译、指令路由、状态推进 | + +--- + +# 第一部分:元程模型 + +## 1. 核心概念 + +### 1.1 元程 (Metaproc) + +一个 **KV 空间** 就是一个元程。 + +类比:一个虚拟地址空间就是一个 OS 进程。 +KV 空间内的所有路径属于同一个元程,所有 vthread 共享该空间内的堆数据。 + +### 1.2 元线程 (Vthread) + +元程内的执行流。一个元程可以有一个或多个 vthread 并行执行。 +vthread 拥有私有的调用栈,但共享元程的堆数据。 + +类比:OS 进程内的线程。线程私有栈,共享堆。 + +### 1.3 函数 (Func) + +定义在 `/src/func/` 下的可复用代码单元。编译器优化后存于 `/op//func/`。 +所有 vthread 可以 CALL 同一个 func。 + +类比:共享库中的函数,或 ELF 的 .text 段中的代码。 + +### 1.4 堆 (Heap) + +元程内全局共享的长生命周期数据。所有非保留路径默认即堆。 +堆数据的 value 包含 dtype、shape、物理地址等元信息。 + +类比:OS 进程的 .data / .bss 段 + mmap 区域。 + +### 1.5 栈 (Stack) + +每个 vthread 私有的执行状态。包括指令序列、局部变量、PC。 +栈在 CALL 时扩展(创建子维度),RETURN 时收缩(删除子维度)。 + +类比:OS 线程的调用栈 (call stack)。 + +## 2. KV 空间要求 + +元程要求 KV 空间提供以下能力。实现方可使用任何满足这些要求的存储系统。 + +### 2.1 基本操作 + +| 操作 | 语义 | +|------|------| +| `GET(key)` | 读取 key 对应的 value | +| `SET(key, value)` | 写入 value,覆盖或新建 | +| `DELETE(key)` | 删除 key 及其所有子 key(递归) | +| `EXISTS(key)` | 判断 key 是否存在 | + +### 2.2 通知队列 + +KV 空间必须提供类似消息队列的通知机制: + +| 操作 | 语义 | +|------|------| +| `PUSH(list_key, value)` | 向队列尾部追加一条消息 | +| `POP(list_key, timeout)` | 阻塞地从队列头部取出一条消息。超时返回空 | +| `LEN(list_key)` | 返回队列当前长度 | + +队列是 FIFO 的。支持多个消费者竞争 POP(每条消息仅被一个消费者获取)。 + +### 2.3 锁 + +| 操作 | 语义 | +|------|------| +| `LOCK(key, holder, ttl)` | 尝试获取排他锁。若已被持有则失败。支持 TTL 自动释放 | +| `UNLOCK(key, holder)` | 释放锁 | + +### 2.4 值类型 + +KV 空间必须能直接存储以下基础类型(无需外部存储): + +- 整数 (int32, int64) +- 浮点数 (float32, float64) +- 布尔值 +- 短字符串 (建议 < 1KB) +- 结构化数据 (JSON 或等价格式) + +对于超大二进制数据(如 tensor 的实际数据),KV 空间仅存**引用**(物理地址),数据本身存于外部存储系统。 + +## 3. 路径空间约定 + +### 3.1 保留路径 + +以下路径前缀由元程运行时保留: + +| 路径 | 用途 | +|------|------| +| `/src/func/` | 函数源码 (平台无关的 dxlang 文本) | +| `/op//func/` | 后端编译产物 (融合、拆分、设备标注后的 func) | +| `/vthread/` | 所有 vthread 的执行状态 | +| `/sys/` | 系统信息(op-plat、heap-plat 注册等) | + +### 3.2 堆的隐式命名空间 + +除保留路径外,KV 空间中的所有其他路径均为堆路径。 +堆路径的具体命名空间划分由上层(如 pysdk)自行管理。 + +``` +例: + /models/bert/weights ← 堆 + /data/cifar10/train ← 堆 + /checkpoints/step1000 ← 堆 + /src/func/forward ← 函数源码 (保留) + /op/op-cuda/func/forward ← 编译产物 (保留) + /vthread/1/... ← 栈 (保留) +``` + +### 3.3 相对路径 + +vthread 栈内使用 `./` 前缀表示相对路径,解析为 `/vthread//` 下的平级命名槽位。 + +``` +./mm → /vthread//mm (命名槽位, 与 [0,0] 等指令坐标平级) +./bias → /vthread//bias +../a → /vthread//../a (不推荐,但技术上可行) +``` + +命名槽位和指令坐标 `[addr0, addr1]` 是 `/vthread//` 下的**平级子 key**: +- 命名槽位:人读的命名空间,存放 tensor 元信息 +- 指令坐标:机器执行的指令序列,VM 据此推进 PC + +## 4. 函数定义 + +### 4.1 三层架构 + +函数有三种表示形式,分别服务于不同的角色: + +| 层 | 位置 | 角色 | 格式 | +|----|------|------|------| +| 源码层 | `/src/func/` | pysdk 写入,人类可读 | dxlang 文本 | +| 编译层 | `/op//func/` | 编译器产出,VM 读取 | 后端优化后的 dxlang 文本 | +| 执行层 | `/vthread//` | VM 翻译后,机器执行 | `[addr0, addr1]` 二维坐标 | + +``` +数据流: + pysdk → /src/func/forward (源码) + │ + ▼ 编译器 (融合、拆分、设备标注) + /op/op-cuda/func/forward (CUDA 编译产物) + /op/op-metal/func/forward (Metal 编译产物, 可能不同) + │ + ▼ VM CALL 时读取 + eager 翻译 + /vthread//[n,0]/[0,0]... (执行层) +``` + +### 4.2 路径结构 + +``` +/src/func/ → 函数签名 (dxlang 文本) +/src/func//0 → 第 0 条指令 +/src/func//1 → 第 1 条指令 +... + +/op/op-cuda/func/ → 编译后的函数签名 +/op/op-cuda/func//0 → 第 0 条指令 (可能已融合/拆分) +... +``` + +### 4.3 函数签名 + +`/src/func/` 的值定义函数的类型签名: + +``` +(func_name((ro_p1:type1, ro_p2:type2, ...) -> (w_p1:type3, w_p2:type4, ...)) +``` + +- 左侧 `()` 内为只读参数 +- 右侧 `()` 内为写入参数(返回值) + +### 4.4 指令 + +`/src/func//N` 的值是一条指令。 + +**指令格式(左读右写):** + +``` +opcode(read_param_1, read_param_2, ...) -> write_param_1, write_param_2 +``` + +- `->` 左侧为读取的输入 +- `->` 右侧为写入的输出 +- 参数可以是 `./` 相对路径(局部变量)、绝对堆路径、或立即数 + +**控制流指令:** + +``` +if(cond) -> ← 分支入口 + /src/func//N/true/0 ← true 分支第一条指令 + /src/func//N/false/0 ← false 分支第一条指令 + +for(iterator) -> ← 循环入口 + /src/func//N/body/0 ← 循环体第一条指令 +``` + +### 4.5 编译层 + +编译器读取 `/src/func/` 的源码,产出后端专属的编译产物到 `/op//func/`: + +``` +/src/func/gemm/0 = matmul(A, B) -> ./Y +/src/func/gemm/1 = mul(./Y, alpha) -> ./Y +/src/func/gemm/2 = mul(C, beta) -> ./C +/src/func/gemm/3 = add(./Y, ./C) -> ./Y + + ↓ 编译器: 算子融合 + +/op/op-cuda/func/gemm/0 = fused_matmul_mul_mul_add(A, B, alpha, C, beta) -> ./Y +/op/op-metal/func/gemm/0 = matmul(A, B) -> ./tmp1 +/op/op-metal/func/gemm/1 = mul(./tmp1, alpha) -> ./tmp2 +/op/op-metal/func/gemm/2 = mul(C, beta) -> ./tmp3 +/op/op-metal/func/gemm/3 = add(./tmp2, ./tmp3) -> ./Y +``` + +不同后端的编译产物可以不同(CUDA 融合了 4→1,Metal 保持 4 条)。 +VM 在 CALL 时读取对应后端的 `/op//func/`,而非 `/src/func/`。 + +### 4.6 示例 + +``` +/src/func/gemm = (gemm(A:tensor, B:tensor, alpha:f32, beta:f32, C:tensor) -> (Y:tensor)) + +/src/func/gemm/0 = matmul(A, B) -> ./Y +/src/func/gemm/1 = mul(./Y, alpha) -> ./Y +/src/func/gemm/2 = mul(C, beta) -> ./C +/src/func/gemm/3 = add(./Y, ./C) -> ./Y +``` + +### 4.7 源码层与执行层 + +源码层(`/src/func/`、`/op//func/`)与执行层(`/vthread/`)的对比: + +| | 源码层 / 编译层 | 执行层 | +|---|---|---| +| 格式 | dxlang 文本:`matmul(A,B)->./Y` | 二维坐标:`[0,0]=matmul`, `[0,-1]=/models/A` | +| 目标 | 人读 / 编译器读写 | 机器高效寻址、零解析 | +| 翻译 | — | VM CALL 时 eager 翻译(§6.2) | + +### 4.8 op-plat 算子注册 + +op-plat **程序**(如 op-cuda)定义了一套算子能力。该程序的多个**进程实例** +(如 GPU0 上的 op-cuda、GPU1 上的 op-cuda)共享同一套算子注册。 + +**程序级路径(静态能力,所有实例共享):** + +``` +/op//list → ["matmul", "add", "relu", "fused_matmul_add_relu", ...] +/op// → 算子元数据 +/op//func/ → 该程序专属的编译产物 +``` + +**进程级路径(动态状态,每个实例独立):** + +``` +/sys/op-plat/ → {pid, device, status, load, started_at} +``` + +例:一台机器有 2 张 GPU,运行 2 个 op-cuda 实例: + +``` +程序级 (共享): + /op/op-cuda/list = ["matmul", "add", "relu", "fused_matmul_add_relu"] + /op/op-cuda/matmul = {"category":"matmul", "dtype":["f32","f16"], ...} + /op/op-cuda/func/forward/0 = fused_matmul_add_relu(A,B,b)->./out + +进程级 (独立): + /sys/op-plat/cuda:0 = {"device":"gpu0", "status":"running", "load":0.3} + /sys/op-plat/cuda:1 = {"device":"gpu1", "status":"running", "load":0.7} +``` + +**算子元数据:** + +``` +/op/op-cuda/matmul = { + "category": "matmul", + "dtype": ["f32", "f16", "bf16"], + "max_shape": [8192, 8192, 8192], + "fusion_group": "linear" +} + +/op/op-cuda/fused_matmul_add_relu = { + "category": "fused", + "dtype": ["f32", "f16"], + "replaces": ["matmul", "add", "relu"] +} +``` + +**编译器使用注册信息:** + +``` +融合决策: + 1. GET /op/op-cuda/list → 过滤 category="fused" 的算子 + 2. 对每个 fused 算子,读取 replaces 列表 + 3. 扫描 /src/func/ 指令序列,滑动窗口匹配 replaces 模式 + 4. 匹配成功 → 等价替换 → 写入 /op/op-cuda/func/ + +拆分决策: + 1. GET /op/op-cuda//max_shape → 单卡能力上限 + 2. 对比 tensor 实际 shape → 超出则拆分为多个子算子 + 3. 子算子标注目标设备 → 写入 /op/op-cuda/func/ +``` + +**VM 使用注册信息:** + +``` +指令路由: + GET /op/op-cuda/list → 含 "matmul" + GET /op/op-metal/list → 不含 "matmul" + → VM 将 matmul 路由到 op-cuda 的某个空闲实例 + → 实例选择基于 /sys/op-plat/cuda:* 的 load 信息 +``` + +## 5. Vthread 执行模型 + +### 5.1 Vthread 路径结构 + +``` +/vthread/ → {"pc":"[2,0]", "status":"running"} ← vthread 自身 (含 PC 和状态) +/vthread//[0,0] ← 指令 #0 的操作码 +/vthread//[0,-1] ← 指令 #0 的读参数 #1 +/vthread//[0,-2] ← 指令 #0 的读参数 #2 +/vthread//[0, 1] ← 指令 #0 的写参数 #1 +/vthread//[1,0] ← 指令 #1 的操作码 +... +/vthread//a ← 命名槽位 a (与 [0,0] 平级,非嵌套) +/vthread//b ← 命名槽位 b +/vthread//mm ← 局部变量 mm (./mm 解析结果) +/vthread//[2,0]/[0,0] ← 子栈: 指令 #2 是 CALL,其子栈的指令 #0 +``` + +命名槽位 (`/vthread//a`) 与指令坐标 (`/vthread//[0,0]`) 是 `/vthread//` 下的**平级子 key**。 +它们互不嵌套:命名槽位供 tensor 读写,指令坐标供 VM 推进 PC。 + +### 5.2 二维寻址 + +vthread 的指令序列使用二维坐标 `[addr0, addr1]` 寻址: + +| addr1 | 含义 | 示例 | +|-------|------|------| +| `0` | 操作码 | `call`, `add`, `matmul`, `newtensor` | +| `-1, -2, ...` | 读取参数 (左值) | func 名、输入 tensor、立即数 | +| `1, 2, ...` | 写入参数 (右值) | 输出 tensor、返回值 | + +addr0 是序列维,表示指令在栈帧内的顺序位置。 + +### 5.3 命名槽位 + +命名槽位是 `/vthread//` 下的平级子 key(如 `/vthread//a`), +与指令坐标 `[addr0, addr1]` 互不嵌套,用于存放局部变量的值。 + +基础类型(int, float, bool, 短 string)的值直接存储在槽位中。 +Tensor 类型的值存储 tensor 元信息(dtype, shape, 物理地址)。 + +命名槽位在 `/vthread/` 下的路径名与 dxlang 源码中的变量名一致,保证可调试性: +``` +源码: matmul(A, B) -> ./mm +执行: /vthread/1/[0,0] = matmul, /vthread/1/[0,-1] = /models/A, ... + /vthread/1/mm = {dtype:"f32", shape:[1024,256], address:{...}} +``` + +### 5.4 程序计数器 (PC) + +PC 是 `/vthread/` 的 value 中的 `pc` 字段,指向当前待执行指令的坐标。 + +``` +/vthread/1 = {"pc": "[3,0]", "status": "running"} +``` + +`pc` 的值如 `[3,0]` 表示当前位于栈帧指令序列的第 3 条指令。 + +### 5.5 Vthread 状态 + +Vthread 状态是 `/vthread/` 的 value 中的 `status` 字段: + +``` +/vthread/1 = {"pc": "[3,0]", "status": "wait"} +``` + +| 状态 | 含义 | +|------|------| +| `init` | 已创建,待 VM 拾取 | +| `running` | VM 正在调度执行 | +| `wait` | 等待异步操作(op-plat 完成 / 锁释放) | +| `error` | 执行出错 | +| `done` | 执行完毕,栈待 GC | + +## 6. CALL 与子栈 + +### 6.1 CALL 指令 + +当 vthread 执行到 `call` 指令时: + +``` +/vthread//[n, 0] = call +/vthread//[n,-1] = ← 被调用的函数名 +/vthread//[n,-2] = ← 只读参数 +/vthread//[n,-3] = +... +/vthread//[n, 1] = ← 返回值绑定的槽位 +``` + +### 6.2 子栈创建 — VM 翻译 (eager) + +VM 在 CALL 时**一次性**将 `/op//func//` 的编译层指令翻译为执行层格式, +复制到 vthread 的子栈。 + +**翻译是 eager 的**(而非 lazy):所有指令在 CALL 时完成翻译,执行时零解析开销。 + +VM 读取的是编译层(`/op//func/`),而非源码层(`/src/func/`)。 +编译层已经是编译器优化后的产物(融合、拆分、设备标注已完成), +VM 只需做形参替换和坐标展开。 + +``` +翻译前 (/op/op-cuda/func/ 编译层): 翻译后 (/vthread/ 执行层, 子栈): +──────────────────────────────── ────────────────────────────── +/op/op-cuda/func/gemm = (签名) /vthread/1/[n,0]/ ← 子栈根 + +/op/op-cuda/func/gemm/0 /vthread/1/[n,0]/[0,0] = matmul + = matmul(A, B) -> ./Y /vthread/1/[n,0]/[0,-1] = /models/A + /vthread/1/[n,0]/[0,-2] = /models/B + /vthread/1/[n,0]/[0, 1] = ./Y + +/op/op-cuda/func/gemm/1 /vthread/1/[n,0]/[1,0] = mul + = mul(./Y, alpha) -> ./Y /vthread/1/[n,0]/[1,-1] = ./Y + /vthread/1/[n,0]/[1,-2] = 1.0 + /vthread/1/[n,0]/[1, 1] = ./Y +``` + +**翻译步骤:** + +``` +1. 从 CALL 指令的读参数中获取 backend 和 func_name +2. GET /op//func/ → 函数签名,获得形参列表 +3. 从 CALL 指令的读参数中提取实参: [n,-2]=/models/A, [n,-3]=/models/B, ... +4. 建立形参→实参映射: {A: /models/A, B: /models/B, alpha: 1.0, ...} +5. 批量 MGET /op//func//0, /op/.../1, ... (一次 Redis 往返) +6. 对每条编译层指令: + a) 解析 dxlang 字符串: opcode + 读参数列表 + 写参数列表 + b) 形参替换: 将形参名替换为实参值 + c) 展开为执行层 key: [i,0]=opcode, [i,-1]=param1, [i,1]=out1, ... + d) 批量 SET 到子栈路径 /vthread//[n,0]/ +7. VM 设置 PC 指向子栈第一条指令: SET pc = "[n,0]/[0,0]" +``` + +**复制规则:** + +- 编译层内部的 `./` 相对路径参数,保持 `./` 形式不变 +- 编译层引用外部堆变量的形参,替换为调用者传入的实参 +- 立即数形参替换为具体数值 +- 命名槽位(如 `./mm`)的 key 在 `/vthread//` 下创建(平级),不受子栈嵌套影响 + +### 6.3 嵌套调用 + +func A CALL func B,func B CALL func C,形成多层嵌套: + +``` +/vthread//[0,0]/ ← A 调用 B 的子栈 +/vthread//[0,0]/[3,0]/ ← B 调用 C 的子栈 (假设 B 的第3条指令是 call) +``` + +最大嵌套深度由实现定义(建议 ≥ 32)。 + +### 6.4 RETURN 语义 + +当 func 执行完毕(到达指令序列末尾,或遇到 `return` 指令): + +1. 返回值写入调用者 CALL 指令指定的写参数槽位 +2. 子栈的 KV 路径被 DELETE(递归删除所有子 key) +3. PC 恢复到调用者 CALL 指令的下一条指令(addr0 + 1) + +## 7. 异步执行模型 + +元程的执行是全异步的。VM 不直接执行计算,而是将计算指令分发给 op-plat。 + +### 7.1 执行循环 + +``` +VM 循环: + 1. GET /vthread/ → {pc: "[n,0]", status: "running", ...} + 2. 读取 /vthread//[n,0] → opcode + 3. 判断指令类型: + a) 张量计算指令 (add, matmul, relu, ...): + → 将指令 + 参数打包,PUSH 到目标 op-plat 的命令队列 + → SET /vthread/ = {pc: "[n,0]", status: "wait", ...} + → BLPOP vthread 的完成通知队列 + → op-plat 计算完毕,LPUSH 完成事件 + → VM 醒来,SET /vthread/ = {pc: "[n+1,0]", status: "running", ...} + + b) 控制流指令 (call, if, for, return): + → VM 直接处理 (无需 op-plat) + → call: 读取 /op//func/ (编译层), eager 翻译到子栈 + → 更新 /vthread/ 的 pc 字段 + + c) 生命周期指令 (newtensor, deltensor): + → PUSH 到 heap-plat 的命令队列 + → 轻量操作,通常同步完成 +``` + +### 7.2 异步通知机制 + +``` +VM 与 op-plat 之间的通信: + + VM op-plat 命令队列 op-plat + ── ────────────── ─────── + PUSH cmd: ───→ [cmd1, cmd2, cmd3] ───→ RPOP 消费 + │ + GPU 计算 + │ + BLPOP done: ←─── [done_event] ←────────────── LPUSH 完成事件 + │ + 醒来,PC++ +``` + +### 7.3 批量发射 + +VM 可分析指令序列,将连续的多条无依赖指令批量 PUSH 到 op-plat 队列,减少往返。 + +``` +无依赖序列: [add(a,b)->c, mul(d,e)->f, relu(g)->h] + → 三条指令的输入互不依赖对方的输出 + → VM 一次 PUSH 三条,op-plat 可并行或流水线执行 + → 三条全部完成后,一次 LPUSH 通知 +``` + +## 8. op-plat 抽象契约 + +op-plat 是执行张量计算指令的被动进程。 + +### 8.1 必须实现 + +| 能力 | 说明 | +|------|------| +| 指令消费 | 从指定的通知队列 RPOP 消费指令 | +| 张量计算 | 执行至少一种张量运算(如 elementwise、matmul、reduce) | +| 完成通知 | 计算完成后 LPUSH 完成事件到指定队列 | +| Tensor 访问 | 根据 tensor 元信息中的物理地址,访问外部存储中的实际数据 | + +### 8.2 指令格式 + +``` +op-plat 消费的指令包含: + { + "opcode": "add", ← 操作码 + "vtid": "1", ← 所属 vthread + "pc": "[3,0]", ← 对应 vthread 中的指令坐标 + "inputs": [ ← 输入 tensor 元信息 + {"key": "/vthread/1/a", "dtype": "f32", "shape": [1024], + "address": {"node": "n1", "device": "gpu0", "shm": "/deepx_t_xxx"}} + ], + "outputs": [ ← 输出 tensor 元信息 + {"key": "/vthread/1/c", "dtype": "f32", "shape": [1024], + "address": {"node": "n1", "device": "gpu0", "shm": "/deepx_t_yyy"}} + ] + } +``` + +### 8.3 完成通知格式 + +``` +op-plat 计算完成后: + LPUSH done: { + "pc": "[3,0]", + "status": "ok", + "outputs_updated": ["/vthread/1/c"] + } +``` + +## 9. heap-plat 抽象契约 + +heap-plat 管理 tensor 的生命周期。 + +### 9.1 必须实现 + +| 能力 | 说明 | +|------|------| +| Tensor 创建 | 分配外部存储空间,返回物理地址 | +| Tensor 删除 | 释放外部存储空间 | +| Tensor 克隆 | 在指定设备上创建 tensor 副本 | +| 元信息查询 | 根据 key 返回 tensor 的 dtype、shape、物理地址 | + +### 9.2 指令格式 + +``` +heap-plat 消费的指令: + newtensor: {op: "newtensor", key: "/models/weights", dtype: "f32", shape: [1024,512], device: "gpu0"} + deltensor: {op: "deltensor", key: "/models/weights"} + clone: {op: "clone", src: "/models/weights", dst: "/models/weights_copy", device: "gpu1"} +``` + +## 10. 生命周期 + +### 10.1 Vthread 生命周期 + +``` +CREATE: + pysdk 或编译器在 /vthread/ 下创建新的 vtid 子树 + SET /vthread/ = {"pc": "[0,0]", "status": "init"} + 写入入口指令 (call main) + +EXECUTE: + VM 拾取 status="init" 的 vthread + 进入执行循环 (§7.1) + SET /vthread/ = {..., "status": "running"} + +WAIT: + 张量计算指令发射后 + SET /vthread/ = {..., "status": "wait"} + VM 阻塞在完成通知队列上 + +ERROR: + 指令执行失败 + SET /vthread/ = {..., "status": "error", "error": "..."} + +DONE: + vthread 执行完毕 + SET /vthread/ = {..., "status": "done"} + VM 清理 /vthread// 子树 (GC) +``` + +### 10.2 堆变量生命周期 + +``` +创建: newtensor → heap-plat 分配外部存储 → SET 元信息到堆路径 +使用: vthread 通过堆路径引用 tensor +删除: deltensor → heap-plat 释放外部存储 → DELETE 堆路径 + +引用计数 (实现可选): + 多个 vthread 可能引用同一堆 tensor + 实现方可使用引用计数管理,refcount=0 时自动回收 +``` + +--- + +# 第二部分:DeepX 实现(待起草) + +> 本部分将定义元程模型在 DeepX 中的具体实现,包括: +> - Redis 作为 KV 空间的具体 key 路径约定 +> - op-cuda / op-metal 的具体协议 +> - heap-cuda / heap-metal 的具体协议 +> - VM 进程的启动与调度实现 +> - pysdk 与编译器的接口 + +--- + +## 附录 A:与 OS 进程的对照 + +| OS 进程概念 | 元程对应 | +|------------|---------| +| 虚拟地址空间 | KV 空间 | +| 进程 | 一个 KV 空间实例 | +| 线程 | Vthread | +| 代码段 (.text) | /src/func/ (源码) + /op//func/ (编译) | +| 堆段 (.data/.bss) | 非保留路径 (堆变量) | +| 栈段 | /vthread// 子树 | +| 程序计数器 (PC) | /vthread/ 的 value.pc 字段 | +| 栈帧 | CALL 产生的子维度 | +| 系统调用 | heap-plat / op-plat 命令 | +| 文件描述符 | 堆 tensor 引用 | +| 互斥锁 | LOCK/UNLOCK | + +## 附录 B:与 x86 指令的对照 + +| x86 指令 | 元程等价 | +|----------|---------| +| `mov [addr], value` | `SET /heap/x = value` | +| `add eax, ebx` | `add(./a, ./b) -> ./c` | +| `push rax` | 写入命名槽位 | +| `call func` | `call` 指令 → 复制 func 到子栈 | +| `ret` | RETURN → 删除子栈 | +| `jmp label` | PC 跳转 | +| `cmp + jcc` | `if(cond) ->` | +| `int 0x80` | PUSH 到 op-plat 命令队列 | + +## 附录 C:术语表 + +| 术语 | 英文 | 定义 | +|------|------|------| +| 元程 | Metaproc | 一个 KV 空间实例,分布式计算的边界 | +| 元线程 | Vthread | 元程内的执行流 | +| 函数 | Func | /src/func/ 下的可复用代码单元 (编译器优化后→/op//func/) | +| 堆变量 | Heap Variable | 全局共享的长生命周期数据 | +| 栈帧 | Stack Frame | CALL 产生的子维度 | +| 二维寻址 | 2D Addressing | [addr0, addr1] 指令坐标系统 | +| 计算平面 | op-plat | 执行张量计算的后端进程 | +| 堆平面 | heap-plat | 管理 tensor 生命周期的后端进程 | diff --git a/doc/op-plat/README.md b/doc/op-plat/README.md new file mode 100644 index 00000000..ce1f37c1 --- /dev/null +++ b/doc/op-plat/README.md @@ -0,0 +1,238 @@ +# op-plat 设计 + +> op-plat 是 DeepX 元程的 **计算平面**,负责执行张量运算。 +> 本文档定义 op-plat 的抽象契约和所有实现的共用规范。 + +## 1. 定位 + +在元程 5 核架构中,op-plat 负责: + +| 核心 | 角色 | +|------|------| +| VM | 将计算指令路由到 op-plat | +| op-plat | 被动消费指令,执行 GPU/CPU 张量计算 | +| Redis | 存储算子注册信息、tensor 元信息 | +| heap-plat | 提供 tensor 的物理存储 (shm) | + +``` +VM PUSH 计算指令 ──→ cmd:op-: + │ + op-plat 消费 + │ + 根据 key 从 Redis GET tensor 元信息 + 通过 shm_name 映射 GPU/CPU 指针 + 执行 GPU kernel + 更新输出 tensor 元信息 + │ +VM BLPOP done: ←── LPUSH 完成通知 +``` + +## 2. 抽象契约 + +任何 op-plat 实现必须满足以下契约。 + +### 2.1 能力矩阵 + +| 能力 | 必须 | 说明 | +|------|------|------| +| 指令消费 | 是 | RPOP/BLPOP 从命令队列消费指令 | +| 张量计算 | 是 | 至少实现一种张量运算 | +| 完成通知 | 是 | LPUSH 完成事件到 done: | +| Tensor 访问 | 是 | 根据元信息中 shm_name 映射物理地址 | +| 算子注册 | 是 | 启动时注册到 /op//list | +| 进程注册 | 是 | 启动时注册到 /sys/op-plat/ | +| 批量执行 | 否 | 可选,支持 batch 指令并行执行 | + +### 2.2 通信模型 + +``` +消费: RPOP / BLPOP cmd:op-: + ↓ +解析: 提取 opcode, vtid, pc, inputs, outputs, params + ↓ +获取: GET → tensor 元信息 → shm_open + mmap → GPU ptr + ↓ +执行: dispatch_kernel(opcode) → GPU kernel + ↓ +更新: SET → 更新 tensor 元信息 (version++, shape) + ↓ +通知: LPUSH done: {pc, status, outputs_updated} +``` + +### 2.3 指令格式 + +VM 发送到 `cmd:op-:` 的指令: + +```json +{ + "vtid": "1", + "pc": "[3,0]", + "opcode": "matmul", + "inputs": [ + { + "key": "/vthread/1/a", + "dtype": "f32", + "shape": [1024, 512], + "address": { + "node": "n1", + "device": "gpu0", + "type": "shm", + "shm_name": "/deepx_t_abc123", + "byte_size": 2097152 + } + } + ], + "outputs": [ + { + "key": "/vthread/1/c", + "dtype": "f32", + "shape": [1024, 256], + "address": { + "node": "n1", + "device": "gpu0", + "type": "shm", + "shm_name": "/deepx_t_def456", + "byte_size": 1048576 + } + } + ], + "params": { + "transpose_a": false, + "transpose_b": true + } +} +``` + +**批量指令 (可选):** +```json +{ + "batch": [ + {"pc": "[3,0]", "opcode": "add", "inputs": [...], "outputs": [...]}, + {"pc": "[4,0]", "opcode": "mul", "inputs": [...], "outputs": [...]} + ] +} +``` + +### 2.4 完成通知格式 + +```json +{ + "pc": "[3,0]", + "status": "ok", + "outputs_updated": [ + {"key": "/vthread/1/c", "new_shape": [1024, 256]} + ] +} +``` + +错误: +```json +{ + "pc": "[3,0]", + "status": "error", + "error": { + "code": "GPU_OOM", + "message": "out of memory: requested 2GB, available 1.5GB" + } +} +``` + +### 2.5 算子注册 (程序级) + +op-plat **程序**的算子能力是静态的,所有**进程实例**共享: + +``` +/op//list → JSON array 支持的算子列表 +/op// → JSON object 算子元数据 +/op//func// → dxlang text 编译产物 (编译器写入) +``` + +启动时注册——第一个实例负责写入 (SET NX),后续实例只读取: + +``` +SET /op/op-metal/list = ["add", "sub", "mul", "div", "relu", "sigmoid", ...] + +SET /op/op-metal/add = { + "category": "elementwise", + "dtype": ["f32", "f16", "i32"], + "inputs": 2, + "outputs": 1 +} + +SET /op/op-metal/matmul = { + "category": "matmul", + "dtype": ["f32", "f16", "bf16"], + "max_shape": [8192, 8192, 8192], + "fusion_group": "linear", + "inputs": 2, + "outputs": 1, + "params": ["transpose_a", "transpose_b"] +} +``` + +### 2.6 算子分类 + +| category | 算子示例 | 输入 | 输出 | +|----------|---------|------|------| +| elementwise | add, sub, mul, div | 2 | 1 | +| activation | relu, sigmoid, tanh, gelu | 1 | 1 | +| matmul | matmul | 2 | 1 | +| reduce | sum, mean, max | 1 | 1 | +| changeshape | reshape, transpose, concat, slice | N | 1 | +| init | zeros, ones, arange | 0 | 1 | +| fused | fused_matmul_add_relu | N | 1 | + +### 2.7 进程注册 + +``` +/sys/op-plat/metal:0 = { + "program": "op-metal", + "device": "gpu0", + "status": "running", + "load": 0.3, + "pid": , + "started_at": +} +``` + +## 3. 消费者循环 (所有实现共用) + +``` +1. 启动: + a) 设备初始化 (GPU context / Metal device) + b) 算子注册: SET /op//list + /op// + c) 进程注册: SET /sys/op-plat/ + +2. 循环: + a) RPOP/BLPOP cmd:op-: + b) 解析 JSON → opcode, vtid, pc, inputs, outputs, params + c) 获取 tensor 指针: + for each input: + GET → tensor 元信息 + shm_open(shm_name) + mmap → CPU ptr + GPU context 包装 → GPU ptr (newBufferWithBytesNoCopy / cudaHostRegister) + d) dispatch_kernel(opcode, inputs, outputs, params) + e) 更新输出 tensor 元信息 (version++, shape) + f) LPUSH done: 完成通知 + +3. 退出: + a) DELETE /sys/op-plat/ +``` + +## 4. 各平台实现概览 + +| 实现 | 目录 | 状态 | GPU | +|------|------|------|-----| +| [op-metal](op-metal.md) | executor/op-metal/ | 待改造 (已有 1,325 行) | Metal (Apple Silicon) | +| [op-cuda](op-cuda.md) | executor/op-cuda/ | 已成熟 (47 文件) | CUDA (NVIDIA) | +| [op-cpu](op-cpu.md) | executor/op-cpu/ | 待开发 | 纯 CPU | + +## 5. 待确定问题 + +| 问题 | 状态 | +|------|------| +| 指令中发 GPU 指针 vs shm_name | VM 发完整 tensor 元信息,op-plat 自行映射 | +| 批量发射依赖分析 | 当前 VM 逐条发送 | +| op-plat 负载上报 | load 字段由 op-plat 自行更新到 /sys/ | +| 多 stream 并行 | Metal command queue / CUDA stream,待实现 | +| 路径缓存失效 | 当前每次重新 GET Redis | diff --git a/docs/benchmark/broadcast.md b/doc/op-plat/benchmark/broadcast.md similarity index 100% rename from docs/benchmark/broadcast.md rename to doc/op-plat/benchmark/broadcast.md diff --git a/docs/benchmark/matmul.md b/doc/op-plat/benchmark/matmul.md similarity index 100% rename from docs/benchmark/matmul.md rename to doc/op-plat/benchmark/matmul.md diff --git a/docs/benchmark/reduce.md b/doc/op-plat/benchmark/reduce.md similarity index 100% rename from docs/benchmark/reduce.md rename to doc/op-plat/benchmark/reduce.md diff --git a/docs/executor/executor.md b/doc/op-plat/contribute.md similarity index 100% rename from docs/executor/executor.md rename to doc/op-plat/contribute.md diff --git a/docs/executor/op-mem-ompsimd/contribute.md b/doc/op-plat/cpu/contribute.md similarity index 100% rename from docs/executor/op-mem-ompsimd/contribute.md rename to doc/op-plat/cpu/contribute.md diff --git a/docs/executor/op-mem-ompsimd/list.md b/doc/op-plat/cpu/op-list.md similarity index 100% rename from docs/executor/op-mem-ompsimd/list.md rename to doc/op-plat/cpu/op-list.md diff --git a/docs/executor/op-mem-ompsimd/range.md b/doc/op-plat/cpu/range.md similarity index 100% rename from docs/executor/op-mem-ompsimd/range.md rename to doc/op-plat/cpu/range.md diff --git a/docs/executor/op-mem-cuda/cublas/api.md b/doc/op-plat/cuda/cublas-api.md similarity index 100% rename from docs/executor/op-mem-cuda/cublas/api.md rename to doc/op-plat/cuda/cublas-api.md diff --git a/docs/executor/op-mem-cuda/cublaslt/api.md b/doc/op-plat/cuda/cublaslt-api.md similarity index 100% rename from docs/executor/op-mem-cuda/cublaslt/api.md rename to doc/op-plat/cuda/cublaslt-api.md diff --git a/docs/executor/op-mem-cuda/list.md b/doc/op-plat/cuda/op-list.md similarity index 100% rename from docs/executor/op-mem-cuda/list.md rename to doc/op-plat/cuda/op-list.md diff --git a/docs/executor/mix_precision.md b/doc/op-plat/mix-precision.md similarity index 100% rename from docs/executor/mix_precision.md rename to doc/op-plat/mix-precision.md diff --git a/doc/op-plat/op-cpu.md b/doc/op-plat/op-cpu.md new file mode 100644 index 00000000..83b1ae56 --- /dev/null +++ b/doc/op-plat/op-cpu.md @@ -0,0 +1,80 @@ +# op-cpu + +> 纯 CPU 计算的 op-plat 实现。待开发。作为最小化参考实现和测试基准。 + +## 1. 平台特性 + +| 特性 | 说明 | +|------|------| +| 设备 | CPU (无 GPU) | +| 后端 | C++ 循环 (可集成 BLAS) | +| 内存 | 系统虚拟内存 | +| 算子 | 基础: elementwise, matmul (BLAS), reduce | + +## 2. 设计要点 + +最简实现,不依赖任何 GPU API: + +``` +1. 消费指令: RPOP cmd:op-cpu:0 +2. 映射数据: shm_open → mmap → CPU ptr +3. 执行计算: C++ 循环 (可能调用 OpenBLAS) +4. 写回结果: 直接写入 mmap 区域 +5. 通知完成: LPUSH done: +``` + +**性能优化 (可选):** +- OpenMP 多线程并行 +- SIMD 向量化 (编译优化 -O3 -march=native) +- 大页内存 (MAP_HUGETLB) + +## 3. 计划算子 + +| 类别 | 算子 | 实现方式 | +|------|------|---------| +| elementwise | add, sub, mul, div | 简单 C++ 循环 | +| activation | relu, sigmoid, tanh, gelu | C++ 数学库 | +| matmul | matmul | OpenBLAS SGEMM | +| reduce | sum, mean, max | C++ 循环 / BLAS | +| changeshape | reshape | 零拷贝 (共享 shm, 改 shape 元信息) | +| changeshape | transpose | 索引重排循环 | +| init | zeros, ones, arange | memset / 简单循环 | + +## 4. Tensor 访问 + +``` +GET /vthread/1/a → tensor 元信息 +shm_open(shm_name) + mmap → float* data_ptr +直接读写 data_ptr (CPU 零开销) +``` + +## 5. 算子注册 + +``` +/op/op-cpu/list = [ + "add", "sub", "mul", "div", + "relu", "sigmoid", "tanh", "gelu", + "matmul", + "sum", "mean", "max", + "reshape", "transpose", + "zeros", "ones", "arange" +] +``` + +## 6. 进程注册 + +``` +/sys/op-plat/cpu:0 = {"program":"op-cpu", "device":"cpu", "status":"running", "load":0.0, "pid":} +``` + +## 7. 依赖 + +| 依赖 | 说明 | +|------|------| +| hiredis | Redis 客户端 | +| OpenBLAS (可选) | 高性能矩阵乘法 | +| OpenMP (可选) | 多线程并行 | + +## 8. 开发量 + +~500 行 C++ (基础实现) + 整合 BLAS。 diff --git a/doc/op-plat/op-cuda.md b/doc/op-plat/op-cuda.md new file mode 100644 index 00000000..4b3f371c --- /dev/null +++ b/doc/op-plat/op-cuda.md @@ -0,0 +1,117 @@ +# op-cuda + +> Linux NVIDIA GPU 的 op-plat 实现。当前最成熟的 op-plat,47 个文件。 + +## 1. 平台特性 + +| 特性 | 说明 | +|------|------| +| GPU | NVIDIA (CUDA) | +| 后端 | CUDA C++ (nvcc) | +| 内存 | 独立显存 (VRAM) | +| 传输 | cudaMemcpy (CPU↔GPU) | +| 算子 | elementwise, changeshape, reduce, matmul, init, io 全覆盖 | + +## 2. 代码位置 + +``` +executor/op-cuda/ (47 文件, 已成熟) +├── elementwise/ add, sub, mul, div, relu, sigmoid, ... +├── changeshape/ reshape, transpose, concat, slice +├── reduce/ sum, mean, max, min +├── matmul/ matmul (cuBLAS) +├── init/ zeros, ones, arange +└── io/ 数据读写 + +依赖 common-cuda +``` + +## 3. 算子清单 (全覆盖) + +| 类别 | 算子 | 实现方式 | +|------|------|---------| +| elementwise | add, sub, mul, div | CUDA kernel | +| activation | relu, sigmoid, tanh, gelu, silu | CUDA kernel | +| matmul | matmul | cuBLAS | +| reduce | sum, mean, max, min | CUDA kernel (parallel reduction) | +| changeshape | reshape, transpose, concat, slice | CUDA kernel / cudaMemcpy | +| init | zeros, ones, arange, constant | CUDA kernel | +| fused | fused_matmul_add_relu, fused_linear_norm | CUDA kernel (融合) | +| io | load, save | cudaMemcpy | + +## 4. 算子注册 + +``` +/op/op-cuda/list = [ + "add", "sub", "mul", "div", + "relu", "sigmoid", "tanh", "gelu", "silu", + "matmul", + "sum", "mean", "max", "min", + "reshape", "transpose", "concat", "slice", + "zeros", "ones", "arange", "constant", + "fused_matmul_add_relu", "fused_linear_norm", + "load", "save" +] + +/op/op-cuda/fused_matmul_add_relu = { + "category": "fused", + "dtype": ["f32", "f16", "bf16"], + "replaces": ["matmul", "add", "relu"] +} +``` + +## 5. VRAM 访问 + +与 Metal 统一内存不同,CUDA 有独立显存。tensor 元信息中需包含 VRAM 指针: + +```json +{ + "dtype": "f32", + "shape": [1024, 512], + "byte_size": 2097152, + "device": "gpu0", + "address": { + "type": "cuda", + "shm_name": "/deepx_t_abc123", + "vram_ptr": "0x7f1234000000", + "cuda_ctx": "0" + } +} +``` + +**VRAM 指针共享方式:** +heap-cuda 分配 VRAM 后,将 `vram_ptr` 写入 shm 头部固定 offset。 +op-cuda 通过 shm_open 读取该字段获得 VRAM 指针,无需重新 cudaMalloc。 + +``` +shm layout: + [0..7]: uint64 vram_ptr ← heap-cuda 写入 + [8..15]: uint64 cuda_context + [16..]: tensor 实际数据 (CPU 可见副本, 可选) +``` + +## 6. 多 GPU 场景 + +``` +/sys/op-plat/cuda:0 → gpu0 → cmd:op-cuda:0 +/sys/op-plat/cuda:1 → gpu1 → cmd:op-cuda:1 +``` + +每个进程实例绑定一张 GPU (cudaSetDevice)。跨 GPU 的数据通过编译器插入的 +clonetensor (heap-cuda P2P) 或 cudaMemcpyPeer 完成。 + +## 7. 待改造 + +当前 op-cuda 代码已成熟,主要改造点: + +| 改造项 | 说明 | +|------|------| +| Redis 命令循环 | 替换当前通信方式为 Redis List 消费 | +| 算子注册 | 启动时 SET /op/op-cuda/list 和元数据 | +| 完成通知 | LPUSH done: | +| 进程注册 | SET /sys/op-plat/cuda:N | +| 批量执行 | batch 指令并行 (多 CUDA stream) | + +## 8. 开发量 + +主要是适配层 (~400 行 C++/CUDA C),不涉及新 kernel 开发。 diff --git a/doc/op-plat/op-metal.md b/doc/op-plat/op-metal.md new file mode 100644 index 00000000..3bde5c4e --- /dev/null +++ b/doc/op-plat/op-metal.md @@ -0,0 +1,212 @@ +# op-metal + +> macOS Metal GPU 的 op-plat 实现。Apple Silicon 上使用 Metal Shading Language。 +> 当前已有 1,325 行 C++/Metal 代码,需增加 Redis 命令循环和算子路由。 + +## 1. 平台特性 + +| 特性 | 说明 | +|------|------| +| GPU | Apple Silicon (M1/M2/M3/M4) | +| 后端 | Metal 3 (MSL - Metal Shading Language) | +| 内存 | 统一内存 (CPU/GPU 共享) | +| 算子 | 已有: elementwise, activation, init | +| 缺 | matmul, reduce, changeshape | + +## 2. 代码位置 + +``` +executor/op-metal/ +├── CMakeLists.txt +├── src/ +│ ├── client/main.mm 入口 (占位, 待改造为 Redis 循环) +│ ├── deepx/ +│ │ ├── metal_context.hpp/mm Metal 设备管理 (MTLDevice, MTLCommandQueue) +│ │ ├── mem/mem_metal.hpp 内存 (shm → MTLBuffer 包装) +│ │ ├── dtype_metal.hpp 数据类型映射 (f32↔MTLDataTypeFloat) +│ │ └── tensorfunc/ +│ │ ├── elementwise_miaobyte.hpp add/sub/mul/div +│ │ ├── elementwise_common.hpp relu/sigmoid/tanh/gelu +│ │ ├── init_miaobyte.hpp zeros/ones/arange +│ │ ├── tensorlife_miaobyte.hpp newtensor/deltensor/clone (历史代码) +│ │ └── metal_common.hpp Metal 工具 +│ └── test/shm/ 跨进程 shm 已验证 + +依赖 common-metal: + shm_tensor.h/.mm + metal_device.h/.mm +``` + +## 3. 算子清单 + +### 已实现 + +| 类别 | 算子 | GPU kernel 文件 | +|------|------|----------------| +| elementwise | add, sub, mul, div | elementwise_miaobyte.hpp | +| activation | relu | elementwise_common.hpp | +| activation | sigmoid | elementwise_common.hpp | +| activation | tanh | elementwise_common.hpp | +| activation | gelu | elementwise_common.hpp | +| init | zeros, ones, arange | init_miaobyte.hpp | + +### 需开发 + +| 类别 | 算子 | 优先级 | 难度 | +|------|------|------|------| +| matmul | matmul | 高 | 中 — 推荐用 MPSMatrixMultiplication | +| reduce | sum, mean, max | 高 | 中 — Metal parallel reduction | +| changeshape | reshape | 中 | 低 — 零拷贝 (改 shape 元信息) | +| changeshape | transpose | 中 | 低 — Metal shader 索引重排 | +| changeshape | concat | 低 | 低 — memcpy 拼接 | +| changeshape | slice | 低 | 低 — 共享 shm + offset | +| activation | softmax | 中 | 中 — reduce + exp + div 组合 | +| norm | layernorm, rmsnorm | 低 | 中 | + +## 4. 算子注册 + +启动时 (SET NX): + +``` +/op/op-metal/list = [ + "add", "sub", "mul", "div", + "relu", "sigmoid", "tanh", "gelu", + "zeros", "ones", "arange", + "matmul", "sum", "mean", "max", + "reshape", "transpose", "concat", "slice" +] + +/op/op-metal/add = { + "category": "elementwise", + "dtype": ["f32", "f16", "i32"], + "inputs": 2, + "outputs": 1 +} + +/op/op-metal/matmul = { + "category": "matmul", + "dtype": ["f32", "f16"], + "max_shape": [8192, 8192, 8192], + "fusion_group": "linear", + "inputs": 2, + "outputs": 1, + "params": ["transpose_a", "transpose_b"] +} +``` + +## 5. Tensor 访问流程 + +``` +1. 从指令获取 input key: "/vthread/1/a" +2. GET /vthread/1/a → tensor 元信息 +3. shm_open(shm_name) + mmap → cpu_ptr +4. [device newBufferWithBytesNoCopy:cpu_ptr + length:byte_size + options:MTLResourceStorageModeShared + deallocator:nil] + → id gpu_buf +5. GPU kernel 读写 gpu_buf +``` + +Apple Silicon 统一内存下,步骤 4 是零拷贝的。Metal 直接使用 CPU 物理内存页。 + +## 6. 算子调度 (dispatch_kernel) + +```cpp +// 伪代码 +json dispatch_kernel(string opcode, vector inputs, + vector outputs, json params) { + + if (opcode == "add") { + auto a = inputs[0].mtl_buffer(); + auto b = inputs[1].mtl_buffer(); + auto c = outputs[0].mtl_buffer(); + auto N = inputs[0].element_count(); + + auto encoder = command_buffer->computeCommandEncoder(); + encoder->setComputePipelineState(add_pipeline); + encoder->setBuffer(a, 0, 0); + encoder->setBuffer(b, 0, 1); + encoder->setBuffer(c, 0, 2); + encoder->dispatchThreads(MTL::Size(N, 1, 1), + MTL::Size(256, 1, 1)); + encoder->endEncoding(); + command_buffer->commit(); + command_buffer->waitUntilCompleted(); + } + // ... 其他算子 + + return {{"status", "ok"}}; +} +``` + +## 7. matmul 实现方案 + +推荐使用 MPS (Metal Performance Shaders): + +```objc +#import + +// 初始化 (启动时) +MPSMatrixMultiplication* matmul_kernel = [[MPSMatrixMultiplication alloc] + initWithDevice:device + resultRows:M resultColumns:N interiorColumns:K + alpha:1.0 beta:0.0]; + +// 每次调用 +id bufA = [device newBufferWithBytesNoCopy:a_ptr + length:M*K*sizeof(float) options:MTLResourceStorageModeShared deallocator:nil]; +id bufB = [device newBufferWithBytesNoCopy:b_ptr + length:K*N*sizeof(float) options:MTLResourceStorageModeShared deallocator:nil]; +id bufC = [device newBufferWithBytesNoCopy:c_ptr + length:M*N*sizeof(float) options:MTLResourceStorageModeShared deallocator:nil]; + +MPSMatrixDescriptor* descA = [MPSMatrixDescriptor matrixDescriptorWithRows:M + columns:K rowBytes:K*sizeof(float) dataType:MPSDataTypeFloat32]; +MPSMatrixDescriptor* descB = [MPSMatrixDescriptor matrixDescriptorWithRows:K + columns:N rowBytes:N*sizeof(float) dataType:MPSDataTypeFloat32]; +MPSMatrixDescriptor* descC = [MPSMatrixDescriptor matrixDescriptorWithRows:M + columns:N rowBytes:N*sizeof(float) dataType:MPSDataTypeFloat32]; + +MPSMatrix* matA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; +MPSMatrix* matB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; +MPSMatrix* matC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; + +[matmul_kernel encodeToCommandBuffer:commandBuffer + leftMatrix:matA rightMatrix:matB resultMatrix:matC]; +``` + +## 8. 进程注册 + +``` +/sys/op-plat/metal:0 = { + "program": "op-metal", + "device": "gpu0", + "status": "running", + "load": 0.0, + "pid": , + "started_at": +} +``` + +## 9. 编译与运行 + +```bash +brew install hiredis +cd executor/op-metal && mkdir build && cd build +cmake .. && make +./op_metal +``` + +## 10. 开发量 + +| 模块 | 新增代码 | +|------|---------| +| Redis 命令循环 | ~200 行 | +| 算子路由 dispatch | ~150 行 | +| Tensor 元信息映射 | ~100 行 | +| matmul (MPS) | ~150 行 | +| reduce kernel | ~100 行 | +| changeshape | ~50 行 | +| 算子注册 | ~50 行 | +| **合计** | **~800 行** | diff --git a/docs/executor/welcome.md b/doc/op-plat/welcome.md similarity index 100% rename from docs/executor/welcome.md rename to doc/op-plat/welcome.md diff --git a/doc/vm/README.md b/doc/vm/README.md new file mode 100644 index 00000000..879e0023 --- /dev/null +++ b/doc/vm/README.md @@ -0,0 +1,332 @@ +# VM 设计 + +> VM 是 DeepX 元程的 **解释执行引擎**,负责调度 vthread、翻译 CALL、路由指令。 +> VM 是 5 核架构中的核心调度者,连接 pysdk、op-plat、heap-plat。 + +**实现语言:Go** + +## 1. 定位 + +``` +pysdk ──写入──→ /src/func/ + /vthread/ + notify:vm + │ │ │ + ▼ ▼ ▼ + Redis (KV 空间) VM BLPOP + │ │ │ + ▼ ▼ ▼ + op-plat 消费 堆变量 VM 拾取 vthread → 执行 +``` + +| 核心 | VM 如何与之交互 | +|------|----------------| +| Redis | 全部状态存储、命令队列、通知阻塞 | +| pysdk | VM 拾取 pysdk 创建的 vthread,通过 notify:vm 唤醒 | +| op-plat | VM PUSH 计算指令到 cmd:op-*,BLPOP done:\ 等待 | +| heap-plat | VM PUSH 生命周期指令到 cmd:heap-* | + +## 2. 项目结构 + +``` +executor/vm/ +├── go.mod / go.sum +├── cmd/vm/main.go 入口: Redis 连接, 注册 VM, 启动 worker pool +├── internal/ +│ ├── engine/engine.go 核心编排: RunWorker, Execute, dispatchControl +│ ├── state/state.go VThreadState + Redis 状态读写 +│ ├── picker/picker.go Vthread 原子拾取 (WATCH/MULTI/EXEC) +│ ├── dispatch/ +│ │ ├── dispatch.go 指令分发 (compute / lifecycle / if) +│ │ └── native.go 原生算子求值 (算术/比较/逻辑) +│ ├── translate/translate.go CALL eager 翻译 + RETURN 处理 + 签名解析 +│ ├── ir/ +│ │ ├── instruction.go 指令数据结构 + 解码 + ParseDxlang +│ │ └── native.go 原生算子注册表 +│ ├── route/router.go 算子路由 (负载感知选实例) +│ └── cache/cache.go 子栈本地缓存 (可选优化) +├── testdata/*.dx dxlang 测试函数 +├── testutil/ 测试工具 (LoadDxFile, RegisterFunc, CreateVThread) +├── engine_test.go 单元测试 +└── integration_test.go 集成测试 +``` + +依赖: `github.com/redis/go-redis/v9` + +包依赖关系 (无循环): +``` +engine → picker, state, dispatch, translate, ir, route +picker → state +dispatch → state, ir, route +translate → ir, route +state → (rdb only) +ir → (rdb only) +route → (rdb only) +``` + +## 3. 并发模型 + +单进程内启动 N 个 worker goroutine (默认 `GOMAXPROCS`)。 +每个 worker 独立循环:拾取 vthread → 执行到底 → 再拾取。 +多 worker 通过 Redis WATCH/MULTI/EXEC 自然竞争,无需中心调度器。 + +### 状态机 + +``` + start → 注册 /sys/vm/ + │ + ┌──────▼──────┐ + │ idle │◄──────────────────┐ + │ PickVthread │ │ + │ BLPOP wait │ │ + └──────┬──────┘ │ + │ status=init │ + ┌──────▼──────┐ │ + │ running │ │ + │ Execute │ │ + └──────┬──────┘ │ + │ │ + ┌──────┴──────┐ │ + │ │ │ + 计算指令 控制流/原生 │ + │ │ │ + PUSH → VM 直接处理 │ + op-plat (call/if/return) │ + │ + native eval │ + BLPOP │ │ + done │ │ + │ │ │ + ▼ ▼ │ + PC++ PC++ │ + │ │ │ + └──────┬──────┘ │ + │ vthread done │ + └─────────────────────────┘ +``` + +## 4. 核心执行循环 (engine.go) + +```go +func Execute(ctx, rdb, vtid) { + for { + s := state.Get(vtid) + if s.Status == "done" || "error" { return } + + inst := ir.Decode(vtid, pc) + switch { + case IsControlOp: dispatchControl(...) // call / return / if + case IsNativeOp: dispatch.Native(...) // + - * / == < && ! + case IsLifecycleOp: dispatch.Lifecycle(...) // newtensor / deltensor + case IsFunctionCall: → convert to "call" internally + case IsComputeOp: dispatch.Compute(...) // → op-plat + } + } +} +``` + +### 函数调用检测 + +dx 源代码中调用另一函数无需 `call` 关键字,直接使用函数名: + +```dxlang +// caller.dx — 直接调用 callee +def caller(A:int, B:int) -> (C:int) { + callee(A, B) -> ./C // ✅ 无需 call() +} +``` + +引擎在 Execute 中自动检测:若 opcode 不是内置关键字(native/control/lifecycle), +则检查 `/src/func/` 或 `/op/*/func/` 是否存在 → 存在则视为函数调用, +内部转换为 `call(funcName, args...)` 格式执行。 + +## 5. Vthread 拾取 (picker.go) + +多 worker 并行竞争,Redis 事务保证每条 vthread 只被一个 worker 拾取: + +```go +func tryPick(vtid) bool { + rdb.Watch(func(tx) { + state := tx.Get(key) + if state.Status != "init" { return errSkip } + state.Status = "running" + tx.Set(key, state) + }) + return err == nil +} +``` + +- 无可用 vthread: BLPOP notify:vm (5s 超时) +- `errSkip` 哨兵: 非 init 状态 vthread 跳过 + +## 6. 指令解码 (ir/instruction.go) + +执行层坐标格式: `[addr0, addr1]`,addr1 负数为 reads,正数为 writes。 + +``` +/vthread//[3, 0] = "add" // opcode +/vthread//[3,-1] = "./a" // read[0] +/vthread//[3,-2] = "./b" // read[1] +/vthread//[3, 1] = "./c" // write[0] +``` + +子栈 PC 格式: `[parent_addr0,0]/[child_addr0,0]` + +**ParseDxlang** 支持三种格式: +- 前缀: `add(A, B) -> ./C` +- 中缀: `A + B -> ./C` +- C风格: `./C <- A + B` + +```go +type Instruction struct { + Opcode string // "+" | "call" | "return" | "matmul" | funcName ... + Reads []string + Writes []string + PC string +} +``` + +## 7. CALL Eager 翻译 (translate.go) + +CALL 时一次性将编译层 dxlang 翻译为执行层 `[i,j]` 坐标。后续逐条执行零解析开销。 + +### 翻译流程 + +``` +handleCall(funcName="add_test", args=["./a","./b"], out=["./c"]): + 1. 确定 backend (op-metal > op-cuda > op-cpu) + 2. 读取签名: GET /src/func/add_test → "def add_test(A:int, B:int) -> (C:int)" + 3. 解析形参: reads=["A","B"], writes=["C"] + 4. 建立绑定: {A:"./a", B:"./b", C:"./c"} + 5. MGET 函数体 (按数字后缀排序, 保证指令顺序): + /src/func/add_test/0 = "add(A, B) -> ./C" + 6. 逐条翻译 → Pipeline SET: + /vthread/1/[0,0]/[0,0] = "add" + /vthread/1/[0,0]/[0,-1] = "./a" (A → 实参替换) + /vthread/1/[0,0]/[0,-2] = "./b" + /vthread/1/[0,0]/[0,1] = "./C" (输出槽位, ./ 保持) + 7. 追加隐式 return: return(./C) + 8. Pipeline Exec (1 次 Redis 往返) +``` + +### 隐式 RETURN + +每个函数体末尾自动追加 `return(./)` 指令,将最后一个输出形参的值回传给父栈。 + +### mgetAll 排序 + +`KEYS` 返回顺序不确定,使用 `strconv.Atoi` 按数字后缀排序后再 MGET。 + +## 8. RETURN 处理 + +``` +handleReturn(pc="[0,0]/[1,0]"): + 1. parentPC = "[0,0]" + 2. 当前指令 reads[0]="./C" → 解析为实际值: + GET /vthread//C → "15" + 3. 父 CALL 指令 writes[0]="./c" → 写入: + SET /vthread//c = "15" + 4. 删除子栈: KEYS /vthread//[0,0]/* → DEL ... + 5. 恢复 PC = NextPC("[0,0]") = "[1,0]" +``` + +## 9. 参数约定 + +### 参数 key 命名空间 + +``` +栈内变量 (vthread-local): + ./a, ./b, ./tmp 相对路径 → /vthread//a + A, B, X 普通变量名 → 由 replaceParams 替换为实参 + +堆对象 (全局): + /models/W, /heap/... 绝对路径 → 全局堆地址 + /data/weights.bin 跨 vthread 共享 +``` + +**规则:** +- `./` 开头的 key 是栈内变量,解析时补全为 `/vthread//` +- 普通变量名 (无 `./` 前缀) 是形参占位符,翻译时替换为调用者传入的实参 +- `/` 开头的 key 是堆对象全局路径,不做栈内补全 + +### 三类参数解析 + +| 参数格式 | 解析方式 | 示例 | +|---------|---------|------| +| `./name` 相对路径 | → `/vthread//name` | ./mm → /vthread/1/mm | +| 普通变量名 | → replaceParams 替换为实参 | A → ./a | +| `/heap/...` 绝对路径 | 直接使用 | /models/W | +| 立即数 | 直接使用 | `1.0`, `true` | + +## 10. 原生算子求值 (dispatch/native.go) + +18 个符号算子直接在 VM 内求值,不经过 op-plat: + +| 类别 | 算子 | +|------|------| +| 算术 | `+` `-` `*` `/` `%` | +| 比较 | `==` `!=` `<` `>` `<=` `>=` | +| 逻辑 | `&&` `\|\|` `!` | +| 位运算 | `&` `\|` `^` `<<` `>>` | + +类型感知求值: bool → int → float → string,自动提升。 +- 算术: 两边 int → int 结果;否则 → float +- 除法: 始终 float +- 比较: 数值优先,回退字符串 + +## 11. 算子路由 (route/router.go) + +```go +Select(opcode) → "metal:0": + 1. 扫描 /op/*/list, 找到支持该 opcode 的程序 + 2. 扫描 /sys/op-plat/*, 选该程序下负载最低的实例 + 3. 返回 instance 标识符, e.g., "metal:0" +``` + +## 12. 编译器无关的函数调用 + +dx 源文件 (.dx) 以 `def` 关键字定义函数: + +```dxlang +# callee.dx — 被调用函数 +def callee(X:int, Y:int) -> (Z:int) { + X + Y -> ./Z +} + +# caller.dx — 直接调用 callee,无需 call 关键字 +def caller(A:int, B:int) -> (C:int) { + callee(A, B) -> ./C +} +``` + +引擎在运行时自动识别 `callee` 为已注册函数,将其转换为 CALL 指令。 + +## 13. 错误处理 + +``` +op-plat 返回 error: + → SET /vthread/ = {status:"error", error:{...}} + +超时: + BLPOP done: 超时 (30s) + → SET /vthread/ = {status:"error", error:{code:"TIMEOUT"}} + +解码/执行异常: + → setError() 标记 error 状态 +``` + +## 14. 编译与运行 + +```bash +cd executor/vm +go build -o ./bin/vm ./cmd/vm/ +./bin/vm +# 多实例: VM_ID=1 ./bin/vm +``` + +## 15. 验证 + +```bash +# 单元测试 +go test ./... -v + +# 集成测试 (需 Redis + mock op-plat) +go test -tags=integration -v -run 'TestIntegration' +``` diff --git a/docs/scheduler/scheduler.md b/doc/vm/scheduler.md similarity index 100% rename from docs/scheduler/scheduler.md rename to doc/vm/scheduler.md diff --git a/docs/deepxIR/deepxir.md b/docs/deepxIR/deepxir.md deleted file mode 100644 index 58e68725..00000000 --- a/docs/deepxIR/deepxir.md +++ /dev/null @@ -1,136 +0,0 @@ -# DeepX IR(deepxir)规范 - -## 1. 类型系统 - -### 基础数据类型 -``` -type f16, f32, f64, bf16, bf8 // 浮点类型 -type i8, i16, i32, i64, u8 // 整数类型 -type bool // 布尔类型 -``` - -### 动态长度类型 -``` -list // list 可以和以上基础类型组合 -``` - -### 类型约束 -``` -f32|f64 // 支持两种/多种 类型之一 -``` - -### Tensor 类型模板 -``` -type tensor -``` -- shape 格式:dim1xdim2x...xdimN,或使用 `?` 表示动态维度。 最后一个x后的是精度。 -- 示例:`tensor<10x20xf32>`, `tensor` - -tensor 也可以没有 shape 和 dtype 的约束,例如: -``` -deepxir addscalar(A:tensor, b:i8|i16|i32|i64) -> (c:tensor) { ... } -``` -表示任意 shape、任意 dtype 的 tensor 都可作为参数。 - -### 动态维度变量 -- `?` 任意数字 -- `?1` 动态维度变量 1 -- `?2` 动态维度变量 2(用于表示同名变量处维度需一致) -- 示例:`tensor` - -## 2. IR 定义格式 - -语法示例: -``` -deepxir ir_name(ro_p1:type1, ro_param2:type2, ...) -> (w_p1:type3, w_p2:type4, ...) -{ - // 函数体:IR 操作序列 - operation_name(ro_p1, ro_p2) -> w_p1 - operation_name(ro_p2, ro_p2) -> w_p2 -} -``` -- `deepxir` 为关键字,也可使用 `function`、`func` 等。 -- 参数遵循“左读右写”规则(无返回值;通过写入参数实现输出)。 -- 参数类型支持:`tensor`、`list`、基础类型,以及基础类型的 list。 - -## 3. 设计思考 -DeepX IR 采用简洁的文本格式表示张量类型约束、运算定义与运算体,便于阅读与解析。 -deepx不是ssa,调用时,依然遵循左读右写的参数列表原则,右写的参数列表支持多个。 - -## 4. 具体示例 - -### 示例 1:融合 Linear + 归一化 -``` -deepxir fused_linear_norm( - A: tensor, - W: tensor, - b: tensor, - axis: i32, - keepdims: bool -) -> (out: tensor) { - newtensor(?1x?3, f32)->(mm) - matmul(A, W)-> (mm) - newtensor(?1x?3, f32)-> bias - add(mm, b)-> bias - deltensor(mm)-> mm - newtensor(?1, f32)-> mean - sum(bias, axis, keepdims)-> mean - newtensor(?1x?3, f32)-> centered - sub(bias, mean)-> centered - deltensor(bias)-> bias - deltensor(mean)-> mean - newtensor(?1x?3, f32)-> sq - mul(centered, centered)-> sq - deltensor(centered)-> centered - newtensor(?1, f32)-> var - sum(sq, axis, keepdims)-> var - deltensor(sq)-> sq - constant(1e-5)-> eps - newtensor(?1, f32)-> var_eps - add(var, eps)-> var_eps - deltensor(var)-> var - deltensor(eps)-> eps - newtensor(?1, f32)-> std - sqrt(var_eps)-> std - deltensor(var_eps)-> var_eps - div(std, std)-> std - deltensor(std)-> std - div(centered, std)-> out -} -``` - -下面给出一个完整的 `deepxir` 调用示例:在一个 IR 中先构造输入张量和辅助参数,然后调用 `fused_linear_norm`,输出 `out`。 - -``` -deepxir example_use_fused_linear_norm() -> (out: tensor<2x3xf32>) { - newtensor([2,4], f32)-> A - newtensor([4,3], f32)-> W - newtensor([3], f32)-> b - fused_linear_norm(A, W, b, 1, false) -> out -} -``` - -该示例展示了如何在 IR 中构造必要的张量/参数并调用 `fused_linear_norm`,其中 `out` 的类型为 `tensor<2x3xf32>`,与 `W` 的列数和 `A` 的行数对应。 - -### 示例 2:融合 Attention score + Softmax -``` -deepxir fused_attention_scores( - Q: tensor, - K: tensor, - axis: list, - keepdims: bool, - shape_scores: list, - shape_sum: list -) -> (out: tensor) { - newtensor(shape_scores, f32)-> scores_tmp - matmul(Q, K)-> scores_tmp - newtensor(shape_scores, f32)-> exp_tmp - exp(scores_tmp)-> exp_tmp - deltensor(scores_tmp)-> scores_tmp - newtensor(shape_sum, f32)-> sum_tmp - sum(exp_tmp, axis, keepdims)-> sum_tmp - div(exp_tmp, sum_tmp)-> out - deltensor(exp_tmp)-> exp_tmp - deltensor(sum_tmp)-> sum_tmp -} -``` \ No newline at end of file diff --git a/example/dxlang/builtin/call/add_test.dx b/example/dxlang/builtin/call/add_test.dx new file mode 100644 index 00000000..b7b8ad58 --- /dev/null +++ b/example/dxlang/builtin/call/add_test.dx @@ -0,0 +1,8 @@ +# add_test: basic element-wise addition +def add_test(A:tensor, B:tensor) -> (C:tensor) { + add(A, B) -> "./C" +} + +# This top-level call requires booted services (heap-plat + op-metal) +# with tensors at /data/a and /data/b already created. +# add_test(/data/a, /data/b) -> "/data/c" diff --git a/example/dxlang/builtin/call/callee.dx b/example/dxlang/builtin/call/callee.dx new file mode 100644 index 00000000..070793a6 --- /dev/null +++ b/example/dxlang/builtin/call/callee.dx @@ -0,0 +1,7 @@ +# callee: a simple function called by caller +def callee(X:int, Y:int) -> (Z:int) { + X + Y -> "./Z" +} + +# Top-level call: callee(2, 3) -> "./out" +callee(2, 3) -> "./out" diff --git a/example/dxlang/builtin/call/caller.dx b/example/dxlang/builtin/call/caller.dx new file mode 100644 index 00000000..4b84bfad --- /dev/null +++ b/example/dxlang/builtin/call/caller.dx @@ -0,0 +1,7 @@ +# caller: calls callee to add two numbers +def caller(A:int, B:int) -> (C:int) { + callee(A, B) -> "./C" +} + +# Top-level call: caller(2, 3) -> "./out" (calls callee internally) +caller(2, 3) -> "./out" diff --git a/example/dxlang/builtin/call/cstyle_call.dx b/example/dxlang/builtin/call/cstyle_call.dx new file mode 100644 index 00000000..af824f3d --- /dev/null +++ b/example/dxlang/builtin/call/cstyle_call.dx @@ -0,0 +1,7 @@ +# cstyle_call: C-style assignment with function call +def cstyle_call(A:int, B:int) -> (C:int) { + "./C" <- add(A, B) +} + +# Top-level call: cstyle_call(2, 3) -> "./out" +cstyle_call(2, 3) -> "./out" diff --git a/example/dxlang/builtin/call/deep3.dx b/example/dxlang/builtin/call/deep3.dx new file mode 100644 index 00000000..7f7084c5 --- /dev/null +++ b/example/dxlang/builtin/call/deep3.dx @@ -0,0 +1,7 @@ +# deep3 calls middle, middle calls leaf — 3 levels deep +def deep3(X:int) -> (Y:int) { + middle(X) -> "./Y" +} + +# Top-level call: deep3(5) -> "./out" (3-level call chain) +deep3(5) -> "./out" diff --git a/example/dxlang/builtin/call/diamond.dx b/example/dxlang/builtin/call/diamond.dx new file mode 100644 index 00000000..f7b42ff9 --- /dev/null +++ b/example/dxlang/builtin/call/diamond.dx @@ -0,0 +1,9 @@ +# diamond: splits into double+triple, then sums results +def diamond(A:int) -> (R:int) { + double(A) -> "./d" + triple(A) -> "./t" + "./d" + "./t" -> "./R" +} + +# Top-level call: diamond(5) -> "./out" (double+triple→sum) +diamond(5) -> "./out" diff --git a/example/dxlang/builtin/call/double.dx b/example/dxlang/builtin/call/double.dx new file mode 100644 index 00000000..e9874a9e --- /dev/null +++ b/example/dxlang/builtin/call/double.dx @@ -0,0 +1,7 @@ +# double: multiply by 2 +def double(X:int) -> (Y:int) { + X * 2 -> "./Y" +} + +# Top-level call: double(5) -> "./out" +double(5) -> "./out" diff --git a/example/dxlang/builtin/call/middle.dx b/example/dxlang/builtin/call/middle.dx new file mode 100644 index 00000000..63e66106 --- /dev/null +++ b/example/dxlang/builtin/call/middle.dx @@ -0,0 +1,8 @@ +# middle: calls leaf then adds 1 +def middle(X:int) -> (Y:int) { + leaf(X) -> "./tmp" + "./tmp" + 1 -> "./Y" +} + +# Top-level call: middle(5) -> "./out" (calls leaf→double internally) +middle(5) -> "./out" diff --git a/example/dxlang/builtin/call/triple.dx b/example/dxlang/builtin/call/triple.dx new file mode 100644 index 00000000..40f2ce00 --- /dev/null +++ b/example/dxlang/builtin/call/triple.dx @@ -0,0 +1,7 @@ +# triple: multiply by 3 +def triple(X:int) -> (Y:int) { + X * 3 -> "./Y" +} + +# Top-level call: triple(5) -> "./out" +triple(5) -> "./out" diff --git a/example/dxlang/builtin/native/arith/abs.dx b/example/dxlang/builtin/native/arith/abs.dx new file mode 100644 index 00000000..516070ba --- /dev/null +++ b/example/dxlang/builtin/native/arith/abs.dx @@ -0,0 +1,6 @@ +def native_abs(A:int) -> (C:int) { + abs(A) -> "./C" +} + +# Top-level call: native_abs(-5) -> "./out" +native_abs(-5) -> "./out" diff --git a/example/dxlang/builtin/native/arith/add.dx b/example/dxlang/builtin/native/arith/add.dx new file mode 100644 index 00000000..2f1acbc9 --- /dev/null +++ b/example/dxlang/builtin/native/arith/add.dx @@ -0,0 +1,6 @@ +def native_arith(A:int, B:int) -> (C:int) { + A + B -> "./C" +} + +# Top-level call: native_arith(2,3) -> "./out" +native_arith(2,3) -> "./out" diff --git a/example/dxlang/builtin/native/arith/div.dx b/example/dxlang/builtin/native/arith/div.dx new file mode 100644 index 00000000..524e92df --- /dev/null +++ b/example/dxlang/builtin/native/arith/div.dx @@ -0,0 +1,6 @@ +def native_div(A:int, B:int) -> (C:float) { + A / B -> "./C" +} + +# Top-level call: native_div(15,2) -> "./out" +native_div(15,2) -> "./out" diff --git a/example/dxlang/builtin/native/arith/exp.dx b/example/dxlang/builtin/native/arith/exp.dx new file mode 100644 index 00000000..b10b15d4 --- /dev/null +++ b/example/dxlang/builtin/native/arith/exp.dx @@ -0,0 +1,6 @@ +def native_exp(A:int) -> (C:float) { + exp(A) -> "./C" +} + +# Top-level call: native_exp(1) -> "./out" +native_exp(1) -> "./out" diff --git a/example/dxlang/builtin/native/arith/log.dx b/example/dxlang/builtin/native/arith/log.dx new file mode 100644 index 00000000..0298dc39 --- /dev/null +++ b/example/dxlang/builtin/native/arith/log.dx @@ -0,0 +1,6 @@ +def native_log(A:int) -> (C:float) { + log(A) -> "./C" +} + +# Top-level call: native_log(10) -> "./out" +native_log(10) -> "./out" diff --git a/example/dxlang/builtin/native/arith/max.dx b/example/dxlang/builtin/native/arith/max.dx new file mode 100644 index 00000000..df3a9a93 --- /dev/null +++ b/example/dxlang/builtin/native/arith/max.dx @@ -0,0 +1,6 @@ +def native_max(A:int, B:int) -> (C:int) { + max(A, B) -> "./C" +} + +# Top-level call: native_max(7,3) -> "./out" +native_max(7,3) -> "./out" diff --git a/example/dxlang/builtin/native/arith/min.dx b/example/dxlang/builtin/native/arith/min.dx new file mode 100644 index 00000000..91067c7e --- /dev/null +++ b/example/dxlang/builtin/native/arith/min.dx @@ -0,0 +1,6 @@ +def native_min(A:int, B:int) -> (C:int) { + min(A, B) -> "./C" +} + +# Top-level call: native_min(-2,5) -> "./out" +native_min(-2,5) -> "./out" diff --git a/example/dxlang/builtin/native/arith/mul.dx b/example/dxlang/builtin/native/arith/mul.dx new file mode 100644 index 00000000..ac930a7d --- /dev/null +++ b/example/dxlang/builtin/native/arith/mul.dx @@ -0,0 +1,6 @@ +def native_mul(A:int, B:int) -> (C:int) { + A * B -> "./C" +} + +# Top-level call: native_mul(6,7) -> "./out" +native_mul(6,7) -> "./out" diff --git a/example/dxlang/builtin/native/arith/neg.dx b/example/dxlang/builtin/native/arith/neg.dx new file mode 100644 index 00000000..531c5683 --- /dev/null +++ b/example/dxlang/builtin/native/arith/neg.dx @@ -0,0 +1,6 @@ +def native_neg(A:int) -> (C:int) { + neg(A) -> "./C" +} + +# Top-level call: native_neg(5) -> "./out" +native_neg(5) -> "./out" diff --git a/example/dxlang/builtin/native/arith/pow.dx b/example/dxlang/builtin/native/arith/pow.dx new file mode 100644 index 00000000..84480e74 --- /dev/null +++ b/example/dxlang/builtin/native/arith/pow.dx @@ -0,0 +1,6 @@ +def native_pow(A:int, B:int) -> (C:float) { + pow(A, B) -> "./C" +} + +# Top-level call: native_pow(2,3) -> "./out" +native_pow(2,3) -> "./out" diff --git a/example/dxlang/builtin/native/arith/sign.dx b/example/dxlang/builtin/native/arith/sign.dx new file mode 100644 index 00000000..c310f9bb --- /dev/null +++ b/example/dxlang/builtin/native/arith/sign.dx @@ -0,0 +1,6 @@ +def native_sign(A:int) -> (C:int) { + sign(A) -> "./C" +} + +# Top-level call: native_sign(-8) -> "./out" +native_sign(-8) -> "./out" diff --git a/example/dxlang/builtin/native/arith/sqrt.dx b/example/dxlang/builtin/native/arith/sqrt.dx new file mode 100644 index 00000000..8eecd3f6 --- /dev/null +++ b/example/dxlang/builtin/native/arith/sqrt.dx @@ -0,0 +1,6 @@ +def native_sqrt(A:int) -> (C:float) { + sqrt(A) -> "./C" +} + +# Top-level call: native_sqrt(16) -> "./out" +native_sqrt(16) -> "./out" diff --git a/example/dxlang/builtin/native/arith/sub.dx b/example/dxlang/builtin/native/arith/sub.dx new file mode 100644 index 00000000..c8fdef57 --- /dev/null +++ b/example/dxlang/builtin/native/arith/sub.dx @@ -0,0 +1,6 @@ +def native_sub(A:int, B:int) -> (C:int) { + A - B -> "./C" +} + +# Top-level call: native_sub(10,3) -> "./out" +native_sub(10,3) -> "./out" diff --git a/example/dxlang/builtin/native/cast/float.dx b/example/dxlang/builtin/native/cast/float.dx new file mode 100644 index 00000000..7667bdf5 --- /dev/null +++ b/example/dxlang/builtin/native/cast/float.dx @@ -0,0 +1,6 @@ +def native_float(A:int) -> (C:float) { + float(A) -> "./C" +} + +# Top-level call: native_float(42) -> "./out" +native_float(42) -> "./out" diff --git a/example/dxlang/builtin/native/cast/int.dx b/example/dxlang/builtin/native/cast/int.dx new file mode 100644 index 00000000..358d2ccb --- /dev/null +++ b/example/dxlang/builtin/native/cast/int.dx @@ -0,0 +1,6 @@ +def native_int(A:float) -> (C:int) { + int(A) -> "./C" +} + +# Top-level call: native_int(3.7) -> "./out" +native_int(3.7) -> "./out" diff --git a/example/dxlang/builtin/native/chain/chain.dx b/example/dxlang/builtin/native/chain/chain.dx new file mode 100644 index 00000000..5bc66948 --- /dev/null +++ b/example/dxlang/builtin/native/chain/chain.dx @@ -0,0 +1,7 @@ +def native_chain(A:int, B:int, C:int) -> (D:int) { + A + B -> "./tmp" + "./tmp" * C -> "./D" +} + +# Top-level call: native_chain(2, 3, 4) -> "./out" +native_chain(2, 3, 4) -> "./out" diff --git a/example/dxlang/builtin/native/compare/eq.dx b/example/dxlang/builtin/native/compare/eq.dx new file mode 100644 index 00000000..c7eb9d2d --- /dev/null +++ b/example/dxlang/builtin/native/compare/eq.dx @@ -0,0 +1,6 @@ +def native_eq(A:int, B:int) -> (C:bool) { + A == B -> "./C" +} + +# Top-level call: native_eq(5, 5) -> "./out" +native_eq(5, 5) -> "./out" diff --git a/example/dxlang/builtin/native/compare/lt.dx b/example/dxlang/builtin/native/compare/lt.dx new file mode 100644 index 00000000..5f81650d --- /dev/null +++ b/example/dxlang/builtin/native/compare/lt.dx @@ -0,0 +1,6 @@ +def native_lt(A:int, B:int) -> (C:bool) { + A < B -> "./C" +} + +# Top-level call: native_lt(3, 7) -> "./out" +native_lt(3, 7) -> "./out" diff --git a/example/dxlang/builtin/native/logic/and.dx b/example/dxlang/builtin/native/logic/and.dx new file mode 100644 index 00000000..1f49be56 --- /dev/null +++ b/example/dxlang/builtin/native/logic/and.dx @@ -0,0 +1,6 @@ +def native_and(A:bool, B:bool) -> (C:bool) { + A && B -> "./C" +} + +# Top-level call: native_and(true, false) -> "./out" +native_and(true, false) -> "./out" diff --git a/example/dxlang/builtin/native/logic/bool.dx b/example/dxlang/builtin/native/logic/bool.dx new file mode 100644 index 00000000..7b56de09 --- /dev/null +++ b/example/dxlang/builtin/native/logic/bool.dx @@ -0,0 +1,6 @@ +def native_bool(A:int) -> (C:bool) { + bool(A) -> "./C" +} + +# Top-level call: native_bool(1) -> "./out" +native_bool(1) -> "./out" diff --git a/example/dxlang/builtin/native/logic/not.dx b/example/dxlang/builtin/native/logic/not.dx new file mode 100644 index 00000000..b2b243fe --- /dev/null +++ b/example/dxlang/builtin/native/logic/not.dx @@ -0,0 +1,6 @@ +def native_not(A:bool) -> (C:bool) { + !A -> "./C" +} + +# Top-level call: native_not(true) -> "./out" +native_not(true) -> "./out" diff --git a/example/dxlang/call/leaf.dx b/example/dxlang/call/leaf.dx new file mode 100644 index 00000000..a0b7c927 --- /dev/null +++ b/example/dxlang/call/leaf.dx @@ -0,0 +1,4 @@ +# leaf: multiplies input by 2 +def leaf(X:int) -> (Y:int) { + X * 2 -> "./Y" +} diff --git a/example/dxlang/io/print_tensor.dx b/example/dxlang/io/print_tensor.dx new file mode 100644 index 00000000..2e284d2f --- /dev/null +++ b/example/dxlang/io/print_tensor.dx @@ -0,0 +1,6 @@ +# print_tensor: 验证 print 算子功能 +# 创建 tensor 后打印其内容 +def print_tensor() -> ("/data/x") { + newtensor("f32", "[10]") -> "/data/x" + print("/data/x") +} diff --git a/example/dxlang/lifecycle/full.dx b/example/dxlang/lifecycle/full.dx new file mode 100644 index 00000000..9d0d0510 --- /dev/null +++ b/example/dxlang/lifecycle/full.dx @@ -0,0 +1,9 @@ +# lifecycle_full: create tensors, compute, then cleanup +def lifecycle_full() -> ("/data/c") { + newtensor("f32", "[4]") -> "/data/a" + newtensor("f32", "[4]") -> "/data/b" + newtensor("f32", "[4]") -> "/data/c" + add("/data/a", "/data/b") -> "/data/c" + deltensor("/data/a") + deltensor("/data/b") +} diff --git a/example/dxlang/native/arith/cstyle_add.dx b/example/dxlang/native/arith/cstyle_add.dx new file mode 100644 index 00000000..15a6f1f8 --- /dev/null +++ b/example/dxlang/native/arith/cstyle_add.dx @@ -0,0 +1,4 @@ +# cstyle_add: C-style assignment with infix operator +def cstyle_add(A:int, B:int) -> (C:int) { + "./C" <- A + B +} diff --git a/example/dxlang/tensor/call/tensor_pipeline.dx b/example/dxlang/tensor/call/tensor_pipeline.dx new file mode 100644 index 00000000..49c15e37 --- /dev/null +++ b/example/dxlang/tensor/call/tensor_pipeline.dx @@ -0,0 +1,40 @@ +# tensor_pipeline: 多函数调用链,每层函数内部创建 tensor、计算、清理 +# 验证深度 CALL + tensor lifecycle 正确性 +# 调用链: producer → stage1 → stage2 → consumer + +# ── 最内层: 创建 tensor 并做 elementwise 计算 ── +def producer() -> ("/data/p") { + newtensor("f32", "[64]") -> "/data/p" + newtensor("f32", "[64]") -> "/data/t" + zeros() -> "/data/p" + add("/data/p", "/data/p") -> "/data/t" + mul("/data/t", "/data/p") -> "/data/p" + deltensor("/data/t") +} + +# ── 中间层 1: 调用 producer 并对结果做 relu ── +def stage1() -> ("/data/s1") { + producer() -> "/data/s1" + newtensor("f32", "[64]") -> "/data/relu_out" + relu("/data/s1") -> "/data/relu_out" + deltensor("/data/s1") + # 将 relu_out 重命名给调用方 (依赖 return 写入的 key) +} + +# ── 中间层 2: 调用 stage1 并对结果做 exp ── +def stage2() -> ("/data/s2") { + stage1() -> "/data/s2" + newtensor("f32", "[64]") -> "/data/exp_out" + exp("/data/s2") -> "/data/exp_out" + deltensor("/data/s2") +} + +# ── 最外层: 调用 stage2 并打印结果,最后清理 ── +def consumer() -> () { + stage2() -> "/data/final" + print("/data/final") + deltensor("/data/final") +} + +# Top-level call: consumer() -> () (full pipeline: producer→stage1→stage2→consumer) +consumer() -> () diff --git a/example/dxlang/tensor/io/io_pipeline.dx b/example/dxlang/tensor/io/io_pipeline.dx new file mode 100644 index 00000000..4b836b35 --- /dev/null +++ b/example/dxlang/tensor/io/io_pipeline.dx @@ -0,0 +1,25 @@ +# io_pipeline: 完整的 IO 流水线 — 创建→计算→保存→加载→打印→清理 +# 验证 VM ↔ heap-plat ↔ op-metal 三方联调 +def io_pipeline() -> ("/data/final") { + newtensor("f32", "[64]") -> "/data/x" + newtensor("f32", "[64]") -> "/data/y" + newtensor("f32", "[64]") -> "/data/z" + newtensor("f32", "[64]") -> "/data/temp" + zeros() -> "/data/x" + zeros() -> "/data/y" + zeros() -> "/data/z" + add("/data/x", "/data/y") -> "/data/temp" + mul("/data/temp", "/data/z") -> "/data/z" + print("/data/z") + save("/data/z", "/tmp/io_pipeline_z") + newtensor("f32", "[64]") -> "/data/final" + load("/tmp/io_pipeline_z") -> "/data/final" + print("/data/final") + deltensor("/data/x") + deltensor("/data/y") + deltensor("/data/z") + deltensor("/data/temp") +} + +# Top-level call: io_pipeline() -> "/data/final" +io_pipeline() -> "/data/final" diff --git a/example/dxlang/tensor/io/save_load.dx b/example/dxlang/tensor/io/save_load.dx new file mode 100644 index 00000000..a7e49beb --- /dev/null +++ b/example/dxlang/tensor/io/save_load.dx @@ -0,0 +1,14 @@ +# save_load: 验证 save 和 load 算子完整流程 +# 1. 创建 tensor → 2. save 到磁盘 → 3. 创建新 tensor → 4. load 回内存 +def save_load() -> ("/data/loaded") { + newtensor("f32", "[10]") -> "/data/src" + zeros() -> "/data/src" + save("/data/src", "/tmp/test_save_load") + newtensor("f32", "[10]") -> "/data/loaded" + load("/tmp/test_save_load") -> "/data/loaded" + print("/data/loaded") + deltensor("/data/src") +} + +# Top-level call: save_load() -> "/data/loaded" +save_load() -> "/data/loaded" diff --git a/example/dxlang/tensor/lifecycle/batch_ops.dx b/example/dxlang/tensor/lifecycle/batch_ops.dx new file mode 100644 index 00000000..8d49bdd5 --- /dev/null +++ b/example/dxlang/tensor/lifecycle/batch_ops.dx @@ -0,0 +1,33 @@ +# batch_ops: 创建 5 个 tensor,多步 GPU 计算,最后全部清理 +# 联调 VM → heap-plat → op-metal 完整链路 +def batch_ops() -> ("/data/final") { + newtensor("f32", "[256]") -> "/data/a" + newtensor("f32", "[256]") -> "/data/b" + newtensor("f32", "[256]") -> "/data/c" + newtensor("f32", "[256]") -> "/data/d" + newtensor("f32", "[256]") -> "/data/final" + newtensor("f32", "[256]") -> "/data/temp" + + # 初始化输入 tensor + zeros() -> "/data/a" + zeros() -> "/data/b" + zeros() -> "/data/c" + zeros() -> "/data/d" + + # Step 1: temp = a + b + add("/data/a", "/data/b") -> "/data/temp" + # Step 2: final = temp * c (使用 mul,GPU elementwise) + mul("/data/temp", "/data/c") -> "/data/final" + # Step 3: final = final + d + add("/data/final", "/data/d") -> "/data/final" + + # 清理中间和输入 tensor + deltensor("/data/temp") + deltensor("/data/a") + deltensor("/data/b") + deltensor("/data/c") + deltensor("/data/d") +} + +# Top-level call: batch_ops() -> "/data/final" +batch_ops() -> "/data/final" diff --git a/example/dxlang/tensor/lifecycle/clone_and_use.dx b/example/dxlang/tensor/lifecycle/clone_and_use.dx new file mode 100644 index 00000000..9cf080c7 --- /dev/null +++ b/example/dxlang/tensor/lifecycle/clone_and_use.dx @@ -0,0 +1,25 @@ +# clone_and_use: 创建 tensor → clone → 分别计算 → 合并结果 → 清理 +# 验证 heap-plat clonetensor + op-metal compute 协同 +def clone_and_use() -> ("/data/merged") { + newtensor("f32", "[128]") -> "/data/original" + zeros() -> "/data/original" + clonetensor("/data/original", "/data/copy") + + newtensor("f32", "[128]") -> "/data/doubled" + newtensor("f32", "[128]") -> "/data/merged" + + # 对 original 做 add,对 copy 做 mul + add("/data/original", "/data/original") -> "/data/doubled" + mul("/data/copy", "/data/copy") -> "/data/copy" + + # 合并两个分支结果 + add("/data/doubled", "/data/copy") -> "/data/merged" + + # 只保留 merged,其他全部清理 + deltensor("/data/original") + deltensor("/data/copy") + deltensor("/data/doubled") +} + +# Top-level call: clone_and_use() -> "/data/merged" +clone_and_use() -> "/data/merged" diff --git a/example/dxlang/tensor/lifecycle/compute.dx b/example/dxlang/tensor/lifecycle/compute.dx new file mode 100644 index 00000000..168ef255 --- /dev/null +++ b/example/dxlang/tensor/lifecycle/compute.dx @@ -0,0 +1,14 @@ +# compute: create 2 tensors, init with zeros, compute sum into a 3rd, cleanup +def compute() -> ("/data/c") { + newtensor("f32", "[8]") -> "/data/a" + newtensor("f32", "[8]") -> "/data/b" + newtensor("f32", "[8]") -> "/data/c" + zeros() -> "/data/a" + zeros() -> "/data/b" + add("/data/a", "/data/b") -> "/data/c" + deltensor("/data/a") + deltensor("/data/b") +} + +# Top-level call: compute() -> "/data/c" +compute() -> "/data/c" diff --git a/example/dxlang/tensor/lifecycle/del.dx b/example/dxlang/tensor/lifecycle/del.dx new file mode 100644 index 00000000..95d7c4eb --- /dev/null +++ b/example/dxlang/tensor/lifecycle/del.dx @@ -0,0 +1,8 @@ +# lifecycle_del: create then delete a heap tensor +def lifecycle_del() -> () { + newtensor("f32", "[8]") -> "/data/tmp" + deltensor("/data/tmp") +} + +# Top-level call: lifecycle_del() -> () +lifecycle_del() -> () diff --git a/example/dxlang/tensor/lifecycle/newtensor.dx b/example/dxlang/tensor/lifecycle/newtensor.dx new file mode 100644 index 00000000..04560d97 --- /dev/null +++ b/example/dxlang/tensor/lifecycle/newtensor.dx @@ -0,0 +1,7 @@ +# lifecycle_newtensor: create a heap tensor and store its reference +def lifecycle_newtensor() -> ("/data/x") { + newtensor("f32", "[16]") -> "/data/x" +} + +# Top-level call: lifecycle_newtensor() -> "/data/x" +lifecycle_newtensor() -> "/data/x" diff --git a/example/dxlang/tensor/math/dist2.dx b/example/dxlang/tensor/math/dist2.dx new file mode 100644 index 00000000..14f0ba1f --- /dev/null +++ b/example/dxlang/tensor/math/dist2.dx @@ -0,0 +1,23 @@ +# dist2: 逐元素平方差 (a - b)^2,用于距离计算的元素部分 +# 全程 GPU tensor: sub → mul +def dist2() -> ("/data/dist2") { + newtensor("f32", "[256]") -> "/data/a" + newtensor("f32", "[256]") -> "/data/b" + newtensor("f32", "[256]") -> "/data/diff" + newtensor("f32", "[256]") -> "/data/dist2" + + zeros() -> "/data/a" + zeros() -> "/data/b" + + # diff = a - b + sub("/data/a", "/data/b") -> "/data/diff" + # dist2 = diff * diff + mul("/data/diff", "/data/diff") -> "/data/dist2" + + deltensor("/data/a") + deltensor("/data/b") + deltensor("/data/diff") +} + +# Top-level call: dist2() -> "/data/dist2" +dist2() -> "/data/dist2" diff --git a/example/dxlang/tensor/math/hadamard3.dx b/example/dxlang/tensor/math/hadamard3.dx new file mode 100644 index 00000000..54482a31 --- /dev/null +++ b/example/dxlang/tensor/math/hadamard3.dx @@ -0,0 +1,26 @@ +# hadamard3: 三元 Hadamard 积 a * b * c +# mul 是 binary GPU op,分两步完成 +def hadamard3() -> ("/data/result") { + newtensor("f32", "[128]") -> "/data/a" + newtensor("f32", "[128]") -> "/data/b" + newtensor("f32", "[128]") -> "/data/c" + newtensor("f32", "[128]") -> "/data/temp" + newtensor("f32", "[128]") -> "/data/result" + + zeros() -> "/data/a" + zeros() -> "/data/b" + zeros() -> "/data/c" + + # temp = a * b + mul("/data/a", "/data/b") -> "/data/temp" + # result = temp * c + mul("/data/temp", "/data/c") -> "/data/result" + + deltensor("/data/a") + deltensor("/data/b") + deltensor("/data/c") + deltensor("/data/temp") +} + +# Top-level call: hadamard3() -> "/data/result" +hadamard3() -> "/data/result" diff --git a/example/dxlang/tensor/math/max_abs.dx b/example/dxlang/tensor/math/max_abs.dx new file mode 100644 index 00000000..56046837 --- /dev/null +++ b/example/dxlang/tensor/math/max_abs.dx @@ -0,0 +1,24 @@ +# max_abs: 对两个 tensor 逐元素取 max(|a|, |b|) +# 流程: abs(a) → abs(b) → max(abs_a, abs_b) +def max_abs() -> ("/data/result") { + newtensor("f32", "[128]") -> "/data/a" + newtensor("f32", "[128]") -> "/data/b" + newtensor("f32", "[128]") -> "/data/abs_a" + newtensor("f32", "[128]") -> "/data/abs_b" + newtensor("f32", "[128]") -> "/data/result" + + zeros() -> "/data/a" + zeros() -> "/data/b" + + abs("/data/a") -> "/data/abs_a" + abs("/data/b") -> "/data/abs_b" + max("/data/abs_a", "/data/abs_b") -> "/data/result" + + deltensor("/data/a") + deltensor("/data/b") + deltensor("/data/abs_a") + deltensor("/data/abs_b") +} + +# Top-level call: max_abs() -> "/data/result" +max_abs() -> "/data/result" diff --git a/example/dxlang/tensor/mixed/native_and_gpu.dx b/example/dxlang/tensor/mixed/native_and_gpu.dx new file mode 100644 index 00000000..88693f51 --- /dev/null +++ b/example/dxlang/tensor/mixed/native_and_gpu.dx @@ -0,0 +1,26 @@ +# native_and_gpu: VM 原生标量计算与 GPU tensor 操作交叉执行 +# 验证 VM native eval → heap-plat → op-metal 混合调度正确性 +def native_and_gpu(N:int) -> ("/data/final") { + # Step 1: VM 原生计算确定 tensor 大小 + N * 4 -> "./size" + + # Step 2: 基于标量结果创建 tensor(此处简化为固定大小) + newtensor("f32", "[256]") -> "/data/x" + newtensor("f32", "[256]") -> "/data/y" + newtensor("f32", "[256]") -> "/data/final" + + zeros() -> "/data/x" + zeros() -> "/data/y" + + # Step 3: GPU tensor 计算 + add("/data/x", "/data/y") -> "/data/final" + + # Step 4: VM 原生比较 (验证标量结果还在) + N == N -> "./valid" + + deltensor("/data/x") + deltensor("/data/y") +} + +# Top-level call: native_and_gpu(8) -> "/data/final" +native_and_gpu(8) -> "/data/final" diff --git a/example/dxlang/tensor/nn/elemwise_long.dx b/example/dxlang/tensor/nn/elemwise_long.dx new file mode 100644 index 00000000..36eaae4e --- /dev/null +++ b/example/dxlang/tensor/nn/elemwise_long.dx @@ -0,0 +1,32 @@ +# elemwise_long: 长链 GPU elementwise 操作 pipeline +# 验证 op-metal 连续多步 kernel dispatch 的稳定性 +# 流程: relu → exp → log → abs → neg → sqrt +def elemwise_long() -> ("/data/out") { + newtensor("f32", "[512]") -> "/data/in" + newtensor("f32", "[512]") -> "/data/s1" + newtensor("f32", "[512]") -> "/data/s2" + newtensor("f32", "[512]") -> "/data/s3" + newtensor("f32", "[512]") -> "/data/s4" + newtensor("f32", "[512]") -> "/data/s5" + newtensor("f32", "[512]") -> "/data/out" + + zeros() -> "/data/in" + + # Pipeline: in → s1 → s2 → s3 → s4 → s5 → out + relu("/data/in") -> "/data/s1" + exp("/data/s1") -> "/data/s2" + log("/data/s2") -> "/data/s3" + abs("/data/s3") -> "/data/s4" + neg("/data/s4") -> "/data/s5" + sqrt("/data/s5") -> "/data/out" + + deltensor("/data/in") + deltensor("/data/s1") + deltensor("/data/s2") + deltensor("/data/s3") + deltensor("/data/s4") + deltensor("/data/s5") +} + +# Top-level call: elemwise_long() -> "/data/out" +elemwise_long() -> "/data/out" diff --git a/example/dxlang/tensor/nn/mlp_small.dx b/example/dxlang/tensor/nn/mlp_small.dx new file mode 100644 index 00000000..d5b99be3 --- /dev/null +++ b/example/dxlang/tensor/nn/mlp_small.dx @@ -0,0 +1,46 @@ +# mlp_small: 模拟单隐藏层 MLP forward pass(无控制流,逐 op 展开) +# 联调 VM + heap-plat (newtensor/deltensor) + op-metal (add/mul/relu) +# 数学: hidden = relu(X * W + B), output = hidden * W2 + B2 +def mlp_small() -> ("/data/output") { + # ── 分配所有权重/偏置/中间结果 ── + newtensor("f32", "[64]") -> "/data/X" + newtensor("f32", "[64]") -> "/data/W" + newtensor("f32", "[64]") -> "/data/B" + newtensor("f32", "[64]") -> "/data/W2" + newtensor("f32", "[64]") -> "/data/B2" + newtensor("f32", "[64]") -> "/data/t1" + newtensor("f32", "[64]") -> "/data/t2" + newtensor("f32", "[64]") -> "/data/hidden" + newtensor("f32", "[64]") -> "/data/t3" + newtensor("f32", "[64]") -> "/data/output" + + # 初始化权重和输入 + zeros() -> "/data/X" + zeros() -> "/data/W" + zeros() -> "/data/B" + zeros() -> "/data/W2" + zeros() -> "/data/B2" + + # ── Layer 1: hidden = relu(X * W + B) ── + mul("/data/X", "/data/W") -> "/data/t1" + add("/data/t1", "/data/B") -> "/data/t2" + relu("/data/t2") -> "/data/hidden" + + # ── Layer 2: output = hidden * W2 + B2 ── + mul("/data/hidden", "/data/W2") -> "/data/t3" + add("/data/t3", "/data/B2") -> "/data/output" + + # ── 清理中间结果,保留 output ── + deltensor("/data/t1") + deltensor("/data/t2") + deltensor("/data/t3") + deltensor("/data/hidden") + deltensor("/data/X") + deltensor("/data/W") + deltensor("/data/B") + deltensor("/data/W2") + deltensor("/data/B2") +} + +# Top-level call: mlp_small() -> "/data/output" +mlp_small() -> "/data/output" diff --git a/example/dxlang/tensor/nn/normalize.dx b/example/dxlang/tensor/nn/normalize.dx new file mode 100644 index 00000000..2d507fd4 --- /dev/null +++ b/example/dxlang/tensor/nn/normalize.dx @@ -0,0 +1,23 @@ +# normalize: 简单 "归一化" — center 后取绝对值 +# sub 和 abs 都是 op-metal GPU kernel +def normalize() -> ("/data/norm") { + newtensor("f32", "[256]") -> "/data/raw" + newtensor("f32", "[256]") -> "/data/center" + newtensor("f32", "[256]") -> "/data/centered" + newtensor("f32", "[256]") -> "/data/norm" + + zeros() -> "/data/raw" + zeros() -> "/data/center" + + # centered = raw - center + sub("/data/raw", "/data/center") -> "/data/centered" + # norm = abs(centered) + abs("/data/centered") -> "/data/norm" + + deltensor("/data/raw") + deltensor("/data/center") + deltensor("/data/centered") +} + +# Top-level call: normalize() -> "/data/norm" +normalize() -> "/data/norm" diff --git a/example/dxlang/tensor/nn/polynomial.dx b/example/dxlang/tensor/nn/polynomial.dx new file mode 100644 index 00000000..96ec6fbe --- /dev/null +++ b/example/dxlang/tensor/nn/polynomial.dx @@ -0,0 +1,39 @@ +# polynomial: 计算多项式 f(x) = a*x^2 + b*x + c 全在 GPU tensor 上 +# 使用 mul + add 组合,无 pow/sqrt 等 params 依赖的算子 +def polynomial() -> ("/data/result") { + newtensor("f32", "[128]") -> "/data/x" + newtensor("f32", "[128]") -> "/data/a" + newtensor("f32", "[128]") -> "/data/b" + newtensor("f32", "[128]") -> "/data/c" + newtensor("f32", "[128]") -> "/data/squared" + newtensor("f32", "[128]") -> "/data/term1" + newtensor("f32", "[128]") -> "/data/term2" + newtensor("f32", "[128]") -> "/data/result" + + zeros() -> "/data/x" + zeros() -> "/data/a" + zeros() -> "/data/b" + zeros() -> "/data/c" + + # x^2 + mul("/data/x", "/data/x") -> "/data/squared" + # a * x^2 + mul("/data/a", "/data/squared") -> "/data/term1" + # b * x + mul("/data/b", "/data/x") -> "/data/term2" + # term1 + term2 + add("/data/term1", "/data/term2") -> "/data/result" + # + c + add("/data/result", "/data/c") -> "/data/result" + + deltensor("/data/squared") + deltensor("/data/term1") + deltensor("/data/term2") + deltensor("/data/x") + deltensor("/data/a") + deltensor("/data/b") + deltensor("/data/c") +} + +# Top-level call: polynomial() -> "/data/result" +polynomial() -> "/data/result" diff --git a/executor/Makefile b/executor/Makefile new file mode 100644 index 00000000..4bc74f1f --- /dev/null +++ b/executor/Makefile @@ -0,0 +1,250 @@ +# ═══════════════════════════════════════════════════════════════ +# deepx Executor — Build & Integration Test Lifecycle Management +# ═══════════════════════════════════════════════════════════════ + +# ── Config ── +REDIS_ADDR ?= 127.0.0.1:16379 +REDIS_HOST := $(word 1,$(subst :, ,$(REDIS_ADDR))) +REDIS_PORT := $(word 2,$(subst :, ,$(REDIS_ADDR))) + +BUILD_DIR := /tmp/deepx/op-metal/build +HEAP_BUILD_DIR := /tmp/deepx/heap-metal/build +IO_BUILD_DIR := /tmp/deepx/io-metal/build +OP_BIN := $(BUILD_DIR)/deepx-op-metal +HEAP_BIN := $(HEAP_BUILD_DIR)/deepx-heap-metal +IO_BIN := $(IO_BUILD_DIR)/deepx-io-metal + +LOG_DIR := /tmp/deepx-logs +OP_LOG := $(LOG_DIR)/op-metal.log +HEAP_LOG := $(LOG_DIR)/heap-metal.log +IO_LOG := $(LOG_DIR)/io-metal.log +OP_PID_FILE := $(LOG_DIR)/op-metal.pid +HEAP_PID_FILE := $(LOG_DIR)/heap-metal.pid +IO_PID_FILE := $(LOG_DIR)/io-metal.pid + +.PHONY: all build build-op build-heap build-io \ + start-services start-op start-heap start-io \ + stop-services stop-op stop-heap stop-io \ + status check-services \ + test-unit test-integration \ + reset-redis clean clean-logs + +# ═══════════════════════════════════════════════════════════════ +# Build +# ═══════════════════════════════════════════════════════════════ + +all: build + +build: build-op build-heap build-io + +build-op: + @echo "=== Building op-metal ===" + bash op-metal/build.sh + @echo " → $(OP_BIN)" + +build-heap: + @echo "=== Building heap-metal ===" + bash heap-metal/build.sh + @echo " → $(HEAP_BIN)" + +build-io: + @echo "=== Building io-metal ===" + bash io-metal/build.sh + @echo " → $(IO_BIN)" + +# ═══════════════════════════════════════════════════════════════ +# Service Lifecycle (daemon mode) +# ═══════════════════════════════════════════════════════════════ + +start-services: start-op start-heap start-io + @echo "All services started." + +start-op: check-services + @if [ -f $(OP_PID_FILE) ] && kill -0 $$(cat $(OP_PID_FILE)) 2>/dev/null; then \ + echo "[op-metal] already running (pid=$$(cat $(OP_PID_FILE)))"; \ + else \ + mkdir -p $(LOG_DIR); \ + echo "Starting op-metal ($(REDIS_HOST):$(REDIS_PORT)) → $(OP_LOG)"; \ + $(OP_BIN) $(REDIS_HOST) $(REDIS_PORT) > $(OP_LOG) 2>&1 & \ + echo $$! > $(OP_PID_FILE); \ + sleep 1; \ + echo "[op-metal] started (pid=$$(cat $(OP_PID_FILE)))"; \ + fi + +start-heap: check-services + @if [ -f $(HEAP_PID_FILE) ] && kill -0 $$(cat $(HEAP_PID_FILE)) 2>/dev/null; then \ + echo "[heap-metal] already running (pid=$$(cat $(HEAP_PID_FILE)))"; \ + else \ + mkdir -p $(LOG_DIR); \ + echo "Starting heap-metal ($(REDIS_HOST):$(REDIS_PORT)) → $(HEAP_LOG)"; \ + $(HEAP_BIN) $(REDIS_HOST) $(REDIS_PORT) > $(HEAP_LOG) 2>&1 & \ + echo $$! > $(HEAP_PID_FILE); \ + sleep 1; \ + echo "[heap-metal] started (pid=$$(cat $(HEAP_PID_FILE)))"; \ + fi + +stop-services: stop-op stop-heap stop-io + @echo "All services stopped." + +stop-op: + @if [ -f $(OP_PID_FILE) ]; then \ + pid=$$(cat $(OP_PID_FILE)); \ + if kill -0 $$pid 2>/dev/null; then \ + echo "Stopping op-metal (pid=$$pid)"; \ + kill $$pid 2>/dev/null || true; \ + sleep 1; \ + if kill -0 $$pid 2>/dev/null; then \ + echo "Force killing op-metal"; \ + kill -9 $$pid 2>/dev/null || true; \ + fi; \ + fi; \ + rm -f $(OP_PID_FILE); \ + fi + +stop-heap: + @if [ -f $(HEAP_PID_FILE) ]; then \ + pid=$$(cat $(HEAP_PID_FILE)); \ + if kill -0 $$pid 2>/dev/null; then \ + echo "Stopping heap-metal (pid=$$pid)"; \ + kill $$pid 2>/dev/null || true; \ + sleep 1; \ + if kill -0 $$pid 2>/dev/null; then \ + echo "Force killing heap-metal"; \ + kill -9 $$pid 2>/dev/null || true; \ + fi; \ + fi; \ + rm -f $(HEAP_PID_FILE); \ + fi + +start-io: check-services + @if [ -f $(IO_PID_FILE) ] && kill -0 $$(cat $(IO_PID_FILE)) 2>/dev/null; then \ + echo "[io-metal] already running (pid=$$(cat $(IO_PID_FILE)))"; \ + else \ + mkdir -p $(LOG_DIR); \ + echo "Starting io-metal ($(REDIS_HOST):$(REDIS_PORT)) → $(IO_LOG)"; \ + $(IO_BIN) $(REDIS_HOST) $(REDIS_PORT) > $(IO_LOG) 2>&1 & \ + echo $$! > $(IO_PID_FILE); \ + sleep 1; \ + echo "[io-metal] started (pid=$$(cat $(IO_PID_FILE)))"; \ + fi + +stop-io: + @if [ -f $(IO_PID_FILE) ]; then \ + pid=$$(cat $(IO_PID_FILE)); \ + if kill -0 $$pid 2>/dev/null; then \ + echo "Stopping io-metal (pid=$$pid)"; \ + kill $$pid 2>/dev/null || true; \ + sleep 1; \ + if kill -0 $$pid 2>/dev/null; then \ + echo "Force killing io-metal"; \ + kill -9 $$pid 2>/dev/null || true; \ + fi; \ + fi; \ + rm -f $(IO_PID_FILE); \ + fi + +# ── Status ── + +status: + @echo "=== deepx Executor Status ===" + @echo "Redis: $(REDIS_ADDR)" + @redis-cli -h $(REDIS_HOST) -p $(REDIS_PORT) PING 2>/dev/null || echo " → NOT REACHABLE" + @echo "" + @if [ -f $(OP_PID_FILE) ] && kill -0 $$(cat $(OP_PID_FILE)) 2>/dev/null; then \ + echo "op-metal: RUNNING (pid=$$(cat $(OP_PID_FILE)))"; \ + else \ + echo "op-metal: STOPPED"; \ + fi + @if [ -f $(HEAP_PID_FILE) ] && kill -0 $$(cat $(HEAP_PID_FILE)) 2>/dev/null; then \ + echo "heap-metal: RUNNING (pid=$$(cat $(HEAP_PID_FILE)))"; \ + else \ + echo "heap-metal: STOPPED"; \ + fi + @if [ -f $(IO_PID_FILE) ] && kill -0 $$(cat $(IO_PID_FILE)) 2>/dev/null; then \ + echo "io-metal: RUNNING (pid=$$(cat $(IO_PID_FILE)))"; \ + else \ + echo "io-metal: STOPPED"; \ + fi + +check-services: + @redis-cli -h $(REDIS_HOST) -p $(REDIS_PORT) PING > /dev/null 2>&1 || \ + (echo "ERROR: Redis not reachable at $(REDIS_ADDR)" && exit 1) + +# ═══════════════════════════════════════════════════════════════ +# Testing +# ═══════════════════════════════════════════════════════════════ + +# Unit tests (no Redis needed) +test-unit: + cd vm && go test -count=1 ./... + +# Integration tests (Redis needed, no plats) +test-integration: + cd vm && REDIS_ADDR=$(REDIS_ADDR) go test -tags=integration -count=1 -v -run 'TestIntegration' ./... + +# ═══════════════════════════════════════════════════════════════ +# Full Pipeline +# ═══════════════════════════════════════════════════════════════ + +# all-in-one: build → start → reset → stop +pipeline: build start-services reset-redis stop-services + @echo "=== Pipeline complete ===" + +# ═══════════════════════════════════════════════════════════════ +# Utilities +# ═══════════════════════════════════════════════════════════════ + +reset-redis: + @echo "Resetting Redis at $(REDIS_ADDR)..." + redis-cli -h $(REDIS_HOST) -p $(REDIS_PORT) FLUSHDB + +clean: + rm -rf $(BUILD_DIR) $(HEAP_BUILD_DIR) $(IO_BUILD_DIR) + cd vm && go clean -testcache + +clean-logs: + rm -rf $(LOG_DIR) + +# Show logs +logs-op: + tail -f $(OP_LOG) + +logs-heap: + tail -f $(HEAP_LOG) + +logs-io: + tail -f $(IO_LOG) + +help: + @echo "deepx Executor Makefile" + @echo "" + @echo "BUILD:" + @echo " make build Build op-metal + heap-metal + io-metal" + @echo " make build-op Build op-metal only" + @echo " make build-heap Build heap-metal only" + @echo " make build-io Build io-metal only" + @echo "" + @echo "SERVICES (daemon):" + @echo " make start-services Start op-metal + heap-metal + io-metal in background" + @echo " make stop-services Stop all services" + @echo " make status Check service status" + @echo "" + @echo "TESTING:" + @echo " make test-unit Run VM unit tests (no Redis)" + @echo " make test-integration Run VM integration tests (Redis, no plats)" + @echo "" + @echo "PIPELINE:" + @echo " make pipeline Full cycle: build → start → reset → stop" + @echo "" + @echo "UTILS:" + @echo " make reset-redis FLUSHDB (clean test state)" + @echo " make clean Remove build artifacts" + @echo " make logs-op Tail op-metal logs" + @echo " make logs-heap Tail heap-metal logs" + @echo " make logs-io Tail io-metal logs" + @echo "" + @echo "Config via env:" + @echo " REDIS_ADDR=$(REDIS_ADDR)" + @echo " OP_BIN=$(OP_BIN)" + @echo " HEAP_BIN=$(HEAP_BIN)" + @echo " IO_BIN=$(IO_BIN)" diff --git a/executor/common-metal/CMakeLists.txt b/executor/common-metal/CMakeLists.txt new file mode 100644 index 00000000..9b21ea7f --- /dev/null +++ b/executor/common-metal/CMakeLists.txt @@ -0,0 +1,27 @@ +cmake_minimum_required(VERSION 3.15) +project(deepx-common-metal LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED True) + +include_directories(include) + +find_library(METAL Metal) +find_library(FOUNDATION Foundation) + +add_library(deepx_common_metal STATIC + src/shmem/shm_tensor.cpp + src/metal_device.cpp +) + +# metal_device.cpp 包含 ObjC Metal API, 需要 Objective-C++ 编译 +set_source_files_properties(src/metal_device.cpp PROPERTIES COMPILE_FLAGS "-x objective-c++") + +target_link_libraries(deepx_common_metal PUBLIC ${METAL} ${FOUNDATION}) +target_include_directories(deepx_common_metal PUBLIC + $ + $ +) +set_target_properties(deepx_common_metal PROPERTIES + XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC YES +) diff --git a/executor/common-metal/build.sh b/executor/common-metal/build.sh new file mode 100644 index 00000000..c1cb9379 --- /dev/null +++ b/executor/common-metal/build.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail +DIR="$(cd "$(dirname "$0")" && pwd)" +BUILD_DIR="/tmp/deepx/common-metal/build" +mkdir -p "$BUILD_DIR" +cd "$BUILD_DIR" +cmake "$DIR" +cmake --build . -j$(sysctl -n hw.ncpu 2>/dev/null || nproc) diff --git a/executor/op-mem-mps/src/deepx/mps_device.hpp b/executor/common-metal/include/deepx/metal_device.hpp similarity index 68% rename from executor/op-mem-mps/src/deepx/mps_device.hpp rename to executor/common-metal/include/deepx/metal_device.hpp index 1235fbda..27a93874 100644 --- a/executor/op-mem-mps/src/deepx/mps_device.hpp +++ b/executor/common-metal/include/deepx/metal_device.hpp @@ -2,12 +2,12 @@ #include -namespace deepx::mps +namespace deepx::metal { struct DeviceInfo { std::string name; - bool supports_mps{false}; + bool supports_metal{false}; }; DeviceInfo get_default_device_info(); diff --git a/executor/common-metal/include/deepx/registry.h b/executor/common-metal/include/deepx/registry.h new file mode 100644 index 00000000..1ac0b499 --- /dev/null +++ b/executor/common-metal/include/deepx/registry.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +namespace deepx::heap { + +// 一个 tensor 的 Redis 元数据 +struct TensorMeta { + std::string name; // tensor name + std::string shm_name; // POSIX shm name (e.g. "/deepx_t_abc123") + std::string dtype; // "f32", "i32", etc. + std::string shape; // "[2,3,4]" + int64_t device = 0; + int64_t byte_size = 0; + int64_t refcount = 0; + int64_t owner_pid = 0; + int64_t ctime = 0; + std::string state; // "ready", "deleted" +}; + +// Registry 接口 — 抽象 Redis 后端。 +// 当前实现可以是 Redis,后续可替换为 etcd/文件等。 +class Registry { +public: + virtual ~Registry() = default; + + // 创建或获取一个 tensor。返回 shm_name。 + // 如果 tensor 已存在,增加引用计数。 + virtual std::string create_or_get(const std::string &name, + const std::string &dtype, + const std::string &shape, + int64_t device, + int64_t byte_size, + int64_t pid, + const std::string &shm_name) = 0; + + // 引用计数 +1 + virtual int64_t ref_inc(const std::string &name) = 0; + + // 引用计数 -1;若为 0 则标记可回收 + virtual int64_t ref_dec(const std::string &name) = 0; + + // 获取 tensor 元数据 + virtual bool get_meta(const std::string &name, TensorMeta &out) = 0; +}; + +} // namespace deepx::heap diff --git a/executor/common-metal/include/deepx/shmem/shm_tensor.h b/executor/common-metal/include/deepx/shmem/shm_tensor.h new file mode 100644 index 00000000..1378bbbb --- /dev/null +++ b/executor/common-metal/include/deepx/shmem/shm_tensor.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +namespace deepx::shmem { + +// POSIX shared memory tensor allocation. +// On Apple Silicon (UMA), this memory is directly GPU-accessible via +// MTLBuffer(newBufferWithBytesNoCopy). + +struct ShmTensor { + std::string shm_name; // e.g. "/deepx_t_" + void *addr = nullptr; + size_t byte_size = 0; + int fd = -1; + int refcount = 0; +}; + +// Create a new POSIX shm region for a tensor. +// Returns true on success, fills `out`. +bool shm_tensor_create(const std::string &shm_name, size_t byte_size, ShmTensor &out); + +// Open an existing shm region. Returns true on success. +bool shm_tensor_open(const std::string &shm_name, size_t byte_size, ShmTensor &out); + +// Close (unmap + close fd). Does NOT unlink. +void shm_tensor_close(ShmTensor &t); + +// Unlink the shm from the filesystem (after all users closed). +void shm_tensor_unlink(const std::string &shm_name); + +// Page-aligned size for the given byte count. +size_t shm_page_align(size_t byte_size); + +} // namespace deepx::shmem diff --git a/executor/op-mem-mps/src/deepx/mps_device.mm b/executor/common-metal/src/metal_device.cpp similarity index 65% rename from executor/op-mem-mps/src/deepx/mps_device.mm rename to executor/common-metal/src/metal_device.cpp index d67972d1..cdde7303 100644 --- a/executor/op-mem-mps/src/deepx/mps_device.mm +++ b/executor/common-metal/src/metal_device.cpp @@ -1,10 +1,9 @@ #import #import -#import -#include "deepx/mps_device.hpp" +#include "deepx/metal_device.hpp" -namespace deepx::mps +namespace deepx::metal { DeviceInfo get_default_device_info() { @@ -14,12 +13,12 @@ DeviceInfo get_default_device_info() if (!device) { info.name = "none"; - info.supports_mps = false; + info.supports_metal = false; return info; } info.name = std::string([[device name] UTF8String]); - info.supports_mps = true; + info.supports_metal = true; return info; } } diff --git a/executor/common-metal/src/shmem/shm_tensor.cpp b/executor/common-metal/src/shmem/shm_tensor.cpp new file mode 100644 index 00000000..7539f902 --- /dev/null +++ b/executor/common-metal/src/shmem/shm_tensor.cpp @@ -0,0 +1,94 @@ +#include "deepx/shmem/shm_tensor.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace deepx::shmem { + +size_t shm_page_align(size_t byte_size) { + static long ps = sysconf(_SC_PAGESIZE); + if (ps <= 0) ps = 16384; // Apple Silicon default + return (byte_size + ps - 1) & ~(ps - 1); +} + +bool shm_tensor_create(const std::string &shm_name, size_t byte_size, ShmTensor &out) { + size_t aligned = shm_page_align(byte_size); + + int fd = shm_open(shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR, 0600); + if (fd < 0) { + fprintf(stderr, "shm_tensor_create: shm_open(%s) failed: %s\n", + shm_name.c_str(), strerror(errno)); + return false; + } + + if (ftruncate(fd, aligned) < 0) { + fprintf(stderr, "shm_tensor_create: ftruncate failed: %s\n", strerror(errno)); + close(fd); + shm_unlink(shm_name.c_str()); + return false; + } + + void *addr = mmap(nullptr, aligned, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (addr == MAP_FAILED) { + fprintf(stderr, "shm_tensor_create: mmap failed: %s\n", strerror(errno)); + close(fd); + shm_unlink(shm_name.c_str()); + return false; + } + + close(fd); + + out.shm_name = shm_name; + out.addr = addr; + out.byte_size = byte_size; // original requested size + out.fd = -1; // already closed after mmap + out.refcount = 1; + return true; +} + +bool shm_tensor_open(const std::string &shm_name, size_t byte_size, ShmTensor &out) { + size_t aligned = shm_page_align(byte_size); + + int fd = shm_open(shm_name.c_str(), O_RDWR, 0600); + if (fd < 0) { + fprintf(stderr, "shm_tensor_open: shm_open(%s) failed: %s\n", + shm_name.c_str(), strerror(errno)); + return false; + } + + void *addr = mmap(nullptr, aligned, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (addr == MAP_FAILED) { + fprintf(stderr, "shm_tensor_open: mmap failed: %s\n", strerror(errno)); + close(fd); + return false; + } + + close(fd); + + out.shm_name = shm_name; + out.addr = addr; + out.byte_size = byte_size; + out.fd = -1; + out.refcount = 1; + return true; +} + +void shm_tensor_close(ShmTensor &t) { + if (t.addr && t.byte_size > 0) { + munmap(t.addr, shm_page_align(t.byte_size)); + t.addr = nullptr; + } + t.byte_size = 0; + t.refcount = 0; +} + +void shm_tensor_unlink(const std::string &shm_name) { + shm_unlink(shm_name.c_str()); +} + +} // namespace deepx::shmem diff --git a/executor/deepxcore/README.md b/executor/deepxcore/README.md deleted file mode 100644 index 37391f77..00000000 --- a/executor/deepxcore/README.md +++ /dev/null @@ -1,72 +0,0 @@ -# deepxcore - -deepxcore 是 deepx 执行器层与统一存算面协议共享的 C++ 核心基础库。 - -它的目标是提供稳定、跨执行器可复用的数据模型与协议对象,避免把 CUDA/Metal/CPU 等具体实现细节渗透到上层与其他组件,从而保证进程间与代码组件的隔离。 - -## 定位 -- 面向:执行器进程(heapmem-*、op-*)、统一存算面 SDK、调度/编译侧的 C++ 组件 -- 提供:dtype/shape/tensor 等基础数据结构、协议对象的结构化表达、配置与序列化基础设施 -- 不提供:具体硬件算子实现、显存/IPC 生命周期实现、调度编译逻辑 - -## 职责 - -### 1) 基础数据模型 -- `DType`:数据类型描述与大小/对齐等基础能力 -- `Shape`:维度/元素数量/bytes 计算、shape 合法性检查 -- `Tensor`:Tensor 元信息与句柄表达(不绑定具体设备实现) - -这些类型应作为所有执行器的共同语言,保证跨组件传递时语义一致。 - -### 2) 统一存算面协议对象 -用于在统一寻址空间(如 Redis KV)与执行器之间传递的数据结构,例如: -- tensor 元信息记录(name/key、dtype、shape、device、bytes、ctime 等) -- 生命周期指令(create/get/delete 等) - -deepxcore 只负责“结构化表达与编解码”,不负责“真正分配/回收/IPC 映射”。 - -### 3) 序列化/反序列化与配置 -- 将协议对象、元信息在 JSON/YAML/二进制之间进行编解码 -- 读取执行器/客户端的配置(例如地址、设备策略、协议版本等) - -目标是让其他组件不要各自实现一套解析与校验逻辑。 - -### 4) 通用基础设施 -- 轻量的错误与返回值表达(Status/Result) -- 字符串、文件系统等工具的薄封装 - -要求保持依赖尽量少、接口稳定、与具体硬件/运行时解耦。 - -## 非职责(边界) - -### 不做硬件绑定 -- 不直接依赖 CUDA/Metal/ROCm/NCCL 等 -- 不实现任何具体算子 kernel - -这些应由 `op-cuda`、`op-ompsimd`、`op-mem-mps` 等执行器承担。 - -### 不做堆 tensor 生命周期与 IPC -- 不管理持久堆 tensor 的分配/回收 -- 不负责 CUDA IPC handle 的创建/打开/关闭 - -这些应由 `heapmem-cuda` 这类“统一寻址空间的 tensor 具体实现”承担。 - -### 不做编译与调度 -- 不负责 deepxIR 的编译替换、fusion、分布式调度 - -这些属于中端编译器与调度器。 - -## 与其他组件的关系 - -- heapmem-*:owner 侧负责堆 tensor 生命周期与跨进程共享;deepxcore 提供 dtype/shape/协议对象 -- op-*:算子执行器负责栈 tensor(中间变量)与 kernel;deepxcore 提供基础数据模型与统一的元信息表达 -- 前端/SDK:通过统一协议把计算图与 tensor 元信息写入统一寻址空间;deepxcore 是 C++ 侧共用的协议层 - -## 目录 -- `src/`:核心库实现 -- `test/`:单元测试 - -## 构建 -本库通过 CMake 构建,并作为其他执行器目标的依赖被链接。 - -在上层执行器中使用时,通常只需要链接 `deepxcore` 目标,并包含对应头文件。 \ No newline at end of file diff --git a/executor/deepxcore/src/deepx/dtype.hpp b/executor/deepxcore/src/deepx/dtype.hpp deleted file mode 100644 index 0a2625aa..00000000 --- a/executor/deepxcore/src/deepx/dtype.hpp +++ /dev/null @@ -1,517 +0,0 @@ -#ifndef DEEPX_DTYPE_HPP -#define DEEPX_DTYPE_HPP - -#include -#include -#include -#include -#include "stdutil/string.hpp" -#include "stdutil/num.hpp" - -namespace deepx -{ - using namespace std; - - template - T to(const std::string &textvalue) - { - if constexpr (std::is_same_v) - { - return textvalue; - } - else if constexpr (std::is_arithmetic_v) - { - return static_cast(std::stof(textvalue)); - } - else - { - // 对于其他类型,尝试从字符串转换 - T value; - std::istringstream iss(textvalue); - iss >> value; - return value; - } - } - - enum class DataCategory : uint8_t - { - Unknown = 0, - Var = 1 << 0, // 变量类型 - Vector = 1 << 1, // 向量类型 - Tensor = 1 << 2, // 张量类型 - ListTensor = 1 << 3, // 张量列表类型 - // 4-15预留 - }; - - // 在DataCategory枚举定义后添加位运算操作符 - inline DataCategory operator|(DataCategory a, DataCategory b) - { - return static_cast( - static_cast(a) | static_cast(b)); - } - - inline DataCategory operator&(DataCategory a, DataCategory b) - { - return static_cast( - static_cast(a) & static_cast(b)); - } - - // 修改base_category_str函数以支持组合类型 - inline std::string base_category_str(DataCategory category) - { - std::vector types; - uint8_t value = static_cast(category); - - if (value & static_cast(DataCategory::Tensor)) - types.push_back("tensor"); - if (value & static_cast(DataCategory::Vector)) - types.push_back("vector"); - if (value & static_cast(DataCategory::Var)) - types.push_back("var"); - if (value & static_cast(DataCategory::ListTensor)) - types.push_back("listtensor"); - - if (types.empty()) - return "unknown"; - - std::string result = types[0]; - for (size_t i = 1; i < types.size(); i++) - { - result += "|" + types[i]; - } - return result; - } - - // 修改base_category函数以支持组合类型 - inline DataCategory base_category(const std::string &str) - { - if (str.find('|') == std::string::npos) - { - // 处理单一类型 - if (str == "tensor") - return DataCategory::Tensor; - else if (str == "vector") - return DataCategory::Vector; - else if (str == "var") - return DataCategory::Var; - else if (str == "listtensor") - return DataCategory::ListTensor; - return DataCategory::Unknown; - } - - // 处理组合类型 - DataCategory result = DataCategory::Unknown; - size_t start = 0; - size_t pos; - - while ((pos = str.find('|', start)) != std::string::npos) - { - std::string type = str.substr(start, pos - start); - result = result | base_category(type); - start = pos + 1; - } - - // 处理最后一个类型 - result = result | base_category(str.substr(start)); - return result; - } - - // 将Precision改为位图形式 - enum class Precision : uint16_t - { - // 浮点类型 (0-7位) - Float64 = 1 << 0, // 0000 0000 0000 0001 - Float32 = 1 << 1, // 0000 0000 0000 0010 - Float16 = 1 << 2, // 0000 0000 0000 0100 // E5M10B15 - BFloat16 = 1 << 3, // 0000 0000 0000 1000 // E8M7B127 - Float8E5M2 = 1 << 4, // 0000 0000 0001 0000 // E5M2B15 - Float8E4M3 = 1 << 5, // 0000 0000 0010 0000 // E4M3B7 - Float4E2M1 = 1 << 6, // 0000 0000 0100 0000 // E2M1B3 - - // 整型 (8-12位) - Int64 = 1 << 8, // 0000 0001 0000 0000 - Int32 = 1 << 9, // 0000 0010 0000 0000 - Int16 = 1 << 10, // 0000 0100 0000 0000 - Int8 = 1 << 11, // 0000 1000 0000 0000 - Int4 = 1 << 12, // 0001 0000 0000 0000 - - // 布尔类型 (13位) - Bool = 1 << 13, // 0010 0000 0000 0000 - String = 1 << 15, // 0100 0000 0000 0000 - // 常用组合 - Any = 0xFFFF, // 1111 1111 1111 1111 - Float = Float64 | Float32 | Float16 | BFloat16 | Float8E5M2 | Float8E4M3 | Float4E2M1, - Float8 = Float8E5M2 | Float8E4M3, // 所有FP8格式 - Int = Int64 | Int32 | Int16 | Int8 | Int4 - }; - - // 添加位运算操作符 - inline Precision operator|(Precision a, Precision b) - { - return static_cast( - static_cast(a) | static_cast(b)); - } - - inline Precision operator&(Precision a, Precision b) - { - return static_cast( - static_cast(a) & static_cast(b)); - } - // 在Precision枚举定义后添加位数获取函数 - inline constexpr int precision_bits(Precision p) - { - switch (p) - { - case Precision::Float64: - return 64; - case Precision::Float32: - return 32; - case Precision::Float16: - return 16; - case Precision::BFloat16: - return 16; - case Precision::Float8E5M2: - return 8; - case Precision::Float8E4M3: - return 8; - //TODO 需要根据平台支持 - // case Precision::Float4E2M1: - // return 4; - case Precision::Int64: - return 64; - case Precision::Int32: - return 32; - case Precision::Int16: - return 16; - case Precision::Int8: - return 8; - //TODO,int4 需要根据平台支持 - // case Precision::Int4: - // return 4; - case Precision::Bool: - return 8; - case Precision::String: - case Precision::Any: - default: - return 0; - } - } - - // 删除DataCategory,直接在DataType中使用BaseCategory - union TypeDef - { - struct - { - DataCategory category : 8; // 基础类型 - Precision precision : 16; // 精度类型 - uint8_t reserved : 8; // 保留位 - } parts; - uint32_t value; // 整体访问 - - // 构造函数 - constexpr TypeDef() : value(0) {} - - // 修改构造函数,使用初始化列表 - constexpr TypeDef(DataCategory c, Precision p) : value(0) - { - parts.category = c; - parts.precision = p; - } - - bool operator==(const TypeDef &other) const - { - return value == other.value; - } - - bool operator!=(const TypeDef &other) const - { - return value != other.value; - } - - // 判断other是否在当前类型的精度范围内 - bool match(const TypeDef &other) const - { - // 类型必须相同 - uint8_t this_cat = static_cast(parts.category); - uint8_t other_cat = static_cast(other.parts.category); - if ((this_cat & other_cat) != this_cat) - { - return false; - } - - // 使用位操作检查precision - // 检查this的precision位是否都在other的precision中 - uint16_t this_prec = static_cast(parts.precision); - uint16_t other_prec = static_cast(other.parts.precision); - return (this_prec & other_prec) == this_prec; - } - constexpr DataCategory category() const - { - return parts.category; - } - - constexpr Precision precision() const - { - return parts.precision; - } - }; - - // 辅助函数用于创建DataType - constexpr TypeDef make_dtype(DataCategory category, Precision precision) - { - return TypeDef(category, precision); - } - - // 修改precision_str函数以使用标准命名格式 - inline std::string precision_str(Precision p) - { - if (p == Precision::Any) - return "any"; - - std::vector types; - uint16_t value = static_cast(p); - - if (value & static_cast(Precision::Float64)) - types.push_back("float64"); - if (value & static_cast(Precision::Float32)) - types.push_back("float32"); - if (value & static_cast(Precision::Float16)) - types.push_back("float16"); // 改回float16 - if (value & static_cast(Precision::BFloat16)) - types.push_back("bfloat16"); // 改回bfloat16 - if (value & static_cast(Precision::Float8E5M2)) - types.push_back("float8e5m2"); - if (value & static_cast(Precision::Float8E4M3)) - types.push_back("float8e4m3"); - if (value & static_cast(Precision::Float4E2M1)) - types.push_back("float4e2m1"); - if (value & static_cast(Precision::Int64)) - types.push_back("int64"); - if (value & static_cast(Precision::Int32)) - types.push_back("int32"); - if (value & static_cast(Precision::Int16)) - types.push_back("int16"); - if (value & static_cast(Precision::Int8)) - types.push_back("int8"); - if (value & static_cast(Precision::Int4)) - types.push_back("int4"); - if (value & static_cast(Precision::Bool)) - types.push_back("bool"); - if (value & static_cast(Precision::String)) - types.push_back("string"); - if (types.empty()) - return "any"; - - std::string result = types[0]; - for (size_t i = 1; i < types.size(); i++) - { - result += "|" + types[i]; - } - return result; - } - - // 修改dtype_str函数 - inline std::string dtype_str(const TypeDef &dtype) - { - return base_category_str(dtype.parts.category) + - "<" + precision_str(dtype.parts.precision) + ">"; - } - - // 修改precision函数以匹配新的命名格式 - inline Precision precision(const std::string &str) - { - if (str == "any") - return Precision::Any; - else if (str == "float64") - return Precision::Float64; - else if (str == "float32") - return Precision::Float32; - else if (str == "float16") - return Precision::Float16; - else if (str == "bfloat16") - return Precision::BFloat16; - else if (str == "float8e5m2") - return Precision::Float8E5M2; - else if (str == "float8e4m3") - return Precision::Float8E4M3; - else if (str == "float4e2m1") - return Precision::Float4E2M1; - - // 添加组合类型支持 - else if (str == "int") - return Precision::Int; - else if (str == "float") - return Precision::Float; - else if (str == "float8") - return Precision::Float8; - - else if (str == "int64") - return Precision::Int64; - else if (str == "int32") - return Precision::Int32; - else if (str == "int16") - return Precision::Int16; - else if (str == "int8") - return Precision::Int8; - else if (str == "int4") - return Precision::Int4; - - else if (str == "bool") - return Precision::Bool; - else if (str == "string") - return Precision::String; - return Precision::Any; - } - - // 修改dtype函数,处理无精度标记的情况 - inline TypeDef dtype(const std::string &str) - { - size_t pos_start = str.find('<'); - size_t pos_end = str.find('>'); - - if (pos_start == std::string::npos || pos_end == std::string::npos) - { - // 无精度标记时,使用Any作为默认精度 - return make_dtype(base_category(str), Precision::Any); - } - - std::string category_str = str.substr(0, pos_start); - std::string precision_str = str.substr(pos_start + 1, pos_end - pos_start - 1); - - return make_dtype( - base_category(category_str), - precision(precision_str)); - } - - inline TypeDef autodtype(const std::string ¶m) - { - std::string type; - std::string textvalue; - std::vector vectorvalues; - bool vectorvalue = false; - if (param.back() == ']') - { - size_t bracket_start = param.find('['); - if (bracket_start != string::npos) - { - vectorvalue = true; - // 提取方括号内的内容作为textvalue - textvalue = param.substr(bracket_start + 1, param.length() - bracket_start - 2); - // 提取方括号前的内容作为type - type = param.substr(0, bracket_start); - // 去除type两端的空格 - stdutil::trim(type); - } - } - - if (!vectorvalue) - { - // 没有方括号,按空格分割 - stringstream ss(param); - string first, second; - ss >> first; - if (ss >> second) - { - // 如果能读取到两个部分 - type = first; - textvalue = second; - } - else - { - textvalue = first; - } - } - // 处理向量值 - if (vectorvalue) - { - // 分割字符串为向量 - stringstream ss(textvalue); - string item; - while (getline(ss, item, ' ')) - { - item.erase(0, item.find_first_not_of(" ")); - item.erase(item.find_last_not_of(" ") + 1); - if (!item.empty()) - { - vectorvalues.push_back(item); - } - } - } - - // 设置结果 - if (!type.empty()) - { - return dtype(type); - } - else - { - // 没有显式类型声明,根据值推断 - if (vectorvalue) - { - if (!vectorvalues.empty()) - { - if (is_integer(vectorvalues[0])) - { - return make_dtype(DataCategory::Vector, Precision::Int32); - } - else if (is_float(vectorvalues[0])) - { - return make_dtype(DataCategory::Vector, Precision::Float64); - } - else - { - return make_dtype(DataCategory::ListTensor, Precision::Any); - } - } - else - { - return make_dtype(DataCategory::Vector, Precision::Any); - } - } - else - { - return make_dtype(DataCategory::Var | DataCategory::Tensor, Precision::Any); - } - } - } - - template - struct PrecisionWrapper {}; - - template - struct to_tensor_type; - - template <> - struct to_tensor_type> { - using type = double; - }; - - template <> - struct to_tensor_type> { - using type = float; - }; - - template <> - struct to_tensor_type> { - using type = int64_t; - }; - - template <> - struct to_tensor_type> { - using type = int32_t; - }; - - template <> - struct to_tensor_type> { - using type = int16_t; - }; - - template <> - struct to_tensor_type> { - using type = int8_t; - }; - - template - using tensor_t = typename to_tensor_type>::type; -} // namespace deepx -#endif diff --git a/executor/deepxcore/test/0_dtypes.cpp b/executor/deepxcore/test/0_dtypes.cpp deleted file mode 100644 index 1761f010..00000000 --- a/executor/deepxcore/test/0_dtypes.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include "deepx/tf/tf.hpp" -#include "deepx/dtype.hpp" -#include -#include -using namespace std; -using namespace deepx::tf; -using namespace deepx; - -void test_1() { - unordered_map dtype_map = { - {"tensor", make_dtype(DataCategory::Tensor, Precision::Any)}, - {"tensor", make_dtype(DataCategory::Tensor, Precision::Int)}, - {"tensor", make_dtype(DataCategory::Tensor, Precision::Float64)}, - {"tensor", make_dtype(DataCategory::Tensor, Precision::Float32)}, - {"tensor", make_dtype(DataCategory::Tensor, Precision::Float16)}, - {"tensor", make_dtype(DataCategory::Tensor, Precision::BFloat16)}, - {"tensor", make_dtype(DataCategory::Tensor, Precision::Float8E5M2)}, - {"tensor", make_dtype(DataCategory::Tensor, Precision::Float8E4M3)}, - {"tensor", make_dtype(DataCategory::Tensor, Precision::Float4E2M1)}, - {"tensor", make_dtype(DataCategory::Tensor, Precision::Int32)}, - {"vector", make_dtype(DataCategory::Vector, Precision::Float64)}, - {"var", make_dtype(DataCategory::Var, Precision::Int32)}, - {"var", make_dtype(DataCategory::Var, Precision::Float32)}, - {"var", make_dtype(DataCategory::Var, Precision::Bool)}, - - {"tensor", make_dtype(DataCategory::Tensor, Precision::Any)}, - {"vector", make_dtype(DataCategory::Vector, Precision::Any)}, - {"var", make_dtype(DataCategory::Var, Precision::Any)}, - }; - - // 打印表头 - cout << string(80, '=') << endl; - cout << setw(25) << left << "Original Type" - << setw(15) << "Status" - << "Converted Back" << endl; - cout << string(80, '-') << endl; - - // 遍历所有类型进行测试 - for (const auto &[type_str1, dtype1] : dtype_map) - { - // 将type_str1转换为DataType - TypeDef converted1 = dtype(type_str1); - // 检查转换后的DataType是否与原始值相等 - bool equal1 = (converted1 == dtype1); - // 将转换后的DataType转回字符串 - string back_str1 = dtype_str(converted1); - - // 输出测试结果 - cout << setw(25) << left << type_str1 - << setw(15) << (equal1 ? "[MATCH]" : "[MISMATCH]") - << back_str1 << endl; - } - - cout << string(80, '=') << endl; -} - -// test to tensor type -void test_2() { - if (typeid(tensor_t)== typeid(double)) { - std::cout<<"it's ok"<)== typeid(float)) { - std::cout<<"it's ok"< +#include +#include +#include +namespace deepx +{ + using namespace std; + + enum class DataCategory : uint8_t + { + Unknown = 0, + Var = 1 << 0, // 变量类型 + Vector = 1 << 1, // 向量类型 + Tensor = 1 << 2, // 张量类型 + ListTensor = 1 << 3, // 张量列表类型 + // 4-15预留 + }; + + // 在DataCategory枚举定义后添加位运算操作符 + inline DataCategory operator|(DataCategory a, DataCategory b) + { + return static_cast( + static_cast(a) | static_cast(b)); + } + + inline DataCategory operator&(DataCategory a, DataCategory b) + { + return static_cast( + static_cast(a) & static_cast(b)); + } + + // 修改base_category_str函数以支持组合类型 + inline std::string base_category_to_string(DataCategory category) + { + std::vector types; + uint8_t value = static_cast(category); + + if (value & static_cast(DataCategory::Tensor)) + types.push_back("tensor"); + if (value & static_cast(DataCategory::Vector)) + types.push_back("vector"); + if (value & static_cast(DataCategory::Var)) + types.push_back("var"); + if (value & static_cast(DataCategory::ListTensor)) + types.push_back("listtensor"); + + if (types.empty()) + return "unknown"; + + std::string result = types[0]; + for (size_t i = 1; i < types.size(); i++) + { + result += "|" + types[i]; + } + return result; + } + + // 修改base_category函数以支持组合类型 + inline DataCategory base_category_from_string(const std::string &str) + { + if (str.find('|') == std::string::npos) + { + // 处理单一类型 + if (str == "tensor") + return DataCategory::Tensor; + else if (str == "vector") + return DataCategory::Vector; + else if (str == "var") + return DataCategory::Var; + else if (str == "listtensor") + return DataCategory::ListTensor; + return DataCategory::Unknown; + } + + // 处理组合类型 + DataCategory result = DataCategory::Unknown; + size_t start = 0; + size_t pos; + + while ((pos = str.find('|', start)) != std::string::npos) + { + std::string type = str.substr(start, pos - start); + result = result | base_category_from_string(type); + start = pos + 1; + } + + // 处理最后一个类型 + result = result | base_category_from_string(str.substr(start)); + return result; + } +} + +#endif \ No newline at end of file diff --git a/executor/dxlang/src/deepx/dtype/precision.hpp b/executor/dxlang/src/deepx/dtype/precision.hpp new file mode 100644 index 00000000..b40ec091 --- /dev/null +++ b/executor/dxlang/src/deepx/dtype/precision.hpp @@ -0,0 +1,189 @@ + +#ifndef DEEPX_DTYPE_PRECISION_HPP +#define DEEPX_DTYPE_PRECISION_HPP + +#include +#include +#include +#include +namespace deepx +{ + // 将Precision改为位图形式 + enum class Precision : uint16_t + { + // 浮点类型 (0-7位) + Float64 = 1 << 0, // 0000 0000 0000 0001 + Float32 = 1 << 1, // 0000 0000 0000 0010 + Float16 = 1 << 2, // 0000 0000 0000 0100 // E5M10B15 + BFloat16 = 1 << 3, // 0000 0000 0000 1000 // E8M7B127 + Float8E5M2 = 1 << 4, // 0000 0000 0001 0000 // E5M2B15 + Float8E4M3 = 1 << 5, // 0000 0000 0010 0000 // E4M3B7 + Float4E2M1 = 1 << 6, // 0000 0000 0100 0000 // E2M1B3 + + // 整型 (8-12位) + Int64 = 1 << 8, // 0000 0001 0000 0000 + Int32 = 1 << 9, // 0000 0010 0000 0000 + Int16 = 1 << 10, // 0000 0100 0000 0000 + Int8 = 1 << 11, // 0000 1000 0000 0000 + Int4 = 1 << 12, // 0001 0000 0000 0000 + + // 布尔类型 (13位) + Bool = 1 << 13, // 0010 0000 0000 0000 + String = 1 << 15, // 0100 0000 0000 0000 + // 常用组合 + Any = 0xFFFF, // 1111 1111 1111 1111 + Float = Float64 | Float32 | Float16 | BFloat16 | Float8E5M2 | Float8E4M3 | Float4E2M1, + Float8 = Float8E5M2 | Float8E4M3, // 所有FP8格式 + Int = Int64 | Int32 | Int16 | Int8 | Int4 + }; + + // 添加位运算操作符 + inline Precision operator|(Precision a, Precision b) + { + return static_cast( + static_cast(a) | static_cast(b)); + } + + inline Precision operator&(Precision a, Precision b) + { + return static_cast( + static_cast(a) & static_cast(b)); + } + // 在Precision枚举定义后添加位数获取函数 + inline constexpr int precision_bits(Precision p) + { + switch (p) + { + case Precision::Float64: + return 64; + case Precision::Float32: + return 32; + case Precision::Float16: + return 16; + case Precision::BFloat16: + return 16; + case Precision::Float8E5M2: + return 8; + case Precision::Float8E4M3: + return 8; + // TODO 需要根据平台支持 + // case Precision::Float4E2M1: + // return 4; + case Precision::Int64: + return 64; + case Precision::Int32: + return 32; + case Precision::Int16: + return 16; + case Precision::Int8: + return 8; + // TODO,int4 需要根据平台支持 + // case Precision::Int4: + // return 4; + case Precision::Bool: + return 8; + case Precision::String: + case Precision::Any: + default: + return 0; + } + } + + // 修改precision函数以匹配新的命名格式 + inline Precision precision_from_string(const std::string &str) + { + if (str == "any") + return Precision::Any; + else if (str == "float64" || str == "f64") + return Precision::Float64; + else if (str == "float32" || str == "f32") + return Precision::Float32; + else if (str == "float16" || str == "f16") + return Precision::Float16; + else if (str == "bfloat16" || str == "bf16") + return Precision::BFloat16; + else if (str == "float8e5m2" || str == "f8e5m2") + return Precision::Float8E5M2; + else if (str == "float8e4m3" || str == "f8e4m3") + return Precision::Float8E4M3; + else if (str == "float4e2m1" || str == "f4e2m1") + return Precision::Float4E2M1; + + // 添加组合类型支持 + else if (str == "int") + return Precision::Int; + else if (str == "float") + return Precision::Float; + else if (str == "float8") + return Precision::Float8; + + else if (str == "int64" || str == "i64") + return Precision::Int64; + else if (str == "int32" || str == "i32") + return Precision::Int32; + else if (str == "int16" || str == "i16") + return Precision::Int16; + else if (str == "int8" || str == "i8") + return Precision::Int8; + else if (str == "int4" || str == "i4") + return Precision::Int4; + + else if (str == "bool") + return Precision::Bool; + else if (str == "string") + return Precision::String; + return Precision::Any; + } + + + + // 修改precision_str函数以使用标准命名格式 + inline std::string precision_to_string(Precision p) + { + if (p == Precision::Any) + return "any"; + + std::vector types; + uint16_t value = static_cast(p); + + if (value & static_cast(Precision::Float64)) + types.push_back("f64"); + if (value & static_cast(Precision::Float32)) + types.push_back("f32"); + if (value & static_cast(Precision::Float16)) + types.push_back("f16"); + if (value & static_cast(Precision::BFloat16)) + types.push_back("bf16"); + if (value & static_cast(Precision::Float8E5M2)) + types.push_back("f8e5m2"); + if (value & static_cast(Precision::Float8E4M3)) + types.push_back("f8e4m3"); + if (value & static_cast(Precision::Float4E2M1)) + types.push_back("f4e2m1"); + if (value & static_cast(Precision::Int64)) + types.push_back("i64"); + if (value & static_cast(Precision::Int32)) + types.push_back("i32"); + if (value & static_cast(Precision::Int16)) + types.push_back("i16"); + if (value & static_cast(Precision::Int8)) + types.push_back("i8"); + if (value & static_cast(Precision::Int4)) + types.push_back("i4"); + if (value & static_cast(Precision::Bool)) + types.push_back("bool"); + if (value & static_cast(Precision::String)) + types.push_back("string"); + if (types.empty()) + return "any"; + + std::string result = types[0]; + for (size_t i = 1; i < types.size(); i++) + { + result += "|" + types[i]; + } + return result; + } + +} +#endif \ No newline at end of file diff --git a/executor/dxlang/src/deepx/dtype/typespec.hpp b/executor/dxlang/src/deepx/dtype/typespec.hpp new file mode 100644 index 00000000..92dd9723 --- /dev/null +++ b/executor/dxlang/src/deepx/dtype/typespec.hpp @@ -0,0 +1,195 @@ +#ifndef DEEPX_DTYPE_TYPEDEF_HPP +#define DEEPX_DTYPE_TYPEDEF_HPP + +#include + +#include "stdutil/string.hpp" +#include "stdutil/num.hpp" + +#include "deepx/dtype/data_category.hpp" +#include "deepx/dtype/precision.hpp" + + +namespace deepx +{ + union TypeSpec + { + struct + { + DataCategory category : 8; // 基础类型 + Precision precision : 16; // 精度类型 + uint8_t reserved : 8; // 保留位 + } parts; + uint32_t value; // 整体访问 + + // 构造函数 + constexpr TypeSpec() : value(0) {} + + // 修改构造函数,使用初始化列表 + constexpr TypeSpec(DataCategory c, Precision p) : value(0) + { + parts.category = c; + parts.precision = p; + } + + bool operator==(const TypeSpec &other) const + { + return value == other.value; + } + + bool operator!=(const TypeSpec &other) const + { + return value != other.value; + } + + // 判断other是否在当前类型的精度范围内 + bool match(const TypeSpec &other) const + { + // 类型必须相同 + uint8_t this_cat = static_cast(parts.category); + uint8_t other_cat = static_cast(other.parts.category); + if ((this_cat & other_cat) != this_cat) + { + return false; + } + + // 使用位操作检查precision + // 检查this的precision位是否都在other的precision中 + uint16_t this_prec = static_cast(parts.precision); + uint16_t other_prec = static_cast(other.parts.precision); + return (this_prec & other_prec) == this_prec; + } + constexpr DataCategory category() const + { + return parts.category; + } + + constexpr Precision precision() const + { + return parts.precision; + } + string to_string() const + { + return base_category_to_string(parts.category) + + "<" + precision_to_string(parts.precision) + ">"; + } + void from_string(const string &str) + { + size_t pos_start = str.find('<'); + size_t pos_end = str.find('>'); + + if (pos_start == std::string::npos || pos_end == std::string::npos) + { + parts.category = base_category_from_string(str); + parts.precision = Precision::Any; + return; + } + + std::string category_str = str.substr(0, pos_start); + std::string precision_str = str.substr(pos_start + 1, pos_end - pos_start - 1); + + parts.category = base_category_from_string(category_str); + parts.precision = precision_from_string(precision_str); + } + }; + + inline TypeSpec typespec_from_string(const std::string &str){ + TypeSpec ts; + ts.from_string(str); + return ts; + } + + inline TypeSpec autodtype(const std::string ¶m) + { + std::string type; + std::string textvalue; + std::vector vectorvalues; + bool vectorvalue = false; + if (param.back() == ']') + { + size_t bracket_start = param.find('['); + if (bracket_start != string::npos) + { + vectorvalue = true; + // 提取方括号内的内容作为textvalue + textvalue = param.substr(bracket_start + 1, param.length() - bracket_start - 2); + // 提取方括号前的内容作为type + type = param.substr(0, bracket_start); + // 去除type两端的空格 + stdutil::trim(type); + } + } + + if (!vectorvalue) + { + // 没有方括号,按空格分割 + stringstream ss(param); + string first, second; + ss >> first; + if (ss >> second) + { + // 如果能读取到两个部分 + type = first; + textvalue = second; + } + else + { + textvalue = first; + } + } + // 处理向量值 + if (vectorvalue) + { + // 分割字符串为向量 + stringstream ss(textvalue); + string item; + while (getline(ss, item, ' ')) + { + item.erase(0, item.find_first_not_of(" ")); + item.erase(item.find_last_not_of(" ") + 1); + if (!item.empty()) + { + vectorvalues.push_back(item); + } + } + } + + // 设置结果 + if (!type.empty()) + { + return typespec_from_string(type); + } + else + { + // 没有显式类型声明,根据值推断 + if (vectorvalue) + { + if (!vectorvalues.empty()) + { + if (is_integer(vectorvalues[0])) + { + return TypeSpec(DataCategory::Vector, Precision::Int32); + } + else if (is_float(vectorvalues[0])) + { + return TypeSpec(DataCategory::Vector, Precision::Float64); + } + else + { + return TypeSpec(DataCategory::ListTensor, Precision::Any); + } + } + else + { + return TypeSpec(DataCategory::Vector, Precision::Any); + } + } + else + { + return TypeSpec(DataCategory::Var | DataCategory::Tensor, Precision::Any); + } + } + } + +} +#endif \ No newline at end of file diff --git a/executor/deepxcore/src/stdutil/error.hpp b/executor/dxlang/src/stdutil/error.hpp similarity index 100% rename from executor/deepxcore/src/stdutil/error.hpp rename to executor/dxlang/src/stdutil/error.hpp diff --git a/executor/deepxcore/src/stdutil/fs.cpp b/executor/dxlang/src/stdutil/fs.cpp similarity index 100% rename from executor/deepxcore/src/stdutil/fs.cpp rename to executor/dxlang/src/stdutil/fs.cpp diff --git a/executor/deepxcore/src/stdutil/fs.hpp b/executor/dxlang/src/stdutil/fs.hpp similarity index 100% rename from executor/deepxcore/src/stdutil/fs.hpp rename to executor/dxlang/src/stdutil/fs.hpp diff --git a/executor/deepxcore/src/stdutil/num.cpp b/executor/dxlang/src/stdutil/num.cpp similarity index 100% rename from executor/deepxcore/src/stdutil/num.cpp rename to executor/dxlang/src/stdutil/num.cpp diff --git a/executor/deepxcore/src/stdutil/num.hpp b/executor/dxlang/src/stdutil/num.hpp similarity index 100% rename from executor/deepxcore/src/stdutil/num.hpp rename to executor/dxlang/src/stdutil/num.hpp diff --git a/executor/deepxcore/src/stdutil/print.hpp b/executor/dxlang/src/stdutil/print.hpp similarity index 100% rename from executor/deepxcore/src/stdutil/print.hpp rename to executor/dxlang/src/stdutil/print.hpp diff --git a/executor/deepxcore/src/stdutil/string.cpp b/executor/dxlang/src/stdutil/string.cpp similarity index 100% rename from executor/deepxcore/src/stdutil/string.cpp rename to executor/dxlang/src/stdutil/string.cpp diff --git a/executor/deepxcore/src/stdutil/string.hpp b/executor/dxlang/src/stdutil/string.hpp similarity index 100% rename from executor/deepxcore/src/stdutil/string.hpp rename to executor/dxlang/src/stdutil/string.hpp diff --git a/executor/deepxcore/src/stdutil/time.hpp b/executor/dxlang/src/stdutil/time.hpp similarity index 100% rename from executor/deepxcore/src/stdutil/time.hpp rename to executor/dxlang/src/stdutil/time.hpp diff --git a/executor/deepxcore/src/stdutil/vector.hpp b/executor/dxlang/src/stdutil/vector.hpp similarity index 100% rename from executor/deepxcore/src/stdutil/vector.hpp rename to executor/dxlang/src/stdutil/vector.hpp diff --git a/executor/deepxcore/test/1_tf.cpp b/executor/dxlang/test/1_tf.cpp similarity index 100% rename from executor/deepxcore/test/1_tf.cpp rename to executor/dxlang/test/1_tf.cpp diff --git a/executor/deepxcore/test/1_tfcheck.cpp b/executor/dxlang/test/1_tfcheck.cpp similarity index 100% rename from executor/deepxcore/test/1_tfcheck.cpp rename to executor/dxlang/test/1_tfcheck.cpp diff --git a/executor/deepxcore/test/2_saveload.cpp b/executor/dxlang/test/2_saveload.cpp similarity index 100% rename from executor/deepxcore/test/2_saveload.cpp rename to executor/dxlang/test/2_saveload.cpp diff --git a/executor/deepxcore/test/CMakeLists.txt b/executor/dxlang/test/CMakeLists.txt similarity index 75% rename from executor/deepxcore/test/CMakeLists.txt rename to executor/dxlang/test/CMakeLists.txt index 28cb4906..122528d0 100644 --- a/executor/deepxcore/test/CMakeLists.txt +++ b/executor/dxlang/test/CMakeLists.txt @@ -1,7 +1,4 @@ -add_executable(test_dtypes 0_dtypes.cpp) -target_link_libraries(test_dtypes deepxcore) - add_executable(test_tf 1_tf.cpp) target_link_libraries(test_tf deepxcore) diff --git a/executor/heapmem-cuda/CMakeLists.txt b/executor/heap-cuda/CMakeLists.txt similarity index 100% rename from executor/heapmem-cuda/CMakeLists.txt rename to executor/heap-cuda/CMakeLists.txt diff --git a/executor/heapmem-cuda/README.md b/executor/heap-cuda/README.md similarity index 98% rename from executor/heapmem-cuda/README.md rename to executor/heap-cuda/README.md index 3b45efc6..bdb544b1 100644 --- a/executor/heapmem-cuda/README.md +++ b/executor/heap-cuda/README.md @@ -1,4 +1,4 @@ -# heapmem-cuda 方案草案 +# heap-cuda 方案草案 本目录用于设计/实现单机多进程的 GPU Tensor 统一存储面(CUDA IPC),并通过 Redis 做 name → IPC handle 的集中注册与控制。 @@ -89,7 +89,7 @@ mem-cuda/ graph LR subgraph 单机 RM["Redis (元数据 + 指令队列)"] - HMC["heapmem-cuda 进程"] + HMC["heap-cuda 进程"] CP["计算进程 (多进程)"] GPU["GPU"] end diff --git a/executor/heapmem-cuda/src/CMakeLists.txt b/executor/heap-cuda/src/CMakeLists.txt similarity index 100% rename from executor/heapmem-cuda/src/CMakeLists.txt rename to executor/heap-cuda/src/CMakeLists.txt diff --git a/executor/heapmem-cuda/src/registry/CMakeLists.txt b/executor/heap-cuda/src/registry/CMakeLists.txt similarity index 100% rename from executor/heapmem-cuda/src/registry/CMakeLists.txt rename to executor/heap-cuda/src/registry/CMakeLists.txt diff --git a/executor/heapmem-cuda/src/registry/registry.cpp b/executor/heap-cuda/src/registry/registry.cpp similarity index 100% rename from executor/heapmem-cuda/src/registry/registry.cpp rename to executor/heap-cuda/src/registry/registry.cpp diff --git a/executor/heapmem-cuda/src/registry/registry.h b/executor/heap-cuda/src/registry/registry.h similarity index 100% rename from executor/heapmem-cuda/src/registry/registry.h rename to executor/heap-cuda/src/registry/registry.h diff --git a/executor/heapmem-cuda/src/runtime/lifecycle.cpp b/executor/heap-cuda/src/runtime/lifecycle.cpp similarity index 86% rename from executor/heapmem-cuda/src/runtime/lifecycle.cpp rename to executor/heap-cuda/src/runtime/lifecycle.cpp index 52d4c3ff..1041581f 100644 --- a/executor/heapmem-cuda/src/runtime/lifecycle.cpp +++ b/executor/heap-cuda/src/runtime/lifecycle.cpp @@ -1,6 +1,7 @@ #include "lifecycle.h" #include +#include namespace memcuda { @@ -26,7 +27,12 @@ static std::string ExtractString(const std::string& json, const std::string& key static long long ExtractInt(const std::string& json, const std::string& key) { auto s = ExtractString(json, key); if (s.empty()) return 0; - return std::stoll(s); + try { + return std::stoll(s); + } catch (...) { + std::cerr << "[lifecycle] ExtractInt failed for key=" << key << " value=" << s << "\n"; + return 0; + } } LifecycleWorker::LifecycleWorker(Registry* registry, const std::string& queue_key) @@ -44,15 +50,15 @@ bool LifecycleWorker::Parse(const std::string& json, LifecycleCommand& out) cons } void LifecycleWorker::Handle(const LifecycleCommand& cmd) { - if (cmd.op == "create") { + if (cmd.op == "newtensor") { registry_->CreateOrGet(cmd.name, cmd.dtype, cmd.shape, cmd.device, 0, cmd.node, cmd.pid, 0, ""); return; } - if (cmd.op == "get") { + if (cmd.op == "gettensor") { registry_->RefInc(cmd.name); return; } - if (cmd.op == "delete") { + if (cmd.op == "deltensor") { registry_->RefDec(cmd.name); return; } @@ -68,6 +74,7 @@ void LifecycleWorker::RunOnce(int timeout_seconds) { } LifecycleCommand cmd; if (!Parse(msg, cmd)) { + std::cerr << "[lifecycle] failed to parse command: " << msg << "\n"; return; } Handle(cmd); diff --git a/executor/heapmem-cuda/src/runtime/lifecycle.h b/executor/heap-cuda/src/runtime/lifecycle.h similarity index 100% rename from executor/heapmem-cuda/src/runtime/lifecycle.h rename to executor/heap-cuda/src/runtime/lifecycle.h diff --git a/executor/heapmem-cuda/src/runtime/sync.cpp b/executor/heap-cuda/src/runtime/sync.cpp similarity index 100% rename from executor/heapmem-cuda/src/runtime/sync.cpp rename to executor/heap-cuda/src/runtime/sync.cpp diff --git a/executor/heapmem-cuda/src/runtime/sync.h b/executor/heap-cuda/src/runtime/sync.h similarity index 100% rename from executor/heapmem-cuda/src/runtime/sync.h rename to executor/heap-cuda/src/runtime/sync.h diff --git a/executor/heapmem-cuda/test/CMakeLists.txt b/executor/heap-cuda/test/CMakeLists.txt similarity index 100% rename from executor/heapmem-cuda/test/CMakeLists.txt rename to executor/heap-cuda/test/CMakeLists.txt diff --git a/executor/heap-metal/CMakeLists.txt b/executor/heap-metal/CMakeLists.txt new file mode 100644 index 00000000..3dfee87a --- /dev/null +++ b/executor/heap-metal/CMakeLists.txt @@ -0,0 +1,26 @@ +cmake_minimum_required(VERSION 3.15) +project(deepx-heap-metal LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_BUILD_TYPE Debug) + +include_directories(src) +include_directories(../common-metal/include) + +# hiredis (Redis C client) +include_directories(/opt/homebrew/opt/hiredis/include) +link_directories(/opt/homebrew/opt/hiredis/lib) + +# nlohmann/json (header-only JSON parser) +include_directories(/opt/homebrew/opt/nlohmann-json/include) + +# 依赖 common-metal 公共库 +if(NOT TARGET deepx_common_metal) + add_subdirectory(../common-metal common-metal) +endif() + +file(GLOB_RECURSE SOURCES "src/*.cpp") + +add_executable(${PROJECT_NAME} ${SOURCES}) +target_link_libraries(${PROJECT_NAME} PRIVATE deepx_common_metal hiredis) diff --git a/executor/heap-metal/build.sh b/executor/heap-metal/build.sh new file mode 100644 index 00000000..0c0b1e8e --- /dev/null +++ b/executor/heap-metal/build.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail +DIR="$(cd "$(dirname "$0")" && pwd)" +BUILD_DIR="/tmp/deepx/heap-metal/build" +mkdir -p "$BUILD_DIR" +cd "$BUILD_DIR" +cmake "$DIR" +cmake --build . -j$(sysctl -n hw.ncpu 2>/dev/null || nproc) +# Copy runtime dependencies (rpath) +cp -f common-metal/libdeepx_common_metal.a "$BUILD_DIR/" 2>/dev/null || true +echo "Built: $BUILD_DIR/deepx-heap-metal" diff --git a/executor/heap-metal/src/lifecycle/lifecycle.cpp b/executor/heap-metal/src/lifecycle/lifecycle.cpp new file mode 100644 index 00000000..e528f593 --- /dev/null +++ b/executor/heap-metal/src/lifecycle/lifecycle.cpp @@ -0,0 +1,159 @@ +#include "lifecycle.h" +#include +#include +#include +#include + +namespace deepx::heap { + +int64_t parse_shape_size(const std::string &shape_str) { + // "[2,3,4]" → 2*3*4 = 24 + int64_t total = 1; + int64_t cur = 0; + bool in_num = false; + for (char c : shape_str) { + if (c >= '0' && c <= '9') { + cur = cur * 10 + (c - '0'); + in_num = true; + } else { + if (in_num) { + total *= cur; + cur = 0; + in_num = false; + } + } + } + if (in_num) total *= cur; + return total; +} + +LifecycleManager::LifecycleManager(Registry *registry) + : registry_(registry) {} + +std::string LifecycleManager::generate_shm_name() const { + auto now = std::chrono::steady_clock::now().time_since_epoch().count(); + uint64_t id = counter_.fetch_add(1); + std::ostringstream oss; + oss << "/deepx_t_" << std::hex << now << "_" << id; + return oss.str(); +} + +std::string LifecycleManager::handle(const LifecycleCommand &cmd, std::string &error) { + if (cmd.op == "newtensor") { + // 计算 byte_size + int64_t element_count = cmd.element_count > 0 + ? cmd.element_count + : parse_shape_size(cmd.shape); + + // 确定每个元素的字节数 + int elem_bytes = 4; // 默认 f32 + if (cmd.dtype == "f64") elem_bytes = 8; + else if (cmd.dtype == "f32") elem_bytes = 4; + else if (cmd.dtype == "f16" || cmd.dtype == "bf16") elem_bytes = 2; + else if (cmd.dtype == "i64") elem_bytes = 8; + else if (cmd.dtype == "i32") elem_bytes = 4; + else if (cmd.dtype == "i16") elem_bytes = 2; + else if (cmd.dtype == "i8" || cmd.dtype == "bool") elem_bytes = 1; + + int64_t total_bytes = element_count * elem_bytes; + + // 检查是否已存在 + TensorMeta existing; + if (registry_->get_meta(cmd.name, existing)) { + // Tensor 已存在 → 打开已有 shm,ref_inc + registry_->ref_inc(cmd.name); + std::lock_guard lock(mutex_); + auto it = open_tensors_.find(existing.shm_name); + if (it == open_tensors_.end()) { + deepx::shmem::ShmTensor st; + if (!deepx::shmem::shm_tensor_open(existing.shm_name, existing.byte_size, st)) { + error = "failed to open existing shm: " + existing.shm_name; + return ""; + } + open_tensors_[existing.shm_name] = st; + } + return existing.shm_name; + } + + // 创建新的 shm tensor + std::string shm_name = generate_shm_name(); + deepx::shmem::ShmTensor st; + if (!deepx::shmem::shm_tensor_create(shm_name, total_bytes, st)) { + error = "shm_tensor_create failed for " + shm_name; + return ""; + } + + // 注册到 registry + registry_->create_or_get(cmd.name, cmd.dtype, cmd.shape, + cmd.device, total_bytes, cmd.pid, shm_name); + + { + std::lock_guard lock(mutex_); + open_tensors_[shm_name] = st; + } + + printf("[heap] created tensor '%s' → shm=%s bytes=%lld\n", + cmd.name.c_str(), shm_name.c_str(), total_bytes); + return shm_name; + } + else if (cmd.op == "gettensor") { + TensorMeta meta; + if (!registry_->get_meta(cmd.name, meta)) { + error = "tensor not found: " + cmd.name; + return ""; + } + registry_->ref_inc(cmd.name); + + // 确保已打开 + std::lock_guard lock(mutex_); + auto it = open_tensors_.find(meta.shm_name); + if (it == open_tensors_.end()) { + deepx::shmem::ShmTensor st; + if (!deepx::shmem::shm_tensor_open(meta.shm_name, meta.byte_size, st)) { + error = "failed to open shm: " + meta.shm_name; + return ""; + } + open_tensors_[meta.shm_name] = st; + } + return meta.shm_name; + } + else if (cmd.op == "deltensor") { + int64_t ref = registry_->ref_dec(cmd.name); + printf("[heap] delete '%s' → refcount=%lld\n", cmd.name.c_str(), ref); + + if (ref <= 0) { + TensorMeta meta; + if (registry_->get_meta(cmd.name, meta)) { + std::lock_guard lock(mutex_); + auto it = open_tensors_.find(meta.shm_name); + if (it != open_tensors_.end()) { + deepx::shmem::shm_tensor_close(it->second); + deepx::shmem::shm_tensor_unlink(it->second.shm_name); + open_tensors_.erase(it); + } + } + } + return ""; + } + error = "unknown op: " + cmd.op; + return ""; +} + +void *LifecycleManager::get_addr(const std::string &shm_name) const { + std::lock_guard lock(mutex_); + auto it = open_tensors_.find(shm_name); + if (it != open_tensors_.end()) { + return it->second.addr; + } + return nullptr; +} + +void LifecycleManager::shutdown() { + std::lock_guard lock(mutex_); + for (auto &kv : open_tensors_) { + deepx::shmem::shm_tensor_close(kv.second); + } + open_tensors_.clear(); +} + +} // namespace deepx::heap diff --git a/executor/heap-metal/src/lifecycle/lifecycle.h b/executor/heap-metal/src/lifecycle/lifecycle.h new file mode 100644 index 00000000..a5458d7a --- /dev/null +++ b/executor/heap-metal/src/lifecycle/lifecycle.h @@ -0,0 +1,48 @@ +#pragma once + +#include "deepx/registry.h" +#include "deepx/shmem/shm_tensor.h" +#include +#include +#include +#include + +namespace deepx::heap { + +// 从 "[10,20,30]" 解析 element count +int64_t parse_shape_size(const std::string &shape_str); + +struct LifecycleCommand { + std::string op; // "newtensor", "gettensor", "deltensor" + std::string name; // tensor name + std::string dtype; + std::string shape; + int64_t device = 0; + int64_t byte_size = 0; + int64_t pid = 0; + int64_t element_count = 0; +}; + +class LifecycleManager { +public: + LifecycleManager(Registry *registry); + + // 处理一条指令,返回 shm_name(create/get 时有效) + std::string handle(const LifecycleCommand &cmd, std::string &error); + + // 获取已打开的 shm tensor 的地址(供本地访问) + void *get_addr(const std::string &shm_name) const; + + // 关闭所有已打开的 tensor + void shutdown(); + +private: + std::string generate_shm_name() const; + + Registry *registry_; + std::unordered_map open_tensors_; + mutable std::mutex mutex_; + mutable std::atomic counter_{0}; +}; + +} // namespace deepx::heap diff --git a/executor/heap-metal/src/main.cpp b/executor/heap-metal/src/main.cpp new file mode 100644 index 00000000..83f3f2ee --- /dev/null +++ b/executor/heap-metal/src/main.cpp @@ -0,0 +1,376 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "registry/registry_file.h" +#include "lifecycle/lifecycle.h" + +using namespace deepx::heap; +using json = nlohmann::json; + +static const char *HEAP_QUEUE = "cmd:heap-metal:0"; +static const char *SYS_QUEUE = "sys:cmd:heap-metal:0"; +static const char *INSTANCE_KEY = "/sys/heap-plat/heap-metal:0"; +static const char *HEARTBEAT_KEY = "/sys/heartbeat/heap-metal:0"; +static const int BLOCK_TIMEOUT_SEC = 5; +static const int HEARTBEAT_INTERVAL_SEC = 2; + +// ── Redis helpers ── + +static redisContext* connect_redis(const char *addr, int port) { + struct timeval tv = {2, 0}; + redisContext *c = redisConnectWithTimeout(addr, port, tv); + if (!c || c->err) { + std::cerr << "Redis connect failed: " << (c ? c->errstr : "null") << "\n"; + if (c) redisFree(c); + return nullptr; + } + return c; +} + +static redisReply* redis_cmd(redisContext *c, const char *fmt, ...) { + va_list ap; + va_start(ap, fmt); + redisReply *r = (redisReply *)redisvCommand(c, fmt, ap); + va_end(ap); + return r; +} + +#define REDIS_FREE(r) do { if (r) freeReplyObject(r); } while(0) + +static bool redis_set(redisContext *c, const std::string &key, const std::string &val) { + redisReply *r = redis_cmd(c, "SET %s %s", key.c_str(), val.c_str()); + bool ok = r && r->type == REDIS_REPLY_STATUS; + REDIS_FREE(r); + return ok; +} + +static void update_heartbeat(redisContext *c, const std::string &status) { + json hb; + hb["ts"] = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()).count(); + hb["status"] = status; + hb["pid"] = getpid(); + redis_set(c, HEARTBEAT_KEY, hb.dump()); +} + +static void register_instance(redisContext *c) { + json reg; + reg["program"] = "heap-metal"; + reg["device"] = "gpu0"; + reg["status"] = "running"; + reg["pid"] = getpid(); + reg["started_at"] = std::chrono::system_clock::now().time_since_epoch().count(); + redis_set(c, INSTANCE_KEY, reg.dump()); + std::cout << "[heap] registered at " << INSTANCE_KEY << "\n"; +} + +static void notify_done(redisContext *c, const std::string &vtid, const std::string &pc, + const std::string &status, const std::string &error_msg = "") { + json done; + done["pc"] = pc; + done["status"] = status; + if (!error_msg.empty()) { + done["error"] = {{"code", "HEAP_ERROR"}, {"message", error_msg}}; + } + std::string key = "done:" + vtid; + redisReply *r = redis_cmd(c, "LPUSH %s %s", key.c_str(), done.dump().c_str()); + if (!r || r->type == REDIS_REPLY_ERROR) { + std::cerr << "[heap] notify_done LPUSH failed for " << vtid << ": " << (r ? r->str : "NULL") << "\n"; + } + REDIS_FREE(r); + std::cout << "[heap] done " << vtid << " pc=" << pc << " status=" << status << "\n"; +} + +// ── JSON → LifecycleCommand ── + +static LifecycleCommand parse_command(const json &j, std::string &error) { + LifecycleCommand cmd; + if (!j.contains("op")) { + error = "missing 'op' field"; + return cmd; + } + cmd.op = j["op"].get(); + + if (j.contains("key")) { + cmd.name = j["key"].get(); + } else if (j.contains("src")) { + cmd.name = j["src"].get(); + } else { + error = "missing tensor key"; + return cmd; + } + + if (j.contains("dtype")) cmd.dtype = j["dtype"].get(); + if (j.contains("shape")) { + if (j["shape"].is_array()) { + std::ostringstream oss; + oss << "["; + bool first = true; + for (const auto &d : j["shape"]) { + if (!first) oss << ","; + oss << d.get(); + first = false; + } + oss << "]"; + cmd.shape = oss.str(); + + int64_t total = 1; + for (const auto &d : j["shape"]) total *= d.get(); + cmd.element_count = total; + } else if (j["shape"].is_string()) { + // shape 也可能以字符串形式传入: "[10,10]" + cmd.shape = j["shape"].get(); + cmd.element_count = parse_shape_size(cmd.shape); + } + } + cmd.device = 0; + cmd.pid = getpid(); + return cmd; +} + +// ── Op dispatch ── + +static void handle_newtensor(LifecycleManager &mgr, const LifecycleCommand &cmd, + redisContext *redis, const json &task) { + std::string error; + std::string shm_name = mgr.handle(cmd, error); + + if (!error.empty()) { + std::cerr << "[heap] newtensor error: " << error << "\n"; + notify_done(redis, task["vtid"], task["pc"], "error", error); + return; + } + + // 写入 tensor 元信息到 Redis + json meta; + meta["dtype"] = cmd.dtype; + meta["shape"] = json::parse(cmd.shape); + meta["byte_size"] = cmd.element_count * 4; // default f32 + meta["device"] = "gpu0"; + meta["address"]["type"] = "shm"; + meta["address"]["shm_name"] = shm_name; + meta["address"]["node"] = "n1"; + redis_set(redis, cmd.name, meta.dump()); + + std::cout << "[heap] newtensor '" << cmd.name << "' → shm=" << shm_name << "\n"; + notify_done(redis, task["vtid"], task["pc"], "ok"); +} + +static void handle_deltensor(LifecycleManager &mgr, const LifecycleCommand &cmd, + redisContext *redis, const json &task) { + std::string error; + mgr.handle(cmd, error); + + if (!error.empty()) { + std::cerr << "[heap] deltensor error: " << error << "\n"; + } + + // 删除 Redis key + redisReply *r = redis_cmd(redis, "DEL %s", cmd.name.c_str()); + REDIS_FREE(r); + + std::cout << "[heap] deltensor '" << cmd.name << "'\n"; + notify_done(redis, task["vtid"], task["pc"], error.empty() ? "ok" : "error", error); +} + +static void handle_clonetensor(LifecycleManager &mgr, const json &task, redisContext *redis) { + std::string src = task["src"]; + std::string dst = task["dst"]; + + // GET src → tensor meta, 然后在 dest 创建同名 shm 并拷贝数据 + // (简化实现: 重用 newtensor—如果不存在则创建,存在则 ref_inc) + redisReply *r = redis_cmd(redis, "GET %s", src.c_str()); + std::string error; + if (!r || r->type != REDIS_REPLY_STRING) { + error = "clone: source tensor not found: " + src; + } else { + json src_meta = json::parse(r->str); + + LifecycleCommand cmd; + cmd.op = "newtensor"; + cmd.name = dst; + cmd.dtype = src_meta["dtype"]; + cmd.shape = src_meta["shape"].dump(); + cmd.device = 0; + cmd.pid = getpid(); + + int64_t total = 1; + for (const auto &d : src_meta["shape"]) total *= d.get(); + cmd.element_count = total; + + std::string shm_name = mgr.handle(cmd, error); + if (error.empty()) { + json dst_meta = src_meta; + dst_meta["address"]["shm_name"] = shm_name; + redis_set(redis, dst, dst_meta.dump()); + + // 拷贝数据 (如果源 shm 也在这台机器上) + void *src_addr = mgr.get_addr(src_meta["address"]["shm_name"]); + void *dst_addr = mgr.get_addr(shm_name); + if (src_addr && dst_addr) { + memcpy(dst_addr, src_addr, static_cast(src_meta["byte_size"].get())); + } + } + } + REDIS_FREE(r); + + if (!error.empty()) { + notify_done(redis, task["vtid"], task["pc"], "error", error); + } else { + notify_done(redis, task["vtid"], task["pc"], "ok"); + } +} + +// ── Main ── + +int main(int argc, char **argv) { + const char *redis_addr = "127.0.0.1"; + int redis_port = 6379; + if (argc > 1) redis_addr = argv[1]; + if (argc > 2) redis_port = atoi(argv[2]); + + // 清理上一轮进程残留的过期 registry(堆的生死 = 进程的生死) + const char *registry_path = "/tmp/deepx_heap_registry.txt"; + unlink(registry_path); + + // 连接 Redis(无限重试,不自退——heap-plat 由元程控制退出) + redisContext *redis = nullptr; + while (!redis) { + redis = connect_redis(redis_addr, redis_port); + if (!redis) { + std::cerr << "[heap] Redis not available, retrying in 1s...\n"; + sleep(1); + } + } + std::cout << "[heap] connected to Redis " << redis_addr << ":" << redis_port << "\n"; + + // 注册实例 + register_instance(redis); + + // 初始化 LifecycleManager + FileRegistry reg(registry_path); + LifecycleManager mgr(®); + + std::cout << "[heap] listening on " << HEAP_QUEUE << " + " << SYS_QUEUE << "\n"; + std::cout << "[heap] heartbeat → " << HEARTBEAT_KEY << " (every " << HEARTBEAT_INTERVAL_SEC << "s)\n"; + + // 初始心跳 + update_heartbeat(redis, "running"); + + // 消费循环 (同时监听业务队列和系统命令队列) + std::atomic running{true}; + auto last_heartbeat = std::chrono::steady_clock::now(); + while (running) { + redisReply *r = redis_cmd(redis, "BLPOP %s %s %d", HEAP_QUEUE, SYS_QUEUE, BLOCK_TIMEOUT_SEC); + if (!r) { + // Redis 断连 → 无限重连(不自退,heap-plat 由元程控制退出) + std::cerr << "[heap] Redis disconnected, reconnecting...\n"; + redisFree(redis); + redis = nullptr; + while (!redis) { + sleep(1); + redis = connect_redis(redis_addr, redis_port); + if (!redis) { + std::cerr << "[heap] Redis still not available, retrying...\n"; + } + } + register_instance(redis); + last_heartbeat = std::chrono::steady_clock::now(); + update_heartbeat(redis, "running"); + continue; + } + + // ── 心跳上报 ── + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - last_heartbeat).count() >= HEARTBEAT_INTERVAL_SEC) { + update_heartbeat(redis, "running"); + last_heartbeat = now; + } + + if (r->type == REDIS_REPLY_NIL) { + // BLPOP timeout — no tasks + REDIS_FREE(r); + continue; + } + + if (r->type != REDIS_REPLY_ARRAY || r->elements < 2) { + REDIS_FREE(r); + continue; + } + + std::string queue_name(r->element[0]->str); + std::string payload(r->element[1]->str); + REDIS_FREE(r); + + // ── 系统命令处理 ── + if (queue_name == SYS_QUEUE) { + try { + json sys_cmd = json::parse(payload); + std::string cmd = sys_cmd.value("cmd", ""); + if (cmd == "shutdown") { + std::cout << "[heap] received sys shutdown command, exiting...\n"; + running = false; + } else { + std::cerr << "[heap] unknown sys command: " << cmd << "\n"; + } + } catch (const std::exception &e) { + std::cerr << "[heap] sys cmd JSON parse error: " << e.what() << "\n"; + } + continue; + } + + // ── 业务命令处理 ── + // 解析 JSON + json task; + try { + task = json::parse(payload); + } catch (const std::exception &e) { + std::cerr << "[heap] JSON parse error: " << e.what() << "\n"; + continue; + } + + std::string op = task.value("op", ""); + std::string vtid = task.value("vtid", ""); + std::string pc = task.value("pc", ""); + + std::cout << "[heap] received op=" << op << " vtid=" << vtid << " pc=" << pc << "\n"; + + std::string parse_err; + LifecycleCommand cmd = parse_command(task, parse_err); + + if (!parse_err.empty()) { + notify_done(redis, vtid, pc, "error", parse_err); + continue; + } + + if (op == "newtensor") { + handle_newtensor(mgr, cmd, redis, task); + } else if (op == "deltensor") { + handle_deltensor(mgr, cmd, redis, task); + } else if (op == "clonetensor") { + handle_clonetensor(mgr, task, redis); + } else { + notify_done(redis, vtid, pc, "error", "unknown op: " + op); + } + } + + mgr.shutdown(); + // 上报 stopped 心跳,然后注销 + if (redis) { + update_heartbeat(redis, "stopped"); + std::cout << "[heap] final heartbeat: stopped\n"; + redis_cmd(redis, "DEL %s", INSTANCE_KEY); + redisFree(redis); + } + std::cout << "[heap] shutdown complete.\n"; + return 0; +} diff --git a/executor/heap-metal/src/main.mm b/executor/heap-metal/src/main.mm new file mode 100644 index 00000000..5ad31b75 --- /dev/null +++ b/executor/heap-metal/src/main.mm @@ -0,0 +1,315 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "registry/registry_file.h" +#include "lifecycle/lifecycle.h" + +using namespace deepx::heap; +using json = nlohmann::json; + +static const char *HEAP_QUEUE = "cmd:heap-metal:0"; +static const char *INSTANCE_KEY = "/sys/heap-plat/heap-metal:0"; +static const int BLOCK_TIMEOUT_SEC = 5; + +// ── Redis helpers ── + +static redisContext* connect_redis(const char *addr, int port) { + struct timeval tv = {2, 0}; + redisContext *c = redisConnectWithTimeout(addr, port, tv); + if (!c || c->err) { + std::cerr << "Redis connect failed: " << (c ? c->errstr : "null") << "\n"; + if (c) redisFree(c); + return nullptr; + } + return c; +} + +static redisReply* redis_cmd(redisContext *c, const char *fmt, ...) { + va_list ap; + va_start(ap, fmt); + redisReply *r = (redisReply *)redisvCommand(c, fmt, ap); + va_end(ap); + return r; +} + +#define REDIS_FREE(r) do { if (r) freeReplyObject(r); } while(0) + +static bool redis_set(redisContext *c, const std::string &key, const std::string &val) { + redisReply *r = redis_cmd(c, "SET %s %s", key.c_str(), val.c_str()); + bool ok = r && r->type == REDIS_REPLY_STATUS; + REDIS_FREE(r); + return ok; +} + +static void register_instance(redisContext *c) { + json reg; + reg["program"] = "heap-metal"; + reg["device"] = "gpu0"; + reg["status"] = "running"; + reg["pid"] = getpid(); + reg["started_at"] = std::chrono::system_clock::now().time_since_epoch().count(); + redis_set(c, INSTANCE_KEY, reg.dump()); + std::cout << "[heap] registered at " << INSTANCE_KEY << "\n"; +} + +static void notify_done(redisContext *c, const std::string &vtid, const std::string &pc, + const std::string &status, const std::string &error_msg = "") { + json done; + done["pc"] = pc; + done["status"] = status; + if (!error_msg.empty()) { + done["error"] = {{"code", "HEAP_ERROR"}, {"message", error_msg}}; + } + std::string key = "done:" + vtid; + redisReply *r = redis_cmd(c, "LPUSH %s %s", key.c_str(), done.dump().c_str()); + if (!r || r->type == REDIS_REPLY_ERROR) { + std::cerr << "[heap] notify_done LPUSH failed for " << vtid << ": " << (r ? r->str : "NULL") << "\n"; + } + REDIS_FREE(r); + std::cout << "[heap] done " << vtid << " pc=" << pc << " status=" << status << "\n"; +} + +// ── JSON → LifecycleCommand ── + +static LifecycleCommand parse_command(const json &j, std::string &error) { + LifecycleCommand cmd; + if (!j.contains("op")) { + error = "missing 'op' field"; + return cmd; + } + cmd.op = j["op"].get(); + + if (j.contains("key")) { + cmd.name = j["key"].get(); + } else if (j.contains("src")) { + cmd.name = j["src"].get(); + } else { + error = "missing tensor key"; + return cmd; + } + + if (j.contains("dtype")) cmd.dtype = j["dtype"].get(); + if (j.contains("shape")) { + if (j["shape"].is_array()) { + std::ostringstream oss; + oss << "["; + bool first = true; + for (const auto &d : j["shape"]) { + if (!first) oss << ","; + oss << d.get(); + first = false; + } + oss << "]"; + cmd.shape = oss.str(); + + int64_t total = 1; + for (const auto &d : j["shape"]) total *= d.get(); + cmd.element_count = total; + } else if (j["shape"].is_string()) { + // shape 也可能以字符串形式传入: "[10,10]" + cmd.shape = j["shape"].get(); + cmd.element_count = parse_shape_size(cmd.shape); + } + } + cmd.device = 0; + cmd.pid = getpid(); + return cmd; +} + +// ── Op dispatch ── + +static void handle_newtensor(LifecycleManager &mgr, const LifecycleCommand &cmd, + redisContext *redis, const json &task) { + std::string error; + std::string shm_name = mgr.handle(cmd, error); + + if (!error.empty()) { + std::cerr << "[heap] newtensor error: " << error << "\n"; + notify_done(redis, task["vtid"], task["pc"], "error", error); + return; + } + + // 写入 tensor 元信息到 Redis + json meta; + meta["dtype"] = cmd.dtype; + meta["shape"] = json::parse(cmd.shape); + meta["byte_size"] = cmd.element_count * 4; // default f32 + meta["device"] = "gpu0"; + meta["address"]["type"] = "shm"; + meta["address"]["shm_name"] = shm_name; + meta["address"]["node"] = "n1"; + redis_set(redis, cmd.name, meta.dump()); + + std::cout << "[heap] newtensor '" << cmd.name << "' → shm=" << shm_name << "\n"; + notify_done(redis, task["vtid"], task["pc"], "ok"); +} + +static void handle_deltensor(LifecycleManager &mgr, const LifecycleCommand &cmd, + redisContext *redis, const json &task) { + std::string error; + mgr.handle(cmd, error); + + if (!error.empty()) { + std::cerr << "[heap] deltensor error: " << error << "\n"; + } + + // 删除 Redis key + redisReply *r = redis_cmd(redis, "DEL %s", cmd.name.c_str()); + REDIS_FREE(r); + + std::cout << "[heap] deltensor '" << cmd.name << "'\n"; + notify_done(redis, task["vtid"], task["pc"], error.empty() ? "ok" : "error", error); +} + +static void handle_clonetensor(LifecycleManager &mgr, const json &task, redisContext *redis) { + std::string src = task["src"]; + std::string dst = task["dst"]; + + // GET src → tensor meta, 然后在 dest 创建同名 shm 并拷贝数据 + // (简化实现: 重用 newtensor—如果不存在则创建,存在则 ref_inc) + redisReply *r = redis_cmd(redis, "GET %s", src.c_str()); + std::string error; + if (!r || r->type != REDIS_REPLY_STRING) { + error = "clone: source tensor not found: " + src; + } else { + json src_meta = json::parse(r->str); + + LifecycleCommand cmd; + cmd.op = "newtensor"; + cmd.name = dst; + cmd.dtype = src_meta["dtype"]; + cmd.shape = src_meta["shape"].dump(); + cmd.device = 0; + cmd.pid = getpid(); + + int64_t total = 1; + for (const auto &d : src_meta["shape"]) total *= d.get(); + cmd.element_count = total; + + std::string shm_name = mgr.handle(cmd, error); + if (error.empty()) { + json dst_meta = src_meta; + dst_meta["address"]["shm_name"] = shm_name; + redis_set(redis, dst, dst_meta.dump()); + + // 拷贝数据 (如果源 shm 也在这台机器上) + void *src_addr = mgr.get_addr(src_meta["address"]["shm_name"]); + void *dst_addr = mgr.get_addr(shm_name); + if (src_addr && dst_addr) { + memcpy(dst_addr, src_addr, static_cast(src_meta["byte_size"].get())); + } + } + } + REDIS_FREE(r); + + if (!error.empty()) { + notify_done(redis, task["vtid"], task["pc"], "error", error); + } else { + notify_done(redis, task["vtid"], task["pc"], "ok"); + } +} + +// ── Main ── + +int main(int argc, char **argv) { + const char *redis_addr = "127.0.0.1"; + int redis_port = 6379; + if (argc > 1) redis_addr = argv[1]; + if (argc > 2) redis_port = atoi(argv[2]); + + const char *registry_path = "/tmp/deepx_heap_registry.txt"; + + // 连接 Redis + redisContext *redis = connect_redis(redis_addr, redis_port); + if (!redis) return 1; + + std::cout << "[heap] connected to Redis " << redis_addr << ":" << redis_port << "\n"; + + // 注册实例 + register_instance(redis); + + // 初始化 LifecycleManager + FileRegistry reg(registry_path); + LifecycleManager mgr(®); + + std::cout << "[heap] listening on " << HEAP_QUEUE << "\n"; + + // 消费循环 + while (true) { + redisReply *r = redis_cmd(redis, "BLPOP %s %d", HEAP_QUEUE, BLOCK_TIMEOUT_SEC); + if (!r) { + // Redis 断连 → 重连 + std::cerr << "[heap] Redis disconnected, reconnecting...\n"; + redisFree(redis); + sleep(1); + redis = connect_redis(redis_addr, redis_port); + if (!redis) break; + register_instance(redis); + continue; + } + + if (r->type == REDIS_REPLY_NIL) { + // BLPOP timeout — no tasks + REDIS_FREE(r); + continue; + } + + if (r->type != REDIS_REPLY_ARRAY || r->elements < 2) { + REDIS_FREE(r); + continue; + } + + std::string payload(r->element[1]->str); + REDIS_FREE(r); + + // 解析 JSON + json task; + try { + task = json::parse(payload); + } catch (const std::exception &e) { + std::cerr << "[heap] JSON parse error: " << e.what() << "\n"; + continue; + } + + std::string op = task.value("op", ""); + std::string vtid = task.value("vtid", ""); + std::string pc = task.value("pc", ""); + + std::cout << "[heap] received op=" << op << " vtid=" << vtid << " pc=" << pc << "\n"; + + std::string parse_err; + LifecycleCommand cmd = parse_command(task, parse_err); + + if (!parse_err.empty()) { + notify_done(redis, vtid, pc, "error", parse_err); + continue; + } + + if (op == "newtensor") { + handle_newtensor(mgr, cmd, redis, task); + } else if (op == "deltensor") { + handle_deltensor(mgr, cmd, redis, task); + } else if (op == "clonetensor") { + handle_clonetensor(mgr, task, redis); + } else { + notify_done(redis, vtid, pc, "error", "unknown op: " + op); + } + } + + mgr.shutdown(); + if (redis) { + redis_cmd(redis, "DEL %s", INSTANCE_KEY); + redisFree(redis); + } + std::cout << "[heap] shutdown complete.\n"; + return 0; +} diff --git a/executor/heap-metal/src/registry/registry_file.h b/executor/heap-metal/src/registry/registry_file.h new file mode 100644 index 00000000..6ac4e33b --- /dev/null +++ b/executor/heap-metal/src/registry/registry_file.h @@ -0,0 +1,121 @@ +#pragma once + +#include "deepx/registry.h" +#include +#include +#include +#include + +namespace deepx::heap { + +// 基于文件的简单 Registry 实现(验证用) +// 生产环境替换为 RedisRegistry +class FileRegistry : public Registry { +public: + explicit FileRegistry(const std::string &path) : path_(path) { + load(); + } + + ~FileRegistry() override { save(); } + + std::string create_or_get(const std::string &name, + const std::string &dtype, + const std::string &shape, + int64_t device, + int64_t byte_size, + int64_t pid, + const std::string &shm_name) override { + std::lock_guard lock(mutex_); + auto it = store_.find(name); + if (it != store_.end()) { + it->second.refcount++; + save(); + return it->second.shm_name; + } + TensorMeta meta; + meta.name = name; + meta.shm_name = shm_name; + meta.dtype = dtype; + meta.shape = shape; + meta.device = device; + meta.byte_size = byte_size; + meta.owner_pid = pid; + meta.refcount = 1; + meta.ctime = time(nullptr); + meta.state = "ready"; + store_[name] = meta; + save(); + return shm_name; + } + + int64_t ref_inc(const std::string &name) override { + std::lock_guard lock(mutex_); + auto it = store_.find(name); + if (it == store_.end()) return -1; + it->second.refcount++; + save(); + return it->second.refcount; + } + + int64_t ref_dec(const std::string &name) override { + std::lock_guard lock(mutex_); + auto it = store_.find(name); + if (it == store_.end()) return -1; + it->second.refcount--; + if (it->second.refcount <= 0) { + it->second.state = "deleted"; + } + save(); + return it->second.refcount; + } + + bool get_meta(const std::string &name, TensorMeta &out) override { + std::lock_guard lock(mutex_); + auto it = store_.find(name); + if (it == store_.end()) return false; + out = it->second; + return true; + } + +private: + void load() { + std::ifstream f(path_); + if (!f) return; + std::string line; + while (std::getline(f, line)) { + auto pos = line.find(' '); + if (pos == std::string::npos) continue; + std::string key = line.substr(0, pos); + std::string val = line.substr(pos + 1); + if (key == "tensor") { + TensorMeta meta; + // simple format: tensor name shm_name dtype shape device bytes refcount pid ctime state + std::istringstream iss(val); + iss >> meta.name >> meta.shm_name >> meta.dtype >> meta.shape + >> meta.device >> meta.byte_size >> meta.refcount + >> meta.owner_pid >> meta.ctime >> meta.state; + if (!meta.name.empty()) { + store_[meta.name] = meta; + } + } + } + } + + void save() { + std::ofstream f(path_, std::ios::trunc); + if (!f) return; + for (auto &kv : store_) { + auto &m = kv.second; + f << "tensor " << m.name << " " << m.shm_name << " " << m.dtype << " " + << m.shape << " " << m.device << " " << m.byte_size << " " + << m.refcount << " " << m.owner_pid << " " << m.ctime << " " + << m.state << "\n"; + } + } + + std::string path_; + std::unordered_map store_; + std::mutex mutex_; +}; + +} // namespace deepx::heap diff --git a/executor/io-metal/CLAUDE.md b/executor/io-metal/CLAUDE.md new file mode 100644 index 00000000..b46f350b --- /dev/null +++ b/executor/io-metal/CLAUDE.md @@ -0,0 +1,141 @@ +# io-metal 开发约束 + +> io-metal 的职责边界。哪些能做,哪些**绝对不能碰**。 + +--- + +## 1. io-metal 是什么 + +在 DeepX 元程架构中,io-metal 是 **I/O 平面**——负责 tensor 与文件系统、进程管道、网络的读写。 + +它只做一件事:**被动消费 I/O 指令 → 读写数据 → 通知完成**。 + +io-metal 是"无状态的 I/O 执行器"——它不关心 tensor 的计算语义,只关心如何把数据持久化或传输。 + +--- + +## 2. 为什么 I/O 要与 GPU 计算分离 + +| 维度 | op-metal (GPU 计算) | io-metal (I/O) | +|------|---------------------|----------------| +| 硬件依赖 | Metal GPU 必须 | 仅需 CPU | +| 操作延迟 | ~μs (kernel launch) | ~ms-s (disk/network) | +| 阻塞风险 | 无 (GPU 异步) | **高** (磁盘满、网络超时) | +| 失败影响 | GPU OOM / Metal 错误 | 磁盘满 / 网络断开 | + +**如果合并在同一个进程**:磁盘 I/O 阻塞会拖死整个 GPU 计算管线。 + +--- + +## 3. 允许做的事(白名单) + +| 操作 | 允许 | 说明 | +|------|------|------| +| BLPOP `cmd:io-metal:*` | ✅ | 消费 VM 发来的 I/O 指令 | +| GET Redis 获取 tensor 元信息 | ✅ | 仅限 inputs/outputs 的 dtype/shape/shm_name/byte_size | +| shm_open + mmap 映射 tensor 内存 | ✅ | 根据 shm_name 获取 CPU 指针,读/写 tensor 数据 | +| 写文件 (save) | ✅ | 将 tensor shape + data 持久化到文件系统 | +| 读文件 (load) | ✅ | 从文件系统读取 tensor shape + data 到 shm | +| 输出到 stdout (print) | ✅ | 格式化打印 tensor 数据 | +| LPUSH `done:` 通知完成 | ✅ | 格式: {pc, status:"ok"\|"error", error?} | +| SET `/sys/io-plat/io-metal:0` | ✅ | 启动时注册进程状态 | +| DEL `/sys/io-plat/io-metal:0` | ✅ | 退出时注销 | + +--- + +## 4. 禁止做的事(黑名单) + +### 4.1 绝对禁止:GPU 计算 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| Metal GPU kernel 调用 | ❌ | I/O 平面不需要 GPU | +| MTLBuffer / MTLDevice 操作 | ❌ | 这是 op-metal 的职责 | +| 修改 tensor 数据(计算) | ❌ | io-metal 只搬运数据,不改变数据 | + +### 4.2 禁止:越权修改其他组件 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| 修改 VM 的 vthread 状态 | ❌ | `/vthread/*` 是 VM 的私有空间 | +| 修改 heap-plat 的分配记录 | ❌ | heap-plat 管理 `/heap/*` | +| 消费 `done:*` 队列 | ❌ | done 是 VM 消费的 | +| 生产 `cmd:op-metal:*` / `cmd:heap-metal:*` | ❌ | 只消费 `cmd:io-metal:*` | +| 创建/删除 tensor shm | ❌ | 这是 heap-plat 的职责 | + +### 4.3 禁止:修改 tensor 语义 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| 修改 dtype | ❌ | 只原样读写 | +| 修改 shape | ❌ | 只原样读写 | +| 类型转换 / cast | ❌ | 这是 op-plat 的职责 | + +--- + +## 5. io-metal 的通信边界 + +``` +VM io-metal heap-metal + │ │ │ + │── PUSH cmd:io-metal:0 ──→ │ │ + │ │── GET /data/x ────→ Redis + │ │←── {shm_name,...} ── Redis + │ │── shm_open("/deepx_t_xxx") → read data + │ │── write to file / stdout + │ │── LPUSH done:1 ────→ Redis + │←── BLPOP done:1 ───────── Redis + │── PC++ 继续 │ +``` + +**io-metal 的边界:** +- 入: `cmd:io-metal:*` 队列 + Redis GET(tensor 元信息) +- 出: `done:` 队列 +- 内部: shm 映射 → 读/写数据 → 返回 + +--- + +## 6. 支持的操作 + +| opcode | 参数 | 输入 | 输出 | 说明 | +|--------|------|------|------|------| +| `print` | format (可选) | tensor | — | 格式化输出到 stdout | +| `save` | arg0=文件路径 | tensor | — | 持久化到文件系统 (path.shape + path.data) | +| `load` | arg0=文件路径 | — | tensor | 从文件系统读取到 shm | + +### 文件格式 + +**save** 产生两个文件: +- `.shape` — JSON: `{"dtype":"f32","shape":[N,M],"size":K}` +- `.data` — 原始二进制 (tensor 数据) + +**load** 读取这两个文件,将数据写入目标 tensor 的 shm 区域。 + +--- + +## 7. 允许的日志输出 + +| 场景 | 允许 | +|------|------| +| 启动: Redis 地址、监听队列 | ✅ 一次性 `std::cout` | +| 启动: CWD、进程 PID | ✅ 一次性 | +| 致命错误: Redis 连接失败 | ✅ `std::cerr` + 重试 | +| 退出: shutdown complete | ✅ 一次性 | +| save/load: 文件路径 + dtype + 元素数 | ✅ 一次性(低频操作) | +| print: 格式化 tensor 数据 | ✅ 这是 print 的职责本身 | +| 每次指令的低级诊断 | ❌ 高频冗余 | +| shm 操作成功日志 | ❌ 成功不输出,失败走 error 通知 | + +--- + +## 8. 与 deepxctl 的关系 + +| 谁做什么 | deepxctl | io-metal | +|---------|----------|----------| +| 启动 io-metal 进程 | ✅ | — | +| 注册 I/O 算子 | — | ✅ | +| 连接 Redis | — | ✅ | +| 消费 I/O 指令 | — | ✅ | +| 执行 I/O 操作 | — | ✅ | +| 验证输出结果 | ✅ (通过轮询 vthread status) | ❌ | +| 清理子进程 | ✅ | — | diff --git a/executor/io-metal/CMakeLists.txt b/executor/io-metal/CMakeLists.txt new file mode 100644 index 00000000..e1dc5e29 --- /dev/null +++ b/executor/io-metal/CMakeLists.txt @@ -0,0 +1,27 @@ +cmake_minimum_required(VERSION 3.15) +project(deepx-io-metal LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_BUILD_TYPE Debug) +set(CMAKE_OSX_DEPLOYMENT_TARGET "12.0") + +include_directories(src) +include_directories(../common-metal/include) + +# hiredis (Redis C client) +include_directories(/opt/homebrew/opt/hiredis/include) +link_directories(/opt/homebrew/opt/hiredis/lib) + +# nlohmann/json (header-only JSON parser) +include_directories(/opt/homebrew/opt/nlohmann-json/include) + +# 依赖 common-metal 公共库 (shm_tensor) +if(NOT TARGET deepx_common_metal) + add_subdirectory(../common-metal common-metal) +endif() + +file(GLOB_RECURSE SOURCES "src/*.cpp") + +add_executable(${PROJECT_NAME} ${SOURCES}) +target_link_libraries(${PROJECT_NAME} PRIVATE deepx_common_metal hiredis) diff --git a/executor/io-metal/build.sh b/executor/io-metal/build.sh new file mode 100644 index 00000000..cbb0c479 --- /dev/null +++ b/executor/io-metal/build.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +set -euo pipefail +DIR="$(cd "$(dirname "$0")" && pwd)" +BUILD_DIR="/tmp/deepx/io-metal/build" +mkdir -p "$BUILD_DIR" +cd "$BUILD_DIR" +cmake "$DIR" +cmake --build . -j$(sysctl -n hw.ncpu 2>/dev/null || nproc) +echo "Built: $BUILD_DIR/deepx-io-metal" diff --git a/executor/io-metal/src/main.cpp b/executor/io-metal/src/main.cpp new file mode 100644 index 00000000..dbe4954e --- /dev/null +++ b/executor/io-metal/src/main.cpp @@ -0,0 +1,584 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "deepx/shmem/shm_tensor.h" + +using json = nlohmann::json; + +static const char *IO_QUEUE = "cmd:io-metal:0"; +static const char *SYS_QUEUE = "sys:cmd:io-metal:0"; +static const char *INSTANCE_KEY = "/sys/io-plat/io-metal:0"; +static const int BLOCK_TIMEOUT_SEC = 5; + +// ═══════════════════════════════════════════════════════════ +// Redis helpers +// ═══════════════════════════════════════════════════════════ + +static redisContext* connect_redis(const char *addr, int port) { + struct timeval tv = {2, 0}; + redisContext *c = redisConnectWithTimeout(addr, port, tv); + if (!c || c->err) { + std::cerr << "[io-metal] Redis connect failed: " << (c ? c->errstr : "null") << "\n"; + if (c) redisFree(c); + return nullptr; + } + return c; +} + +static redisReply* redis_cmd(redisContext *c, const char *fmt, ...) { + va_list ap; + va_start(ap, fmt); + redisReply *r = (redisReply *)redisvCommand(c, fmt, ap); + va_end(ap); + return r; +} + +#define REDIS_FREE(r) do { if (r) freeReplyObject(r); } while(0) + +static bool redis_set(redisContext *c, const std::string &key, const std::string &val) { + redisReply *r = redis_cmd(c, "SET %s %s", key.c_str(), val.c_str()); + bool ok = r && r->type == REDIS_REPLY_STATUS; + REDIS_FREE(r); + return ok; +} + +static void register_instance(redisContext *c) { + json reg; + reg["program"] = "io-metal"; + reg["device"] = "cpu"; + reg["status"] = "running"; + reg["load"] = 0.0; + reg["pid"] = getpid(); + reg["started_at"] = std::chrono::system_clock::now().time_since_epoch().count(); + redis_set(c, INSTANCE_KEY, reg.dump()); + std::cout << "[io-metal] registered at " << INSTANCE_KEY << "\n"; + + // ── 注册支持的 I/O 算子列表 ── + redisReply *r = redis_cmd(c, "DEL %s", "/op/io-metal/list"); + REDIS_FREE(r); + + redis_cmd(c, "RPUSH %s %s %s %s", + "/op/io-metal/list", + "print", "save", "load"); + + std::cout << "[io-metal] registered I/O ops: print save load\n"; +} + +static void notify_done(redisContext *c, const std::string &vtid, + const std::string &pc, const std::string &status, + const std::string &error_msg = "") { + json done; + done["pc"] = pc; + done["status"] = status; + if (!error_msg.empty()) { + done["error"] = {{"code", "IO_ERROR"}, {"message", error_msg}}; + } + std::string key = "done:" + vtid; + redisReply *r = redis_cmd(c, "LPUSH %s %s", key.c_str(), done.dump().c_str()); + if (!r || r->type == REDIS_REPLY_ERROR) { + std::cerr << "[io-metal] notify_done LPUSH failed for " << vtid << ": " << (r ? r->str : "NULL") << "\n"; + } + REDIS_FREE(r); + std::cout << "[io-metal] done " << vtid << " pc=" << pc << " status=" << status << "\n"; +} + +// ═══════════════════════════════════════════════════════════ +// shm helpers (reuses common-metal ShmTensor utilities) +// ═══════════════════════════════════════════════════════════ + +struct ShmMapping { + std::string shm_name; + void *addr = nullptr; + size_t byte_size = 0; +}; + +static bool shm_open_readwrite(const std::string &name, size_t byte_size, ShmMapping &out) { + out.shm_name = name; + out.byte_size = byte_size; + + int fd = shm_open(name.c_str(), O_RDWR, 0600); + if (fd < 0) { + std::cerr << "[io-metal] shm_open failed: " << name << " (" << strerror(errno) << ")\n"; + return false; + } + + size_t aligned = deepx::shmem::shm_page_align(byte_size); + void *addr = mmap(nullptr, aligned, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + close(fd); + if (addr == MAP_FAILED) { + std::cerr << "[io-metal] mmap failed: " << name << " (" << strerror(errno) << ")\n"; + return false; + } + + out.addr = addr; + return true; +} + +static void shm_close(ShmMapping &m) { + if (m.addr) { + munmap(m.addr, deepx::shmem::shm_page_align(m.byte_size)); + m.addr = nullptr; + } +} + +// ═══════════════════════════════════════════════════════════ +// Tensor metadata from Redis +// ═══════════════════════════════════════════════════════════ + +struct TensorMeta { + std::string key; + std::string dtype; + std::vector shape_data; + std::string shm_name; + size_t byte_size = 0; + bool valid = false; +}; + +static TensorMeta fetch_tensor_meta(redisContext *c, const std::string &key) { + TensorMeta m; + m.key = key; + + redisReply *r = redis_cmd(c, "GET %s", key.c_str()); + if (!r || r->type != REDIS_REPLY_STRING) { + REDIS_FREE(r); + return m; + } + + try { + json meta = json::parse(r->str); + REDIS_FREE(r); + + if (meta.contains("dtype")) m.dtype = meta["dtype"].get(); + if (meta.contains("shape") && meta["shape"].is_array()) { + for (const auto &d : meta["shape"]) { + m.shape_data.push_back(d.get()); + } + } + if (meta.contains("byte_size")) m.byte_size = meta["byte_size"].get(); + if (meta.contains("address") && meta["address"].contains("shm_name")) { + m.shm_name = meta["address"]["shm_name"].get(); + } + m.valid = true; + } catch (const std::exception &e) { + REDIS_FREE(r); + std::cerr << "[io-metal] JSON parse error for tensor " << key << ": " << e.what() << "\n"; + } + + return m; +} + +static inline int64_t element_count(const std::vector &shape) { + int64_t n = 1; + for (auto d : shape) n *= d; + return n; +} + +// ═══════════════════════════════════════════════════════════ +// Print helpers +// ═══════════════════════════════════════════════════════════ + +template +struct type_tag { using type = T; }; + +#define DISPATCH_BY_DTYPE(dtype, Fn) \ + do { \ + if (dtype == "f32" || dtype == "float32") Fn(type_tag{}); \ + else if (dtype == "i64" || dtype == "int64") Fn(type_tag{}); \ + else if (dtype == "i32" || dtype == "int32") Fn(type_tag{}); \ + else if (dtype == "i16" || dtype == "int16") Fn(type_tag{}); \ + else if (dtype == "i8" || dtype == "int8") Fn(type_tag{}); \ + else if (dtype == "bool") Fn(type_tag{}); \ + else { error = "unsupported dtype: " + dtype; return; } \ + } while(0) + +template +static void io_print_data(const void *data, int64_t n, const std::string &format) { + const T *ptr = static_cast(data); + std::cout << "["; + for (int64_t i = 0; i < n; ++i) { + if (i > 0) std::cout << " "; + if (format.empty()) { + std::cout << ptr[i]; + } else { + printf(format.c_str(), ptr[i]); + } + if (i > 0 && i % 32 == 31 && i < n - 1) std::cout << "\n "; + } + std::cout << "]" << std::endl; +} + +static void io_print(const std::string &dtype, const void *data, int64_t n, const std::string &format, std::string &error) { + auto fn = [&](auto tag) { + using T = typename decltype(tag)::type; + io_print_data(data, n, format); + (void)tag; + }; + DISPATCH_BY_DTYPE(dtype, fn); +} + +// ═══════════════════════════════════════════════════════════ +// I/O operations: print, save, load +// ═══════════════════════════════════════════════════════════ + +static size_t dtype_byte_size(const std::string &dtype) { + if (dtype == "f64" || dtype == "float64" || dtype == "i64" || dtype == "int64") return 8; + if (dtype == "f32" || dtype == "float32" || dtype == "i32" || dtype == "int32") return 4; + if (dtype == "f16" || dtype == "float16" || dtype == "i16" || dtype == "int16") return 2; + if (dtype == "i8" || dtype == "int8" || dtype == "bool") return 1; + return 4; // default f32 +} + +// save: persist tensor shape + data to disk +static bool io_save(redisContext *redis, const TensorMeta &meta, const void *data, + const std::string &filepath, std::string &error) { + // Write .shape as JSON + json shape_json; + shape_json["dtype"] = meta.dtype; + shape_json["shape"] = meta.shape_data; + shape_json["size"] = element_count(meta.shape_data); + std::string shape_str = shape_json.dump(); + + std::ofstream shape_fs(filepath + ".shape", std::ios::binary); + if (!shape_fs.is_open()) { + error = "save: cannot open " + filepath + ".shape"; + return false; + } + shape_fs.write(shape_str.c_str(), shape_str.size()); + shape_fs.close(); + + // Write .data as raw binary + std::ofstream data_fs(filepath + ".data", std::ios::binary); + if (!data_fs.is_open()) { + error = "save: cannot open " + filepath + ".data"; + return false; + } + data_fs.write(static_cast(data), meta.byte_size); + data_fs.close(); + + std::cout << "[io-metal] saved tensor to " << filepath << " (dtype=" << meta.dtype + << ", elems=" << element_count(meta.shape_data) << ")\n"; + return true; +} + +// load: read tensor shape + data from disk +static bool io_load(redisContext *redis, const std::string &filepath, + const TensorMeta &out_meta, const std::string &out_key, + void *out_data, std::string &error) { + // Read .shape + std::ifstream shape_fs(filepath + ".shape", std::ios::binary); + if (!shape_fs.is_open()) { + error = "load: cannot open " + filepath + ".shape"; + return false; + } + std::string shape_str((std::istreambuf_iterator(shape_fs)), std::istreambuf_iterator()); + shape_fs.close(); + + json shape_json; + try { + shape_json = json::parse(shape_str); + } catch (const std::exception &e) { + error = std::string("load: shape JSON parse error: ") + e.what(); + return false; + } + + std::string loaded_dtype = shape_json.value("dtype", ""); + int64_t loaded_elems = shape_json.value("size", 0); + int64_t loaded_bytes = loaded_elems * dtype_byte_size(loaded_dtype); + + // Read .data into output SHM + std::ifstream data_fs(filepath + ".data", std::ios::binary); + if (!data_fs.is_open()) { + error = "load: cannot open " + filepath + ".data"; + return false; + } + + size_t actual_read = loaded_bytes; + if (loaded_bytes > out_meta.byte_size) { + actual_read = out_meta.byte_size; + std::cerr << "[io-metal] load: truncating " << loaded_bytes + << " → " << out_meta.byte_size << " bytes (output tensor smaller)\n"; + } + data_fs.read(static_cast(out_data), actual_read); + data_fs.close(); + + // Update tensor metadata in Redis + try { + json updated_meta; + updated_meta["dtype"] = loaded_dtype; + updated_meta["shape"] = shape_json["shape"]; + updated_meta["byte_size"] = loaded_bytes; + if (!out_meta.shm_name.empty()) { + updated_meta["address"]["shm_name"] = out_meta.shm_name; + updated_meta["address"]["node"] = "n1"; + updated_meta["address"]["type"] = "shm"; + } + updated_meta["device"] = "cpu"; + redis_set(redis, out_key, updated_meta.dump()); + } catch (...) { + // metadata update is best-effort + } + + std::cout << "[io-metal] loaded tensor from " << filepath << " (dtype=" << loaded_dtype + << ", elems=" << loaded_elems << ", bytes=" << actual_read << ")\n"; + return true; +} + +// ═══════════════════════════════════════════════════════════ +// Task execution +// ═══════════════════════════════════════════════════════════ + +static void execute_task(redisContext *redis, const json &task) { + std::string vtid = task.value("vtid", ""); + std::string pc = task.value("pc", ""); + std::string opcode = task.value("opcode", ""); + json params = task.value("params", json::object()); + + if (opcode != "load" && (!task.contains("inputs") || task["inputs"].empty())) { + notify_done(redis, vtid, pc, "error", "missing inputs for " + opcode); + return; + } + + // ── Resolve input tensors ── + const auto &inputs = task.contains("inputs") ? task["inputs"] : json::array(); + std::vector input_metas; + std::vector input_shms; + std::vector input_ptrs; + + for (const auto &in : inputs) { + std::string key = in.value("key", ""); + if (key.empty()) { + notify_done(redis, vtid, pc, "error", "input missing key"); + return; + } + + TensorMeta meta = fetch_tensor_meta(redis, key); + if (!meta.valid) { + notify_done(redis, vtid, pc, "error", "input tensor not found: " + key); + return; + } + + ShmMapping shm; + if (!meta.shm_name.empty()) { + if (!shm_open_readwrite(meta.shm_name, meta.byte_size, shm)) { + notify_done(redis, vtid, pc, "error", "shm open failed: " + meta.shm_name); + return; + } + input_ptrs.push_back(shm.addr); + } else { + notify_done(redis, vtid, pc, "error", "input has no shm address: " + key); + return; + } + + input_metas.push_back(meta); + input_shms.push_back(shm); + } + + // ── Resolve output tensor (required for load) ── + const auto &outputs = task.contains("outputs") ? task["outputs"] : json::array(); + std::string out_key; + TensorMeta out_meta; + ShmMapping out_shm; + bool has_output = !outputs.empty(); + + if (has_output) { + const auto &out = outputs[0]; + out_key = out.value("key", ""); + out_meta = fetch_tensor_meta(redis, out_key); + if (!out_meta.valid && opcode == "load") { + notify_done(redis, vtid, pc, "error", "output tensor not found: " + out_key); + for (auto &s : input_shms) shm_close(s); + return; + } + if (out_meta.valid && !out_meta.shm_name.empty()) { + if (!shm_open_readwrite(out_meta.shm_name, out_meta.byte_size, out_shm)) { + notify_done(redis, vtid, pc, "error", "output shm open failed: " + out_meta.shm_name); + for (auto &s : input_shms) shm_close(s); + return; + } + } + } + + std::string dtype = input_metas.empty() ? "f32" : input_metas[0].dtype; + std::string error; + bool ok = false; + + // ── print ── + if (opcode == "print") { + std::string format = params.value("format", ""); + int64_t nelem = element_count(input_metas[0].shape_data); + io_print(dtype, input_ptrs[0], nelem, format, error); + ok = true; + } + // ── save ── + else if (opcode == "save") { + std::string filepath = params.value("arg0", ""); + if (filepath.empty()) { + error = "save: missing file path (arg0)"; + } else { + ok = io_save(redis, input_metas[0], input_ptrs[0], filepath, error); + } + } + // ── load ── + else if (opcode == "load") { + std::string filepath = params.value("arg0", ""); + if (filepath.empty()) { + error = "load: missing file path (arg0)"; + } else if (!has_output || !out_meta.valid) { + error = "load: missing output tensor"; + } else { + ok = io_load(redis, filepath, out_meta, out_key, out_shm.addr, error); + } + } + else { + notify_done(redis, vtid, pc, "error", + "unsupported io opcode: " + opcode); + for (auto &s : input_shms) shm_close(s); + if (has_output) shm_close(out_shm); + return; + } + + // ── Cleanup ── + for (auto &s : input_shms) shm_close(s); + if (has_output) shm_close(out_shm); + + if (ok) { + notify_done(redis, vtid, pc, "ok"); + } else { + if (error.empty()) error = "io dispatch failed for " + opcode; + notify_done(redis, vtid, pc, "error", error); + } +} + +// ═══════════════════════════════════════════════════════════ +// Main +// ═══════════════════════════════════════════════════════════ + +int main(int argc, char **argv) { + const char *redis_addr = "127.0.0.1"; + int redis_port = 6379; + if (argc > 1) redis_addr = argv[1]; + if (argc > 2) redis_port = atoi(argv[2]); + + // Force unbuffered output for diagnostics + std::cout << std::unitbuf; + std::cerr << std::unitbuf; + + std::cout << "[io-metal] I/O plane starting\n"; + { + char cwd[4096]; + if (getcwd(cwd, sizeof(cwd))) { + std::cout << "[io-metal] CWD: " << cwd << "\n"; + } + } + + // 连接 Redis(无限重试,不自退——io-plat 由元程控制退出) + redisContext *redis = nullptr; + while (!redis) { + redis = connect_redis(redis_addr, redis_port); + if (!redis) { + std::cerr << "[io-metal] Redis not available, retrying in 1s...\n"; + sleep(1); + } + } + std::cout << "[io-metal] connected to Redis " << redis_addr << ":" << redis_port << "\n"; + + // 注册实例和算子 + register_instance(redis); + + std::cout << "[io-metal] listening on " << IO_QUEUE << " + " << SYS_QUEUE << "\n"; + + // ── 消费循环 ── + std::atomic running{true}; + while (running) { + redisReply *r = redis_cmd(redis, "BLPOP %s %s %d", IO_QUEUE, SYS_QUEUE, BLOCK_TIMEOUT_SEC); + if (!r) { + // Redis 断连 → 无限重连 + std::cerr << "[io-metal] Redis disconnected, reconnecting...\n"; + redisFree(redis); + redis = nullptr; + while (!redis) { + sleep(1); + redis = connect_redis(redis_addr, redis_port); + if (!redis) { + std::cerr << "[io-metal] Redis still not available, retrying...\n"; + } + } + register_instance(redis); + continue; + } + + if (r->type == REDIS_REPLY_NIL) { + REDIS_FREE(r); + continue; + } + + if (r->type != REDIS_REPLY_ARRAY || r->elements < 2) { + REDIS_FREE(r); + continue; + } + + std::string queue_name(r->element[0]->str); + std::string payload(r->element[1]->str); + REDIS_FREE(r); + + // ── 系统命令处理 ── + if (queue_name == SYS_QUEUE) { + try { + json sys_cmd = json::parse(payload); + std::string cmd = sys_cmd.value("cmd", ""); + if (cmd == "shutdown") { + std::cout << "[io-metal] received sys shutdown command, exiting...\n"; + running = false; + } else { + std::cerr << "[io-metal] unknown sys command: " << cmd << "\n"; + } + } catch (const std::exception &e) { + std::cerr << "[io-metal] sys cmd JSON parse error: " << e.what() << "\n"; + } + continue; + } + + // ── I/O 命令处理 ── + json task; + try { + task = json::parse(payload); + } catch (const std::exception &e) { + std::cerr << "[io-metal] JSON parse error: " << e.what() << "\n"; + continue; + } + + try { + execute_task(redis, task); + } catch (const std::exception &e) { + std::string vtid = task.value("vtid", ""); + std::string pc = task.value("pc", ""); + std::cerr << "[io-metal] task exception: " << e.what() << "\n"; + if (!vtid.empty()) { + notify_done(redis, vtid, pc, "error", e.what()); + } + } + } + + if (redis) { + redis_cmd(redis, "DEL %s", INSTANCE_KEY); + redisFree(redis); + } + std::cout << "[io-metal] shutdown complete.\n"; + return 0; +} diff --git a/executor/deepxcore/src/client/udpserver.cpp b/executor/old-cppcommon/client/udpserver.cpp similarity index 100% rename from executor/deepxcore/src/client/udpserver.cpp rename to executor/old-cppcommon/client/udpserver.cpp diff --git a/executor/deepxcore/src/client/udpserver.hpp b/executor/old-cppcommon/client/udpserver.hpp similarity index 96% rename from executor/deepxcore/src/client/udpserver.hpp rename to executor/old-cppcommon/client/udpserver.hpp index f32fe89a..731b7e42 100644 --- a/executor/deepxcore/src/client/udpserver.hpp +++ b/executor/old-cppcommon/client/udpserver.hpp @@ -9,7 +9,7 @@ #include #include -#include "deepx/tf/tf.hpp" +#include "tf.hpp" namespace client{ using namespace std; class udpserver diff --git a/executor/deepxcore/src/client/unixsocketserver.cpp b/executor/old-cppcommon/client/unixsocketserver.cpp similarity index 100% rename from executor/deepxcore/src/client/unixsocketserver.cpp rename to executor/old-cppcommon/client/unixsocketserver.cpp diff --git a/executor/deepxcore/src/client/unixsocketserver.hpp b/executor/old-cppcommon/client/unixsocketserver.hpp similarity index 100% rename from executor/deepxcore/src/client/unixsocketserver.hpp rename to executor/old-cppcommon/client/unixsocketserver.hpp diff --git a/executor/deepxcore/src/client/worker.hpp b/executor/old-cppcommon/client/worker.hpp similarity index 100% rename from executor/deepxcore/src/client/worker.hpp rename to executor/old-cppcommon/client/worker.hpp diff --git a/executor/old-cppcommon/dtype.hpp b/executor/old-cppcommon/dtype.hpp new file mode 100644 index 00000000..46cbf054 --- /dev/null +++ b/executor/old-cppcommon/dtype.hpp @@ -0,0 +1,24 @@ +#ifndef DEEPX_DTYPE_HPP +#define DEEPX_DTYPE_HPP + +#include +#include +#include +#include +#include "stdutil/string.hpp" +#include "stdutil/num.hpp" + +namespace deepx +{ + using namespace std; + + + + + + + + + +} // namespace deepx +#endif diff --git a/executor/deepxcore/src/deepx/mem/mem.hpp b/executor/old-cppcommon/mem/mem.hpp similarity index 99% rename from executor/deepxcore/src/deepx/mem/mem.hpp rename to executor/old-cppcommon/mem/mem.hpp index 710c3773..dd73459f 100644 --- a/executor/deepxcore/src/deepx/mem/mem.hpp +++ b/executor/old-cppcommon/mem/mem.hpp @@ -7,7 +7,7 @@ #include #include "iostream" -#include "deepx/tensor.hpp" +#include "../tensor.hpp" namespace deepx::mem { using namespace std; diff --git a/executor/deepxcore/src/deepx/shape.cpp b/executor/old-cppcommon/shape.cpp similarity index 96% rename from executor/deepxcore/src/deepx/shape.cpp rename to executor/old-cppcommon/shape.cpp index cc802d90..d12ab602 100644 --- a/executor/deepxcore/src/deepx/shape.cpp +++ b/executor/old-cppcommon/shape.cpp @@ -5,7 +5,7 @@ #include #include "tensor.hpp" -#include "deepx/dtype.hpp" +#include "dtype.hpp" namespace deepx { Shape::Shape(const int *shape, int dim) @@ -92,7 +92,7 @@ namespace deepx std::string Shape::toYaml() const{ YAML::Node node; - node["dtype"] = precision_str(dtype); + node["dtype"] = precision_to_string(dtype); node["dim"] = dim(); node["shape"] = shape; node["stride"] = strides; @@ -101,7 +101,7 @@ namespace deepx } void Shape::fromYaml(const std::string &yaml){ YAML::Node node = YAML::Load(yaml); - dtype = precision(node["dtype"].as()); + dtype = precision_from_string(node["dtype"].as()); shape = node["shape"].as>(); strides=node["stride"].as>(); size=node["size"].as(); diff --git a/executor/deepxcore/src/deepx/shape.hpp b/executor/old-cppcommon/shape.hpp similarity index 98% rename from executor/deepxcore/src/deepx/shape.hpp rename to executor/old-cppcommon/shape.hpp index 482142cd..07251f38 100644 --- a/executor/deepxcore/src/deepx/shape.hpp +++ b/executor/old-cppcommon/shape.hpp @@ -8,7 +8,8 @@ #include #include "stdutil/fs.hpp" -#include "deepx/dtype.hpp" +#include "deepx/dtype/precision.hpp" +#include "dtype.hpp" namespace deepx { // omp内线程局部变量 diff --git a/executor/deepxcore/src/deepx/shape_changeshape.cpp b/executor/old-cppcommon/shape_changeshape.cpp similarity index 99% rename from executor/deepxcore/src/deepx/shape_changeshape.cpp rename to executor/old-cppcommon/shape_changeshape.cpp index c02d2f5c..0cbd0201 100644 --- a/executor/deepxcore/src/deepx/shape_changeshape.cpp +++ b/executor/old-cppcommon/shape_changeshape.cpp @@ -1,7 +1,7 @@ #include #include -#include "deepx/shape_changeshape.hpp" +#include "shape_changeshape.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/shape_changeshape.hpp b/executor/old-cppcommon/shape_changeshape.hpp similarity index 97% rename from executor/deepxcore/src/deepx/shape_changeshape.hpp rename to executor/old-cppcommon/shape_changeshape.hpp index 65a36670..6adb04be 100644 --- a/executor/deepxcore/src/deepx/shape_changeshape.hpp +++ b/executor/old-cppcommon/shape_changeshape.hpp @@ -5,8 +5,8 @@ #include #include #include -#include "deepx/tensor.hpp" -#include "deepx/shape.hpp" +#include "tensor.hpp" +#include "shape.hpp" #include "stdutil/error.hpp" namespace deepx diff --git a/executor/deepxcore/src/deepx/shape_matmul.cpp b/executor/old-cppcommon/shape_matmul.cpp similarity index 94% rename from executor/deepxcore/src/deepx/shape_matmul.cpp rename to executor/old-cppcommon/shape_matmul.cpp index 46247c70..8614fd2b 100644 --- a/executor/deepxcore/src/deepx/shape_matmul.cpp +++ b/executor/old-cppcommon/shape_matmul.cpp @@ -1,6 +1,6 @@ #include -#include "deepx/shape_matmul.hpp" +#include "shape_matmul.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/shape_matmul.hpp b/executor/old-cppcommon/shape_matmul.hpp similarity index 86% rename from executor/deepxcore/src/deepx/shape_matmul.hpp rename to executor/old-cppcommon/shape_matmul.hpp index d4985e32..32e16e97 100644 --- a/executor/deepxcore/src/deepx/shape_matmul.hpp +++ b/executor/old-cppcommon/shape_matmul.hpp @@ -1,7 +1,7 @@ #ifndef DEEPX_SHAPE_MATMUL_HPP #define DEEPX_SHAPE_MATMUL_HPP -#include "deepx/shape.hpp" +#include "shape.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/shape_range.cpp b/executor/old-cppcommon/shape_range.cpp similarity index 99% rename from executor/deepxcore/src/deepx/shape_range.cpp rename to executor/old-cppcommon/shape_range.cpp index bd5e2885..ebb4df9b 100644 --- a/executor/deepxcore/src/deepx/shape_range.cpp +++ b/executor/old-cppcommon/shape_range.cpp @@ -4,7 +4,7 @@ #include #include -#include "deepx/shape.hpp" +#include "shape.hpp" namespace deepx { static int checkdim(int dimCount, int dim) diff --git a/executor/deepxcore/src/deepx/shape_reduce.cpp b/executor/old-cppcommon/shape_reduce.cpp similarity index 98% rename from executor/deepxcore/src/deepx/shape_reduce.cpp rename to executor/old-cppcommon/shape_reduce.cpp index eeb427a2..d9449db2 100644 --- a/executor/deepxcore/src/deepx/shape_reduce.cpp +++ b/executor/old-cppcommon/shape_reduce.cpp @@ -4,7 +4,7 @@ #include #include "stdutil/error.hpp" -#include "deepx/shape_reduce.hpp" +#include "shape_reduce.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/shape_reduce.hpp b/executor/old-cppcommon/shape_reduce.hpp similarity index 95% rename from executor/deepxcore/src/deepx/shape_reduce.hpp rename to executor/old-cppcommon/shape_reduce.hpp index ac15d4e7..426e02df 100644 --- a/executor/deepxcore/src/deepx/shape_reduce.hpp +++ b/executor/old-cppcommon/shape_reduce.hpp @@ -1,7 +1,7 @@ #ifndef DEEPX_SHAPE_SUM_HPP #define DEEPX_SHAPE_SUM_HPP -#include "deepx/shape.hpp" +#include "shape.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/shape_tensorinit.cpp b/executor/old-cppcommon/shape_tensorinit.cpp similarity index 95% rename from executor/deepxcore/src/deepx/shape_tensorinit.cpp rename to executor/old-cppcommon/shape_tensorinit.cpp index dd93e798..a3fa1e14 100644 --- a/executor/deepxcore/src/deepx/shape_tensorinit.cpp +++ b/executor/old-cppcommon/shape_tensorinit.cpp @@ -1,4 +1,4 @@ -#include "deepx/shape_tensorinit.hpp" +#include "shape_tensorinit.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/shape_tensorinit.hpp b/executor/old-cppcommon/shape_tensorinit.hpp similarity index 88% rename from executor/deepxcore/src/deepx/shape_tensorinit.hpp rename to executor/old-cppcommon/shape_tensorinit.hpp index 5bd5629a..00f2bc1d 100644 --- a/executor/deepxcore/src/deepx/shape_tensorinit.hpp +++ b/executor/old-cppcommon/shape_tensorinit.hpp @@ -1,7 +1,7 @@ #ifndef DEEPX_SHAPE_TENSORINIT_HPP #define DEEPX_SHAPE_TENSORINIT_HPP -#include "deepx/shape.hpp" +#include "shape.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/tensor.hpp b/executor/old-cppcommon/tensor.hpp similarity index 98% rename from executor/deepxcore/src/deepx/tensor.hpp rename to executor/old-cppcommon/tensor.hpp index 9e46b222..43ba8565 100644 --- a/executor/deepxcore/src/deepx/tensor.hpp +++ b/executor/old-cppcommon/tensor.hpp @@ -6,9 +6,9 @@ #include #include -#include "deepx/shape.hpp" -#include "deepx/dtype.hpp" -#include "deepx/tensorbase.hpp" +#include "shape.hpp" +#include "dtype.hpp" +#include "tensorbase.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/tensorbase.hpp b/executor/old-cppcommon/tensorbase.hpp similarity index 97% rename from executor/deepxcore/src/deepx/tensorbase.hpp rename to executor/old-cppcommon/tensorbase.hpp index 6e8806e0..42338458 100644 --- a/executor/deepxcore/src/deepx/tensorbase.hpp +++ b/executor/old-cppcommon/tensorbase.hpp @@ -1,7 +1,7 @@ #ifndef DEEPX_TENSORBASE_HPP #define DEEPX_TENSORBASE_HPP -#include "deepx/shape.hpp" +#include "shape.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/tensorfunc/authors.hpp b/executor/old-cppcommon/tensorfunc/authors.hpp similarity index 100% rename from executor/deepxcore/src/deepx/tensorfunc/authors.hpp rename to executor/old-cppcommon/tensorfunc/authors.hpp diff --git a/executor/deepxcore/src/deepx/tensorfunc/changeshape.hpp b/executor/old-cppcommon/tensorfunc/changeshape.hpp similarity index 99% rename from executor/deepxcore/src/deepx/tensorfunc/changeshape.hpp rename to executor/old-cppcommon/tensorfunc/changeshape.hpp index c0eb4306..9736eaff 100644 --- a/executor/deepxcore/src/deepx/tensorfunc/changeshape.hpp +++ b/executor/old-cppcommon/tensorfunc/changeshape.hpp @@ -2,7 +2,7 @@ #define DEEPX_TENSORFUNC_CHANGESHAPE_HPP #include -#include "deepx/tensor.hpp" +#include "../tensor.hpp" #include "stdutil/error.hpp" namespace deepx::tensorfunc diff --git a/executor/deepxcore/src/deepx/tensorfunc/elementwise.hpp b/executor/old-cppcommon/tensorfunc/elementwise.hpp similarity index 99% rename from executor/deepxcore/src/deepx/tensorfunc/elementwise.hpp rename to executor/old-cppcommon/tensorfunc/elementwise.hpp index a708268b..bdba435a 100644 --- a/executor/deepxcore/src/deepx/tensorfunc/elementwise.hpp +++ b/executor/old-cppcommon/tensorfunc/elementwise.hpp @@ -1,7 +1,7 @@ #ifndef DEEPX_TENSORFUNC_ELEMENTWISE_HPP #define DEEPX_TENSORFUNC_ELEMENTWISE_HPP -#include "deepx/tensor.hpp" +#include "../tensor.hpp" #include "stdutil/error.hpp" namespace deepx::tensorfunc diff --git a/executor/deepxcore/src/deepx/tensorfunc/init.hpp b/executor/old-cppcommon/tensorfunc/init.hpp similarity index 98% rename from executor/deepxcore/src/deepx/tensorfunc/init.hpp rename to executor/old-cppcommon/tensorfunc/init.hpp index 934e25f1..0b3a21b3 100644 --- a/executor/deepxcore/src/deepx/tensorfunc/init.hpp +++ b/executor/old-cppcommon/tensorfunc/init.hpp @@ -1,7 +1,7 @@ #ifndef DEEPX_TENSORFUNC_INIT_HPP #define DEEPX_TENSORFUNC_INIT_HPP -#include "deepx/tensor.hpp" +#include "../tensor.hpp" #include "stdutil/error.hpp" namespace deepx::tensorfunc diff --git a/executor/deepxcore/src/deepx/tensorfunc/io.hpp b/executor/old-cppcommon/tensorfunc/io.hpp similarity index 94% rename from executor/deepxcore/src/deepx/tensorfunc/io.hpp rename to executor/old-cppcommon/tensorfunc/io.hpp index 59a3606e..d267d5ca 100644 --- a/executor/deepxcore/src/deepx/tensorfunc/io.hpp +++ b/executor/old-cppcommon/tensorfunc/io.hpp @@ -1,7 +1,7 @@ #ifndef DEEPX_TENSORFUNC_IO_HPP #define DEEPX_TENSORFUNC_IO_HPP -#include "deepx/tensor.hpp" +#include "../tensor.hpp" #include "stdutil/fs.hpp" namespace deepx::tensorfunc{ diff --git a/executor/deepxcore/src/deepx/tensorfunc/matmul.hpp b/executor/old-cppcommon/tensorfunc/matmul.hpp similarity index 93% rename from executor/deepxcore/src/deepx/tensorfunc/matmul.hpp rename to executor/old-cppcommon/tensorfunc/matmul.hpp index 35a50114..7152b2cb 100644 --- a/executor/deepxcore/src/deepx/tensorfunc/matmul.hpp +++ b/executor/old-cppcommon/tensorfunc/matmul.hpp @@ -1,8 +1,8 @@ #ifndef DEEPX_TENSORFUNC_MATMUL_HPP #define DEEPX_TENSORFUNC_MATMUL_HPP -#include "deepx/tensor.hpp" -#include "deepx/tensorfunc/authors.hpp" +#include "../tensor.hpp" +#include "authors.hpp" #include "stdutil/error.hpp" namespace deepx::tensorfunc { diff --git a/executor/deepxcore/src/deepx/tensorfunc/reduce.hpp b/executor/old-cppcommon/tensorfunc/reduce.hpp similarity index 96% rename from executor/deepxcore/src/deepx/tensorfunc/reduce.hpp rename to executor/old-cppcommon/tensorfunc/reduce.hpp index a94c1908..602a385c 100644 --- a/executor/deepxcore/src/deepx/tensorfunc/reduce.hpp +++ b/executor/old-cppcommon/tensorfunc/reduce.hpp @@ -1,8 +1,8 @@ #ifndef DEEPX_TENSORFUNC_REDUCE_HPP #define DEEPX_TENSORFUNC_REDUCE_HPP - #include "deepx/tensor.hpp" -#include "deepx/tensorfunc/authors.hpp" + #include "../tensor.hpp" +#include "authors.hpp" #include "stdutil/error.hpp" namespace deepx::tensorfunc diff --git a/executor/deepxcore/src/deepx/tensorfunc/tensorlife.hpp b/executor/old-cppcommon/tensorfunc/tensorlife.hpp similarity index 94% rename from executor/deepxcore/src/deepx/tensorfunc/tensorlife.hpp rename to executor/old-cppcommon/tensorfunc/tensorlife.hpp index cc06c69d..65db609a 100644 --- a/executor/deepxcore/src/deepx/tensorfunc/tensorlife.hpp +++ b/executor/old-cppcommon/tensorfunc/tensorlife.hpp @@ -1,7 +1,7 @@ #ifndef DEEPX_TENSORFUNC_TENSORLIFE_HPP #define DEEPX_TENSORFUNC_TENSORLIFE_HPP -#include "deepx/tensor.hpp" +#include "../tensor.hpp" namespace deepx::tensorfunc { diff --git a/executor/deepxcore/src/deepx/tf/tf.cpp b/executor/old-cppcommon/tf/tf.cpp similarity index 95% rename from executor/deepxcore/src/deepx/tf/tf.cpp rename to executor/old-cppcommon/tf/tf.cpp index 6c52b3e4..e2269542 100644 --- a/executor/deepxcore/src/deepx/tf/tf.cpp +++ b/executor/old-cppcommon/tf/tf.cpp @@ -3,7 +3,7 @@ #include #include -#include "deepx/tf/tf.hpp" +#include "tf.hpp" #include "stdutil/time.hpp" #include "stdutil/string.hpp" namespace deepx::tf @@ -28,18 +28,18 @@ namespace deepx::tf } if (!type.empty()) { - this->dtype = deepx::dtype(type); + this->dtype.from_string(type); this->textvalue = textvalue; } else { - this->dtype = deepx::dtype(textvalue); + this->dtype.from_string(textvalue); this->textvalue = textvalue; } } string Param::to_string() const { - return dtype_str(dtype) + ":" + textvalue; + return dtype.to_string() + ":" + textvalue; } string TFMetadata::to_string() const { @@ -192,7 +192,7 @@ namespace deepx::tf } // 解析单个值为具体C++类型 - any parse_single_value(const string &value_str, const TypeDef &dtype) + any parse_single_value(const string &value_str, const TypeSpec &dtype) { // 如果是字符串类型,直接返回 if (dtype.precision() == Precision::String) @@ -261,9 +261,7 @@ namespace deepx::tf } catch (const std::exception &e) { - throw runtime_error("Failed to parse value '" + value_str + "' as " + - base_category_str(dtype.category()) + "<" + - precision_str(dtype.precision()) + ">"); + throw runtime_error("Failed to parse value '" + value_str + "' as " +dtype.to_string()); } return value_str; // 默认作为字符串处理 @@ -401,8 +399,8 @@ namespace deepx::tf for (size_t i = 0; i < args.size(); ++i) { // 当前TF的类型可能包含多个选项 - TypeDef dtype = args[i].dtype; - TypeDef other_dtype = other.args[i].dtype; + TypeSpec dtype = args[i].dtype; + TypeSpec other_dtype = other.args[i].dtype; // TODO } diff --git a/executor/deepxcore/src/deepx/tf/tf.hpp b/executor/old-cppcommon/tf/tf.hpp similarity index 96% rename from executor/deepxcore/src/deepx/tf/tf.hpp rename to executor/old-cppcommon/tf/tf.hpp index f5cf55c1..397d001f 100644 --- a/executor/deepxcore/src/deepx/tf/tf.hpp +++ b/executor/old-cppcommon/tf/tf.hpp @@ -9,9 +9,9 @@ #include #include -#include "deepx/tensor.hpp" -#include "deepx/mem/mem.hpp" -#include "deepx/dtype.hpp" +#include "../tensor.hpp" +#include "../mem/mem.hpp" +#include "deepx/dtype/typespec.hpp" #include "stdutil/error.hpp" #include "stdutil/num.hpp" @@ -24,11 +24,11 @@ namespace deepx::tf struct Param { - TypeDef dtype; + TypeSpec dtype; string textvalue; Param(const string &textvalue = "", const DataCategory &dt = DataCategory::Unknown, const Precision &prec = Precision::Any) - : textvalue(textvalue), dtype(make_dtype(dt, prec)) {} + : textvalue(textvalue), dtype(TypeSpec(dt, prec)) {} void parse(const string ¶m); string to_string() const; @@ -152,7 +152,7 @@ namespace deepx::tf std::string item; while (std::getline(ss, item, ' ')) { - result.push_back(to(item)); + result.push_back(item); } return result; } diff --git a/executor/deepxcore/src/deepx/tf/tffactory.cpp b/executor/old-cppcommon/tf/tffactory.cpp similarity index 93% rename from executor/deepxcore/src/deepx/tf/tffactory.cpp rename to executor/old-cppcommon/tf/tffactory.cpp index 7ddbf555..1acaec95 100644 --- a/executor/deepxcore/src/deepx/tf/tffactory.cpp +++ b/executor/old-cppcommon/tf/tffactory.cpp @@ -4,8 +4,8 @@ #include #include -#include "deepx/tf/tffactory.hpp" -#include "deepx/dtype.hpp" +#include "tffactory.hpp" +#include "../dtype.hpp" namespace deepx::tf { @@ -37,13 +37,13 @@ namespace deepx::tf } // 提取参数和返回值类型 - vector arg_types; + vector arg_types; for (const auto &arg : other.args) { arg_types.push_back(arg.dtype); } - vector return_types; + vector return_types; for (const auto &ret : other.returns) { return_types.push_back(ret.dtype); @@ -63,14 +63,14 @@ namespace deepx::tf { if (i > 0) cerr << ", "; - cerr << dtype_str(registered_tf->args[i].dtype); + cerr << registered_tf->args[i].dtype.to_string(); } cerr << ")->("; for (size_t i = 0; i < registered_tf->returns.size(); i++) { if (i > 0) cerr << ", "; - cerr << dtype_str(registered_tf->returns[i].dtype); + cerr << registered_tf->returns[i].dtype.to_string(); } cerr << ")" << endl; } diff --git a/executor/deepxcore/src/deepx/tf/tffactory.hpp b/executor/old-cppcommon/tf/tffactory.hpp similarity index 85% rename from executor/deepxcore/src/deepx/tf/tffactory.hpp rename to executor/old-cppcommon/tf/tffactory.hpp index 6f8e97dc..969c2895 100644 --- a/executor/deepxcore/src/deepx/tf/tffactory.hpp +++ b/executor/old-cppcommon/tf/tffactory.hpp @@ -1,15 +1,15 @@ #ifndef DEEPX_TF_TFFACTORY_HPP #define DEEPX_TF_TFFACTORY_HPP -#include "deepx/tf/tf.hpp" +#include "tf.hpp" namespace deepx::tf { struct TypeSignature { - vector args; - vector returns; + vector args; + vector returns; bool is_compatible(const TypeSignature &other) const { @@ -18,7 +18,7 @@ namespace deepx::tf } private: - static bool is_compatible_types(const vector &def, const vector &other) + static bool is_compatible_types(const vector &def, const vector &other) { if (def.size() != other.size()) return false; @@ -36,20 +36,20 @@ namespace deepx::tf vector> tfs; // 获取匹配的TF实现 - std::shared_ptr get_matching_tf(const vector &arg_types, - const vector &return_types) const + std::shared_ptr get_matching_tf(const vector &arg_types, + const vector &return_types) const { TypeSignature target{arg_types, return_types}; for (const auto &tf : tfs) { - vector tf_arg_types; + vector tf_arg_types; for (const auto &arg : tf->args) { tf_arg_types.push_back(arg.dtype); } - vector tf_return_types; + vector tf_return_types; for (const auto &ret : tf->returns) { tf_return_types.push_back(ret.dtype); diff --git a/executor/deepxcore/src/deepx/vector_combination.cpp b/executor/old-cppcommon/vector_combination.cpp similarity index 96% rename from executor/deepxcore/src/deepx/vector_combination.cpp rename to executor/old-cppcommon/vector_combination.cpp index fb05e6ce..2cce2dca 100644 --- a/executor/deepxcore/src/deepx/vector_combination.cpp +++ b/executor/old-cppcommon/vector_combination.cpp @@ -1,6 +1,6 @@ #include -#include "deepx/vector_combination.hpp" +#include "vector_combination.hpp" namespace deepx { diff --git a/executor/deepxcore/src/deepx/vector_combination.hpp b/executor/old-cppcommon/vector_combination.hpp similarity index 100% rename from executor/deepxcore/src/deepx/vector_combination.hpp rename to executor/old-cppcommon/vector_combination.hpp diff --git a/executor/op-cuda/build.sh b/executor/op-cuda/build.sh index 6b33ddb8..d5c21fb8 100644 --- a/executor/op-cuda/build.sh +++ b/executor/op-cuda/build.sh @@ -1,4 +1,8 @@ -mkdir -p build && cd build -rm -rf build/* -cmake .. +#!/usr/bin/env bash +set -euo pipefail + +mkdir -p /tmp/deepx/executor/op-cuda/build +cd /tmp/deepx/executor/op-cuda/build +rm -rf ./* +cmake "$(cd "$(dirname "$0")" && pwd)" make -j$(nproc) diff --git a/executor/op-cuda/src/deepx/tf/arg.hpp b/executor/op-cuda/src/deepx/tf/arg.hpp index 4b33c457..05d0caf8 100644 --- a/executor/op-cuda/src/deepx/tf/arg.hpp +++ b/executor/op-cuda/src/deepx/tf/arg.hpp @@ -31,7 +31,7 @@ namespace deepx::tf error = "argset(int32) must have 1 argument"; return 1; } - TypeDef datatype = this->returns[0].dtype; + TypeSpec datatype = this->returns[0].dtype; if (uint8_t(datatype.category() & DataCategory::Var) == 0) { error = "datatype must be var"; @@ -87,7 +87,7 @@ namespace deepx::tf int run(shared_ptr mem, string &error) override { string name = this->returns[0].textvalue; - TypeDef datatype = this->returns[0].dtype; + TypeSpec datatype = this->returns[0].dtype; if (uint8_t(datatype.category() & DataCategory::Vector) == 0) { error = "datatype must be vector"; diff --git a/executor/op-cuda/src/deepx/tf/tensorlife.hpp b/executor/op-cuda/src/deepx/tf/tensorlife.hpp index 0db28933..8500285b 100644 --- a/executor/op-cuda/src/deepx/tf/tensorlife.hpp +++ b/executor/op-cuda/src/deepx/tf/tensorlife.hpp @@ -24,7 +24,7 @@ namespace deepx::tf int run(shared_ptr mem, string &error) override { string name = this->returns[0].textvalue; - TypeDef type = this->returns[0].dtype; + TypeSpec type = this->returns[0].dtype; if (uint8_t(type.category() & DataCategory::Tensor) == 0) { error = "newtensor: return type must include tensor category"; diff --git a/executor/op-mem-mps/.gitignore b/executor/op-mem-mps/.gitignore deleted file mode 100644 index 93ad0a0e..00000000 --- a/executor/op-mem-mps/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -build/ -.DS_Store -*.xcuserdata/ -*.xcworkspace/ -*.xcodeproj/ diff --git a/executor/op-mem-mps/CMakeLists.txt b/executor/op-mem-mps/CMakeLists.txt deleted file mode 100644 index 9b361f5a..00000000 --- a/executor/op-mem-mps/CMakeLists.txt +++ /dev/null @@ -1,102 +0,0 @@ -cmake_minimum_required(VERSION 3.15...3.29) -project(deepx-executor-mps LANGUAGES CXX OBJCXX) - -# 设置 C++ 标准 -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED True) -set(CMAKE_BUILD_TYPE Debug) - -# 仅支持 macOS -if(NOT APPLE) - message(FATAL_ERROR "op-mem-mps 仅支持 macOS") -endif() - -# 包含头文件目录 -include_directories(src) - -add_subdirectory(../cpp-common common) - -# 源文件 -file(GLOB_RECURSE DEEPX_SOURCES "src/deepx/*.cpp" "src/deepx/*.mm" "src/deepx/*.hpp" "src/deepx/*.h") -file(GLOB_RECURSE CLIENT_SOURCES "src/client/*.cpp" "src/client/*.mm") - -find_library(METAL_FRAMEWORK Metal REQUIRED) -find_library(MPS_FRAMEWORK MetalPerformanceShaders REQUIRED) -find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED) - -find_package(yaml-cpp REQUIRED) -set(YAMLCPP_LIB "") -if (TARGET yaml-cpp::yaml-cpp) - set(YAMLCPP_LIB yaml-cpp::yaml-cpp) -else() - set(YAMLCPP_LIB yaml-cpp) -endif() - -add_library(deepx_mps SHARED - ${DEEPX_SOURCES} -) - -# --- Offline compile Metal shaders into default.metallib --- -set(METAL_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src/deepx/tensorfunc) -set(METAL_SOURCES - ${METAL_SRC_DIR}/elementwise_miaobyte.metal -) - -set(METAL_AIR_DIR ${CMAKE_CURRENT_BINARY_DIR}/metal_air) -set(METAL_LIB ${CMAKE_CURRENT_BINARY_DIR}/default.metallib) - -option(DEEPX_MPS_OFFLINE_METAL "Build .metallib at build time (requires Xcode/metal tools)" ON) - -if (DEEPX_MPS_OFFLINE_METAL) - execute_process( - COMMAND xcrun -sdk macosx -find metal - OUTPUT_VARIABLE METALC - OUTPUT_STRIP_TRAILING_WHITESPACE - RESULT_VARIABLE METALC_RV - ERROR_QUIET - ) - execute_process( - COMMAND xcrun -sdk macosx -find metallib - OUTPUT_VARIABLE METALLIB - OUTPUT_STRIP_TRAILING_WHITESPACE - RESULT_VARIABLE METALLIB_RV - ERROR_QUIET - ) - - if (METALC_RV EQUAL 0 AND METALLIB_RV EQUAL 0) - add_custom_command( - OUTPUT ${METAL_LIB} - COMMAND ${CMAKE_COMMAND} -E make_directory ${METAL_AIR_DIR} - COMMAND ${METALC} -c ${METAL_SRC_DIR}/elementwise_miaobyte.metal -o ${METAL_AIR_DIR}/elementwise_miaobyte.air - COMMAND ${METALLIB} ${METAL_AIR_DIR}/elementwise_miaobyte.air -o ${METAL_LIB} - DEPENDS ${METAL_SOURCES} - VERBATIM - ) - - add_custom_target(deepx_metal_kernels ALL DEPENDS ${METAL_LIB}) - add_dependencies(deepx_mps deepx_metal_kernels) - else() - message(WARNING "Metal offline tools not found (xcrun metal/metallib). Install Xcode or CLT, or set -DDEEPX_MPS_OFFLINE_METAL=OFF.") - endif() -endif() -# ---------------------------------------------------------- - -target_link_libraries(deepx_mps - PUBLIC - deepx_common - ${YAMLCPP_LIB} - ${METAL_FRAMEWORK} - ${MPS_FRAMEWORK} - ${FOUNDATION_FRAMEWORK} -) - -add_executable(${PROJECT_NAME} ${CLIENT_SOURCES}) - -target_link_libraries(${PROJECT_NAME} - PRIVATE - deepx_mps -) - -# 测试 -add_subdirectory(test/tensorfunc) - diff --git a/executor/op-mem-mps/README.md b/executor/op-mem-mps/README.md deleted file mode 100644 index 39df974a..00000000 --- a/executor/op-mem-mps/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# op-mem-mps - -基于 Apple Metal / MPS 的 DeepX 执行器原型(macOS)。 - -## 依赖 - -- macOS -- Xcode Command Line Tools -- CMake 3.15+ - -## 构建 - -```bash -./build.sh -``` - -## 运行 - -```bash -./build/deepx-executor-mps -``` diff --git a/executor/op-mem-mps/build.sh b/executor/op-mem-mps/build.sh deleted file mode 100755 index 415fdc3a..00000000 --- a/executor/op-mem-mps/build.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -mkdir -p build -cd build -rm -rf ./* -cmake .. -cmake --build . -j$(sysctl -n hw.ncpu) diff --git a/executor/op-mem-mps/src/client/main.mm b/executor/op-mem-mps/src/client/main.mm deleted file mode 100644 index 2a16c768..00000000 --- a/executor/op-mem-mps/src/client/main.mm +++ /dev/null @@ -1,15 +0,0 @@ -#include -#include "deepx/mps_context.hpp" -#include "deepx/mps_device.hpp" - -int main() -{ - deepx::mps::MpsContext ctx; - auto info = deepx::mps::get_default_device_info(); - - std::cout << "MPS device: " << info.name - << " (available=" << (info.supports_mps ? "true" : "false") << ")\n"; - std::cout << "Context valid: " << (ctx.is_valid() ? "true" : "false") << "\n"; - - return ctx.is_valid() ? 0 : 1; -} diff --git a/executor/op-mem-mps/src/deepx/tensorfunc/elementwise_common.hpp b/executor/op-mem-mps/src/deepx/tensorfunc/elementwise_common.hpp deleted file mode 100644 index ce03f0b8..00000000 --- a/executor/op-mem-mps/src/deepx/tensorfunc/elementwise_common.hpp +++ /dev/null @@ -1,94 +0,0 @@ -#ifndef DEEPX_TENSORFUNC_ELEMENTWISE_COMMON_HPP -#define DEEPX_TENSORFUNC_ELEMENTWISE_COMMON_HPP - -#if defined(__APPLE__) - #include -#endif - -#include -#include -#include - -#include "deepx/tensor.hpp" -#include "deepx/tensorfunc/mps_common.hpp" - -namespace deepx::tensorfunc::detail -{ - template - inline void assert_same_shape(const Tensor &A, const Tensor &B, const Tensor &C) - { - if (A.shape.size != B.shape.size || A.shape.size != C.shape.size || - A.shape.shape != B.shape.shape || A.shape.shape != C.shape.shape) - { - throw std::invalid_argument("shape mismatch"); - } - } - - template - inline void add_cpu(const Tensor &A, const Tensor &B, Tensor &C) - { - for (int64_t i = 0; i < A.shape.size; ++i) - { - C.data[i] = A.data[i] + B.data[i]; - } - } -} - -namespace deepx::mps::kernels -{ -#if defined(__APPLE__) && TARGET_OS_OSX && defined(__OBJC__) - inline deepx::mps::common::MetalKernelRuntime &elementwise_runtime() - { - static deepx::mps::common::MetalKernelRuntime rt("elementwise_miaobyte.metal"); - return rt; - } - - inline bool add_f32(const float *a, const float *b, float *c, int64_t n) - { - return elementwise_runtime().dispatch_binary_1d("add_f32", a, b, c, static_cast(n), sizeof(float)); - } - -#if defined(__FLT16_MANT_DIG__) - inline bool add_f16(const _Float16 *a, const _Float16 *b, _Float16 *c, int64_t n) - { - return elementwise_runtime().dispatch_binary_1d("add_f16", a, b, c, static_cast(n), 2); - } -#endif - - inline bool add_i8(const int8_t *a, const int8_t *b, int8_t *c, int64_t n) - { - return elementwise_runtime().dispatch_binary_1d("add_i8", a, b, c, static_cast(n), sizeof(int8_t)); - } - - inline bool add_i16(const int16_t *a, const int16_t *b, int16_t *c, int64_t n) - { - return elementwise_runtime().dispatch_binary_1d("add_i16", a, b, c, static_cast(n), sizeof(int16_t)); - } - - inline bool add_i32(const int32_t *a, const int32_t *b, int32_t *c, int64_t n) - { - return elementwise_runtime().dispatch_binary_1d("add_i32", a, b, c, static_cast(n), sizeof(int32_t)); - } - - inline bool add_i64(const int64_t *a, const int64_t *b, int64_t *c, int64_t n) - { - return elementwise_runtime().dispatch_binary_1d("add_i64", a, b, c, static_cast(n), sizeof(int64_t)); - } - -#else - - inline bool add_f32(const float *, const float *, float *, int64_t) { return false; } - -#if defined(__FLT16_MANT_DIG__) - inline bool add_f16(const _Float16 *, const _Float16 *, _Float16 *, int64_t) { return false; } -#endif - - inline bool add_i8(const int8_t *, const int8_t *, int8_t *, int64_t) { return false; } - inline bool add_i16(const int16_t *, const int16_t *, int16_t *, int64_t) { return false; } - inline bool add_i32(const int32_t *, const int32_t *, int32_t *, int64_t) { return false; } - inline bool add_i64(const int64_t *, const int64_t *, int64_t *, int64_t) { return false; } - -#endif -} // namespace deepx::mps::kernels - -#endif // DEEPX_TENSORFUNC_ELEMENTWISE_COMMON_HPP diff --git a/executor/op-mem-mps/src/deepx/tensorfunc/elementwise_miaobyte.hpp b/executor/op-mem-mps/src/deepx/tensorfunc/elementwise_miaobyte.hpp deleted file mode 100644 index ecc2057c..00000000 --- a/executor/op-mem-mps/src/deepx/tensorfunc/elementwise_miaobyte.hpp +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef DEEPX_TENSORFUNC_ELEMENTWISE_MIAOBYTE_HPP -#define DEEPX_TENSORFUNC_ELEMENTWISE_MIAOBYTE_HPP - -#include -#include - -#include "deepx/tensor.hpp" -#include "deepx/tensorfunc/authors.hpp" -#include "deepx/tensorfunc/elementwise_common.hpp" -#include "deepx/tensorfunc/elementwise.hpp" -namespace deepx::tensorfunc -{ - template - struct addDispatcher - { - static void add(const Tensor &A, const Tensor &B, Tensor &C) - { - detail::assert_same_shape(A, B, C); - -#if defined(__APPLE__) && TARGET_OS_OSX - // Try Metal path for supported dtypes. Current tensors are host-backed, - // so this does staging copies (correctness-first). If Metal is unavailable, - // fall back to the CPU implementation below. - bool ok = false; - if constexpr (std::is_same_v) - { - ok = deepx::mps::kernels::add_f32(A.data, B.data, C.data, A.shape.size); - } -#if defined(__FLT16_MANT_DIG__) - else if constexpr (std::is_same_v) - { - ok = deepx::mps::kernels::add_f16(A.data, B.data, C.data, A.shape.size); - } -#endif - else if constexpr (std::is_same_v) - { - ok = deepx::mps::kernels::add_i8(A.data, B.data, C.data, A.shape.size); - } - else if constexpr (std::is_same_v) - { - ok = deepx::mps::kernels::add_i16(A.data, B.data, C.data, A.shape.size); - } - else if constexpr (std::is_same_v) - { - ok = deepx::mps::kernels::add_i32(A.data, B.data, C.data, A.shape.size); - } - else if constexpr (std::is_same_v) - { - ok = deepx::mps::kernels::add_i64(A.data, B.data, C.data, A.shape.size); - } - - if (ok) - { - return; - } -#endif - detail::add_cpu(A, B, C); - } - }; -} - -#endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAOBYTE_HPP diff --git a/executor/op-mem-mps/src/deepx/tensorfunc/elementwise_miaobyte.metal b/executor/op-mem-mps/src/deepx/tensorfunc/elementwise_miaobyte.metal deleted file mode 100644 index bee51557..00000000 --- a/executor/op-mem-mps/src/deepx/tensorfunc/elementwise_miaobyte.metal +++ /dev/null @@ -1,58 +0,0 @@ -#include -using namespace metal; - -// miaobyte elementwise add kernels (specialized per dtype) - -kernel void add_f32(device const float* A [[buffer(0)]], - device const float* B [[buffer(1)]], - device float* C [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint gid [[thread_position_in_grid]]) -{ - if (gid < n) { C[gid] = A[gid] + B[gid]; } -} - -kernel void add_f16(device const half* A [[buffer(0)]], - device const half* B [[buffer(1)]], - device half* C [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint gid [[thread_position_in_grid]]) -{ - if (gid < n) { C[gid] = A[gid] + B[gid]; } -} - -kernel void add_i8(device const char* A [[buffer(0)]], - device const char* B [[buffer(1)]], - device char* C [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint gid [[thread_position_in_grid]]) -{ - if (gid < n) { C[gid] = (char)(A[gid] + B[gid]); } -} - -kernel void add_i16(device const short* A [[buffer(0)]], - device const short* B [[buffer(1)]], - device short* C [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint gid [[thread_position_in_grid]]) -{ - if (gid < n) { C[gid] = (short)(A[gid] + B[gid]); } -} - -kernel void add_i32(device const int* A [[buffer(0)]], - device const int* B [[buffer(1)]], - device int* C [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint gid [[thread_position_in_grid]]) -{ - if (gid < n) { C[gid] = A[gid] + B[gid]; } -} - -kernel void add_i64(device const long* A [[buffer(0)]], - device const long* B [[buffer(1)]], - device long* C [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint gid [[thread_position_in_grid]]) -{ - if (gid < n) { C[gid] = A[gid] + B[gid]; } -} diff --git a/executor/op-mem-mps/src/deepx/tensorfunc/mps_common.hpp b/executor/op-mem-mps/src/deepx/tensorfunc/mps_common.hpp deleted file mode 100644 index 43a2148d..00000000 --- a/executor/op-mem-mps/src/deepx/tensorfunc/mps_common.hpp +++ /dev/null @@ -1,192 +0,0 @@ -#ifndef DEEPX_TENSORFUNC_MPS_COMMON_HPP -#define DEEPX_TENSORFUNC_MPS_COMMON_HPP - -#if defined(__APPLE__) - #include -#endif - -#include -#include -#include -#include -#include -#include - -#if defined(__APPLE__) && TARGET_OS_OSX - #if defined(__OBJC__) - #import - #import - #endif -#endif - -namespace deepx::mps::common -{ -#if defined(__APPLE__) && TARGET_OS_OSX && defined(__OBJC__) - class MetalKernelRuntime final - { - public: - MetalKernelRuntime(const char *metal_source_basename) - : metal_source_basename_(metal_source_basename ? metal_source_basename : "") - { - device_ = MTLCreateSystemDefaultDevice(); - queue_ = device_ ? [device_ newCommandQueue] : nil; - } - - bool valid() const { return device_ != nil && queue_ != nil; } - - bool dispatch_binary_1d(const char *kernel_fn, - const void *a, - const void *b, - void *c, - uint32_t n, - size_t elem_bytes) - { - if (!valid() || !kernel_fn) - { - return false; - } - - @autoreleasepool - { - NSError *error = nil; - id pso = pipeline(kernel_fn, &error); - if (!pso) - { - return false; - } - - const size_t bytes = static_cast(n) * elem_bytes; - id bufA = [device_ newBufferWithBytes:a length:bytes options:MTLResourceStorageModeShared]; - id bufB = [device_ newBufferWithBytes:b length:bytes options:MTLResourceStorageModeShared]; - id bufC = [device_ newBufferWithLength:bytes options:MTLResourceStorageModeShared]; - id bufN = [device_ newBufferWithBytes:&n length:sizeof(n) options:MTLResourceStorageModeShared]; - if (!bufA || !bufB || !bufC || !bufN) - { - return false; - } - - id cmd = [queue_ commandBuffer]; - id enc = [cmd computeCommandEncoder]; - [enc setComputePipelineState:pso]; - [enc setBuffer:bufA offset:0 atIndex:0]; - [enc setBuffer:bufB offset:0 atIndex:1]; - [enc setBuffer:bufC offset:0 atIndex:2]; - [enc setBuffer:bufN offset:0 atIndex:3]; - - const NSUInteger w = pso.maxTotalThreadsPerThreadgroup; - const MTLSize threadsPerThreadgroup = MTLSizeMake(w, 1, 1); - const MTLSize threadsPerGrid = MTLSizeMake(n, 1, 1); - [enc dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; - [enc endEncoding]; - [cmd commit]; - [cmd waitUntilCompleted]; - - std::memcpy(c, [bufC contents], bytes); - return true; - } - } - - private: - id library(NSError **error) - { - if (library_) - { - return library_; - } - - NSFileManager *fm = [NSFileManager defaultManager]; - - // Offline metallib (built into build dir as default.metallib) - NSArray *libCandidates = @[ - @"default.metallib", - @"build/default.metallib", - @"executor/op-mem-mps/build/default.metallib", - ]; - - for (NSString *rel in libCandidates) - { - NSString *path = [[fm currentDirectoryPath] stringByAppendingPathComponent:rel]; - if (![fm fileExistsAtPath:path]) - { - continue; - } - NSURL *url = [NSURL fileURLWithPath:path]; - library_ = [device_ newLibraryWithURL:url error:error]; - if (library_) - { - return library_; - } - } - - // Fallback: compile from .metal source at runtime. - // We try both the basename and the common repo-relative paths. - const std::string src = metal_source_basename_.empty() ? std::string("elementwise_miaobyte.metal") : metal_source_basename_; - NSArray *srcCandidates = @[ - [NSString stringWithUTF8String:src.c_str()], - [@"src/deepx/tensorfunc" stringByAppendingPathComponent:[NSString stringWithUTF8String:src.c_str()]], - [@"executor/op-mem-mps/src/deepx/tensorfunc" stringByAppendingPathComponent:[NSString stringWithUTF8String:src.c_str()]], - ]; - - for (NSString *rel in srcCandidates) - { - NSString *path = [[fm currentDirectoryPath] stringByAppendingPathComponent:rel]; - if (![fm fileExistsAtPath:path]) - { - continue; - } - NSError *readErr = nil; - NSString *metalSource = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&readErr]; - if (!metalSource) - { - continue; - } - MTLCompileOptions *opts = [MTLCompileOptions new]; - library_ = [device_ newLibraryWithSource:metalSource options:opts error:error]; - if (library_) - { - return library_; - } - } - - return nil; - } - - id pipeline(const char *kernel_fn, NSError **error) - { - NSString *fnName = [NSString stringWithUTF8String:kernel_fn]; - auto it = pipeline_cache_.find(fnName); - if (it != pipeline_cache_.end()) - { - return it->second; - } - - id lib = library(error); - if (!lib) - { - return nil; - } - - id fn = [lib newFunctionWithName:fnName]; - if (!fn) - { - return nil; - } - - id pso = [device_ newComputePipelineStateWithFunction:fn error:error]; - if (pso) - { - pipeline_cache_.emplace(fnName, pso); - } - return pso; - } - - std::string metal_source_basename_; - id device_ = nil; - id queue_ = nil; - id library_ = nil; - std::unordered_map> pipeline_cache_; - }; -#endif -} // namespace deepx::mps::common - -#endif // DEEPX_TENSORFUNC_MPS_COMMON_HPP diff --git a/executor/op-mem-mps/swift.md b/executor/op-mem-mps/swift.md deleted file mode 100644 index 68fec32c..00000000 --- a/executor/op-mem-mps/swift.md +++ /dev/null @@ -1,13 +0,0 @@ -# op-mem-mps - -macOS 上的 MPS/Metal 执行器工程已初始化,当前提供最小可运行骨架与设备探测。 - -## 目录 - -- CMake 配置与构建脚本 -- MPS 设备与上下文初始化(Objective-C++) -- 最小示例入口 - -## 入口 - -- 构建与运行说明见 README \ No newline at end of file diff --git a/executor/op-mem-mps/test/tensorfunc/1_fill.cpp b/executor/op-mem-mps/test/tensorfunc/1_fill.cpp deleted file mode 100644 index 1acc10e0..00000000 --- a/executor/op-mem-mps/test/tensorfunc/1_fill.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include -#include - -#include "deepx/tensor.hpp" -#include "deepx/dtype.hpp" -#include "deepx/tensorfunc/authors.hpp" -#include "deepx/tensorfunc/tensorlife_miaobyte.hpp" -#include "deepx/tensorfunc/init_miaobyte.hpp" -#include "deepx/tensorfunc/io_miaobyte.hpp" - -using namespace deepx; -using namespace deepx::tensorfunc; - -int main() -{ - auto t = New({2, 3}); - constant(t, 3.5f); - print(t, "%2.3f"); - - return 0; -} diff --git a/executor/op-mem-mps/test/tensorfunc/2_add.cpp b/executor/op-mem-mps/test/tensorfunc/2_add.cpp deleted file mode 100644 index 9c1ff4d0..00000000 --- a/executor/op-mem-mps/test/tensorfunc/2_add.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include -#include - -#include "deepx/tensor.hpp" -#include "deepx/dtype.hpp" -#include "deepx/tensorfunc/authors.hpp" -#include "deepx/tensorfunc/tensorlife_miaobyte.hpp" -#include "deepx/tensorfunc/init_miaobyte.hpp" -#include "deepx/tensorfunc/elementwise_miaobyte.hpp" -#include "deepx/tensorfunc/io_miaobyte.hpp" - -using namespace deepx; -using namespace deepx::tensorfunc; - -int main() -{ - auto a = New({2, 2}); - auto b = New({2, 2}); - auto c = New({2, 2}); - - constant(a, 1.25f); - constant(b, 2.5f); - add(a, b, c); - print(c, "%2.3f"); - return 0; -} diff --git a/executor/op-mem-mps/test/tensorfunc/CMakeLists.txt b/executor/op-mem-mps/test/tensorfunc/CMakeLists.txt deleted file mode 100644 index 902027fc..00000000 --- a/executor/op-mem-mps/test/tensorfunc/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -add_executable(1_fill 1_fill.cpp) -target_link_libraries(1_fill deepx_mps) - -add_executable(2_add 2_add.cpp) -target_link_libraries(2_add deepx_mps) - -if(APPLE) - set_source_files_properties(1_fill.cpp PROPERTIES LANGUAGE OBJCXX) - set_source_files_properties(2_add.cpp PROPERTIES LANGUAGE OBJCXX) -endif() diff --git a/executor/op-mem-mps/vibe.md b/executor/op-mem-mps/vibe.md deleted file mode 100644 index 0634e967..00000000 --- a/executor/op-mem-mps/vibe.md +++ /dev/null @@ -1,132 +0,0 @@ -# op-mem-mps 设计方案(Vibe) - -> 目标:为 macOS / Apple Silicon 提供基于 Metal Performance Shaders (MPS) 的执行器,实现与现有执行器一致的接口与行为。 - -## 1. 目标与范围 - -### 1.1 目标 -- 提供 MPS 后端执行器(op-mem-mps),对接 deepx 的 Tensor / TF / Op 体系。 -- 兼容现有网络通信与任务调度流程(UDP server + TF factory)。 -- 最小可用:支持设备探测、内存分配、张量生命周期与少量算子(init/io/elementwise)。 - -### 1.2 非目标 -- 不包含跨平台抽象层的重构。 -- 不覆盖全部算子,一期仅实现核心子集。 - ---- - -## 2. 总体架构 - -### 2.1 模块划分 -- client:入口、网络服务、TF 注册与调度 -- mem:MPS 设备/上下文与缓冲区管理 -- tensorfunc:算子实现(以作者/精度为维度) -- tf:TF 封装与参数绑定 - -### 2.2 关键组件 -1. MPSDevice - - 枚举与选择 MTLDevice - - 兼容 Apple Silicon 与 Intel + AMD(若支持) - -2. MPSContext - - 维护 MTLCommandQueue - - 统一的 command buffer 生命周期 - -3. MPSBuffer / MemBase - - 统一的 Tensor 内存管理 - - 与 deepx Tensor 对接 - ---- - -## 3. 数据流与执行流程 - -front(py) -> UDP -> TFFactory -> TF.run -> tensorfunc -> MPS -> output - -- TF 负责参数解析与调度 -- tensorfunc 负责具体算子 -- mem 负责存储与同步 - ---- - -## 4. 目录与代码组织(建议) - -executor/op-mem-mps/ - src/ - client/ - main.mm - tfs.cpp - deepx/ - mem/ - tf/ - tensorfunc/ - mps_device.{hpp,mm} - mps_context.{hpp,mm} - ---- - -## 5. API 设计与契约 - -### 5.1 与现有 executor 一致 -- register_all(TfFactory&) -- TF::run(shared_ptr, string &error) -- MemBase::gettensor() - -### 5.2 MPS 约束 -- 必须在 command buffer 提交后同步读回 -- 统一使用 MTLStorageModeShared 以便 CPU 读取(一期) - ---- - -## 6. 优先级路线图(MVP -> v1) - -### MVP -- 设备探测 + MPSContext -- 张量生命周期 (new, delete) -- init (ones, zeros, arange) -- elementwise (add, mul) - -### v1 -- matmul -- reduce -- changeshape - ---- - -## 7. 构建与依赖 - -- CMake + Objective-C++ -- 依赖: - - Metal / MetalPerformanceShaders - - yaml-cpp(保持与其他执行器一致) - ---- - -## 8. 风险与约束 - -- MPS 对部分精度支持有限(如 int8) -- Metal buffer 与 CPU 共享模式性能有限 -- 异步执行与同步点控制复杂 - ---- - -## 9. 里程碑 - -| 阶段 | 内容 | 时间 | -|------|------|------| -| MVP | 设备+内存+基础算子 | 2 周 | -| v1 | 常用算子覆盖 | 4 周 | - ---- - -## 10. 验证策略 - -- 对比 op-mem-ompsimd 输出 -- 单元测试 + 前端 examples 回归 - ---- - -## 11. 需要确认的问题 - -1. 目标最小支持的算子集? -2. 是否允许 Metal shader 自定义 kernel? -3. MPSGraph 是否允许用于算子拼接? diff --git a/executor/op-metal/CLAUDE.md b/executor/op-metal/CLAUDE.md new file mode 100644 index 00000000..9277ee9f --- /dev/null +++ b/executor/op-metal/CLAUDE.md @@ -0,0 +1,179 @@ +# op-metal 开发约束 + +> op-metal 的职责边界。哪些能做,哪些**绝对不能碰**。 + +--- + +## 1. op-metal 是什么 + +在 DeepX 元程 5 核架构中,op-metal 是**计算平面**的 Metal GPU 实现。 + +它只做一件事:**被动消费指令 → 执行 GPU kernel → 通知完成**。 + +op-metal 是"无状态的 GPU 函数调用器"——它不关心数据从哪来、到哪去、含义是什么。 + +> **I/O 操作 (print/save/load) 已迁移到 `io-metal` 独立进程。** op-metal 只处理纯 GPU 计算。 + +--- + +## 2. 允许做的事(白名单) + +| 操作 | 允许 | 说明 | +|------|------|------| +| BLPOP/RPOP `cmd:op-metal:*` | ✅ | 消费 VM 发来的计算指令 | +| GET Redis 获取 tensor 元信息 | ✅ | 仅限 inputs/outputs 的 dtype/shape/shm_name/byte_size | +| shm_open + mmap 映射 tensor 内存 | ✅ | 根据 shm_name 获取 CPU 指针 | +| 执行 Metal GPU kernel | ✅ | add/sub/mul/div/relu/... (不包含 I/O) | +| newBufferWithBytes/newBufferWithLength | ✅ | 将 CPU 数据拷贝到 GPU buffer | +| LPUSH `done:` 通知完成 | ✅ | 格式: {pc, status:"ok"\|"error", error?} | +| print/save/load 等 I/O 操作 | ❌ | 已迁移到 io-metal 进程 (`cmd:io-metal:*`) | +| SET `/sys/op-plat/op-metal:0` | ✅ | 启动时注册进程状态 | +| SET NX `/op/op-metal/*` | ✅ | 首次启动时注册算子列表 | +| DELETE `/sys/op-plat/op-metal:0` | ✅ | 退出时注销 | + +--- + +## 3. 禁止做的事(黑名单) + +### 3.1 绝对禁止:理解数据语义 + +> op-metal 的执行单元是 **tensor 指针 + 元素数量 + dtype**。 +> 它不知道也不应该知道:这个 tensor 叫什么名字、属于哪个 vthread、是不是中间结果。 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| 检查 tensor 的内容是否正确 | ❌ | 这是 VM / 测试层的职责 | +| 打印 tensor 的采样数据 | ❌ | 数据验证不属于 op-plat | +| 比较不同 tensor 的值 | ❌ | op-plat 只做计算,不做断言 | +| 知道 tensor 的 key 名称含义 | ❌ | key 只是用来从 Redis 获取元信息 | + +### 3.2 绝对禁止:格式化和展示 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| 格式化输出结果数据(如 `[6, 8, 10, 12]`) | ❌ | 数据展示属于 VM/pysdk/deepxctl | +| 打印树形结构、表格、进度条 | ❌ | op-plat 是静默的计算后端 | +| 对执行结果做"美观"输出 | ❌ | 只有 done 通知是 op-plat 的输出 | + +### 3.3 禁止:性能统计渗透到业务代码 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| 在 execute_task 中加 chrono 计时 | ❌ | 性能统计应通过 profiling 工具(Instruments) | +| 分阶段计时(shm open / dispatch / notify) | ❌ | 同上 | +| 打印 "xxx ms" 到 stdout | ❌ | 不可与计算输出混在一起 | +| 累计统计(平均延迟、吞吐量) | ❌ | 应通过 `/sys/op-plat/` 的 metrics 字段上报 | + +### 3.4 禁止:越权修改其他组件 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| 修改 VM 的 vthread 状态 | ❌ | `/vthread/*` 是 VM 的私有空间 | +| 修改 heap-plat 的分配记录 | ❌ | heap-plat 管理 `/heap/*` | +| 消费 `done:*` 队列 | ❌ | done 是 VM 消费的 | +| 生产 `cmd:*` 队列消息 | ❌ | cmd 队列由 VM 生产 | +| 修改 build.sh / CMakeLists.txt | ❌ | 构建系统属于项目配置,非日常开发 | + +### 3.5 禁止:引入调试代码 + +| 操作 | 禁止 | 原因 | +|------|------|------| +| fprintf(stderr, ...) 逐次调用诊断 | ❌ | 高频调用会产生大量日志 | +| std::cerr 打印每次 dispatch 的参数 | ❌ | 同上 | +| 用 `#ifdef DEBUG` 包裹大量调试输出 | ❌ | 污染主流程可读性 | +| **唯一例外**: 启动时的进程状态日志(`[op-metal] CWD/device/connected/listening`) | ✅ | 一次性输出,帮助排查进程生命周期问题 | + +--- + +## 4. op-metal 的通信边界 + +``` +deepxctl VM op-metal heap-metal + │ │ │ │ + │ ── SET /vthread/1 ──→│ │ │ + │ ── LPUSH notify:vm ─→│ │ │ + │ │── PUSH cmd:op-metal:0 ─→│ │ + │ │ │── GET /data/x ────→ Redis + │ │ │←── {shm_name,...} ── Redis + │ │ │── shm_open("/deepx_t_xxx") → kernel + │ │ │── GPU compute add_f32 + │ │ │── LPUSH done:1 ────→ Redis + │ │←─ BLPOP done:1 ────── Redis + │ │── PC++ 继续 │ + │←── GET /vthread/1 ───│ │ + │ status=done │ │ +``` + +**op-metal 的边界:** +- 入: `cmd:op-metal:*` 队列 + Redis GET(tensor 元信息) +- 出: `done:` 队列 +- 内部: shm 映射 → GPU 计算 → 返回 + +--- + +## 5. execute_task 的标准结构 + +```cpp +static void execute_task(redisContext *redis, const json &task) { + // 1. 解析指令 + std::string opcode = task["opcode"]; + std::string vtid = task["vtid"]; + std::string pc = task["pc"]; + + // 2. 解析 inputs → 映射 shm → 获取 GPU 指针 + std::vector input_ptrs; + for (auto &in : task["inputs"]) { + auto meta = fetch_tensor_meta(redis, in["key"]); + auto shm = shm_open_readwrite(meta.shm_name, meta.byte_size); + input_ptrs.push_back(shm.addr); + } + + // 3. 解析 output → 映射 shm + auto out_meta = fetch_tensor_meta(redis, task["outputs"][0]["key"]); + auto out_shm = shm_open_readwrite(out_meta.shm_name, out_meta.byte_size); + + // 4. GPU dispatch + bool ok = dispatch_binary(opcode, dtype, input_ptrs[0], input_ptrs[1], out_shm.addr, n); + + // 5. 清理 + 通知 + shm_close_all(); + if (ok) notify_done(redis, vtid, pc, "ok"); + else notify_done(redis, vtid, pc, "error", "..."); +} +``` + +**禁止在此结构中添加:** +- `auto t0 = chrono::now()` / 分阶段计时 +- `std::cout << result[0:4]` / 数据采样打印 +- `[op-metal] ┌─` 树形格式化输出 +- 任何与计算无关的逻辑 + +--- + +## 6. 允许的日志输出 + +| 场景 | 允许 | +|------|------| +| 启动: 设备名、Redis 地址、监听队列 | ✅ 一次性 `std::cout` | +| 启动: CWD、进程 PID | ✅ 一次性 | +| 致命错误: Metal 设备不可用、Redis 连接失败 | ✅ `std::cerr` + `return 1` | +| 退出: shutdown complete | ✅ 一次性 | +| 每次 dispatch: opcode / dtype / n / 耗时 | ❌ 高频冗余 | +| 每次 dispatch: 结果采样数据 | ❌ 越界 | +| shm 操作成功/失败 | ❌ 成功不输出,失败走 error 通知 | + +--- + +## 7. 与 deepxctl 的关系 + +| 谁做什么 | deepxctl | op-metal | +|---------|----------|----------| +| 启动 op-metal 进程 | ✅ | — | +| 注册 Metal 设备 | — | ✅ | +| 连接 Redis | — | ✅ | +| 消费计算指令 | — | ✅ | +| 执行 GPU kernel | — | ✅ | +| 验证计算结果 | ✅ (通过轮询 vthread status) | ❌ | +| 展示 tensor 数据 | ✅ (通过 deepxctl verbose) | ❌ | +| 性能分析 | ✅ (外部工具) | ❌ | +| 清理子进程 | ✅ | — | diff --git a/executor/op-metal/CMakeLists.txt b/executor/op-metal/CMakeLists.txt new file mode 100644 index 00000000..c0fc2098 --- /dev/null +++ b/executor/op-metal/CMakeLists.txt @@ -0,0 +1,77 @@ +cmake_minimum_required(VERSION 3.15) +project(deepx-op-metal LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_BUILD_TYPE Debug) +set(CMAKE_OSX_DEPLOYMENT_TARGET "12.0") + +if(NOT APPLE) + message(FATAL_ERROR "op-metal only supports macOS") +endif() + +include_directories(src) +include_directories(../old-cppcommon) +include_directories(../common-metal/include) + +# hiredis (Redis C client) +include_directories(/opt/homebrew/opt/hiredis/include) +link_directories(/opt/homebrew/opt/hiredis/lib) + +# nlohmann/json (header-only JSON parser) +include_directories(/opt/homebrew/opt/nlohmann-json/include) + +# dxlang (for deepx/dtype/precision.hpp etc.) +include_directories(../dxlang/src) + +# 依赖 common-metal 公共库 +if(NOT TARGET deepx_common_metal) + add_subdirectory(../common-metal common-metal) +endif() + +# 源文件 +file(GLOB_RECURSE DEEPX_SOURCES "src/deepx/*.cpp" "src/deepx/*.hpp" "src/deepx/*.h") +file(GLOB_RECURSE CLIENT_SOURCES "src/client/*.cpp") + +find_library(METAL Metal REQUIRED) +find_library(FOUNDATION Foundation REQUIRED) + +add_library(deepx_metal SHARED ${DEEPX_SOURCES}) +target_link_libraries(deepx_metal + PUBLIC deepx_common_metal ${METAL} ${FOUNDATION} +) +set_target_properties(deepx_metal PROPERTIES + XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC YES +) + +# metal_context.cpp 包含 ObjC Metal API +set_source_files_properties(src/deepx/metal_context.cpp PROPERTIES COMPILE_FLAGS "-x objective-c++") + +# main.cpp 通过 elementwise_common.hpp 间接使用 Metal GPU kernel (__OBJC__ guard) +set_source_files_properties(src/client/main.cpp PROPERTIES COMPILE_FLAGS "-x objective-c++") + +# ── 编译期 .metal → .metallib(使用 Metal Framework API,无需 xcrun/Xcode)── +set(METAL_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/deepx/tensorfunc/elementwise_miaobyte.metal") +set(METALLIB_OUTPUT "${CMAKE_BINARY_DIR}/default.metallib") + +add_executable(compile_metal_lib src/build/compile_metal_lib.mm) +set_source_files_properties(src/build/compile_metal_lib.mm PROPERTIES COMPILE_FLAGS "-x objective-c++") +set_target_properties(compile_metal_lib PROPERTIES + XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC YES +) +target_link_libraries(compile_metal_lib ${METAL} ${FOUNDATION}) + +add_custom_command( + OUTPUT ${METALLIB_OUTPUT} + COMMAND compile_metal_lib ${METAL_SOURCE} ${METALLIB_OUTPUT} + DEPENDS compile_metal_lib ${METAL_SOURCE} + COMMENT "[metal] build-time compiling ${METAL_SOURCE} → default.metallib" +) +add_custom_target(metal_shaders DEPENDS ${METALLIB_OUTPUT}) + +add_executable(${PROJECT_NAME} ${CLIENT_SOURCES}) +add_dependencies(${PROJECT_NAME} metal_shaders) +target_link_libraries(${PROJECT_NAME} PRIVATE deepx_metal hiredis) + +# 集成测试 +add_subdirectory(test/shm) diff --git a/executor/op-metal/build.sh b/executor/op-metal/build.sh new file mode 100644 index 00000000..ab23b103 --- /dev/null +++ b/executor/op-metal/build.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +set -euo pipefail +DIR="$(cd "$(dirname "$0")" && pwd)" +BUILD_DIR="/tmp/deepx/op-metal/build" +mkdir -p "$BUILD_DIR" +cd "$BUILD_DIR" +cmake "$DIR" +cmake --build . -j$(sysctl -n hw.ncpu 2>/dev/null || nproc) +# Copy runtime dependencies (rpath: @rpath/libdeepx_metal.dylib) +cp -f libdeepx_metal.dylib default.metallib "$BUILD_DIR/" 2>/dev/null || true +echo "Built: $BUILD_DIR/deepx-op-metal" +echo "Test: $BUILD_DIR/test/shm/test_cross_process" diff --git a/executor/op-metal/src/client/main.cpp b/executor/op-metal/src/client/main.cpp new file mode 100644 index 00000000..6c8ebbb0 --- /dev/null +++ b/executor/op-metal/src/client/main.cpp @@ -0,0 +1,837 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "deepx/metal_device.hpp" +#include "deepx/tensorfunc/elementwise_common.hpp" + +using json = nlohmann::json; + +static const char *OP_QUEUE = "cmd:op-metal:0"; +static const char *SYS_QUEUE = "sys:cmd:op-metal:0"; +static const char *INSTANCE_KEY = "/sys/op-plat/op-metal:0"; +static const char *HEARTBEAT_KEY = "/sys/heartbeat/op-metal:0"; +static const int BLOCK_TIMEOUT_SEC = 5; +static const int HEARTBEAT_INTERVAL_SEC = 2; + +// ═══════════════════════════════════════════════════════════ +// Redis helpers +// ═══════════════════════════════════════════════════════════ + +static redisContext* connect_redis(const char *addr, int port) { + struct timeval tv = {2, 0}; + redisContext *c = redisConnectWithTimeout(addr, port, tv); + if (!c || c->err) { + std::cerr << "[op-metal] Redis connect failed: " << (c ? c->errstr : "null") << "\n"; + if (c) redisFree(c); + return nullptr; + } + return c; +} + +static redisReply* redis_cmd(redisContext *c, const char *fmt, ...) { + va_list ap; + va_start(ap, fmt); + redisReply *r = (redisReply *)redisvCommand(c, fmt, ap); + va_end(ap); + return r; +} + +#define REDIS_FREE(r) do { if (r) freeReplyObject(r); } while(0) + +static bool redis_set(redisContext *c, const std::string &key, const std::string &val) { + redisReply *r = redis_cmd(c, "SET %s %s", key.c_str(), val.c_str()); + bool ok = r && r->type == REDIS_REPLY_STATUS; + REDIS_FREE(r); + return ok; +} + +static void update_heartbeat(redisContext *c, const std::string &status) { + json hb; + hb["ts"] = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()).count(); + hb["status"] = status; + hb["pid"] = getpid(); + redis_set(c, HEARTBEAT_KEY, hb.dump()); +} + +static void register_instance(redisContext *c) { + json reg; + reg["program"] = "op-metal"; + reg["device"] = "gpu0"; + reg["status"] = "running"; + reg["load"] = 0.0; + reg["pid"] = getpid(); + reg["started_at"] = std::chrono::system_clock::now().time_since_epoch().count(); + redis_set(c, INSTANCE_KEY, reg.dump()); + std::cout << "[op-metal] registered at " << INSTANCE_KEY << "\n"; + + // ── 注册支持的算子列表 ── + redisReply *r = redis_cmd(c, "DEL %s", "/op/op-metal/list"); + REDIS_FREE(r); + + // elementwise binary (Metal GPU) + redis_cmd(c, "RPUSH %s %s %s %s %s %s %s", + "/op/op-metal/list", + "add", "sub", "mul", "div", "max", "min"); + // elementwise unary (Metal GPU) + redis_cmd(c, "RPUSH %s %s %s %s %s %s %s %s %s %s", + "/op/op-metal/list", + "relu", "neg", "abs", "sqrt", "exp", "log", "sin", "cos", "tan"); + // elementwise scalar + redis_cmd(c, "RPUSH %s %s %s %s %s %s %s %s %s", + "/op/op-metal/list", + "addscalar", "subscalar", "mulscalar", "divscalar", + "maxscalar", "minscalar", "pow", "powscalar"); + // comparison + redis_cmd(c, "RPUSH %s %s %s %s %s %s %s %s %s", + "/op/op-metal/list", + "equal", "notequal", "less", "greater", + "equalscalar", "notequalscalar", "lessscalar", "greaterscalar"); + // changeshape + redis_cmd(c, "RPUSH %s %s %s %s %s %s %s", + "/op/op-metal/list", + "reshape", "transpose", "concat", "broadcastTo", "indexselect", "repeat"); + // reduce + redis_cmd(c, "RPUSH %s %s %s %s %s", + "/op/op-metal/list", + "sum", "prod", "reducemax", "reducemin"); + // io — migrated to io-metal (separate I/O plane) + // init + redis_cmd(c, "RPUSH %s %s %s", + "/op/op-metal/list", + "constant", "arange"); + // misc + redis_cmd(c, "RPUSH %s %s %s", + "/op/op-metal/list", + "invert", "todtype"); + + std::cout << "[op-metal] registered all ops\n"; +} + +static void notify_done(redisContext *c, const std::string &vtid, + const std::string &pc, const std::string &status, + const std::string &error_msg = "") { + json done; + done["pc"] = pc; + done["status"] = status; + if (!error_msg.empty()) { + done["error"] = {{"code", "OP_ERROR"}, {"message", error_msg}}; + } + std::string key = "done:" + vtid; + redisReply *r = redis_cmd(c, "LPUSH %s %s", key.c_str(), done.dump().c_str()); + if (!r || r->type == REDIS_REPLY_ERROR) { + std::cerr << "[op-metal] notify_done LPUSH failed for " << vtid << ": " << (r ? r->str : "NULL") << "\n"; + } + REDIS_FREE(r); + std::cout << "[op-metal] done " << vtid << " pc=" << pc << " status=" << status << "\n"; +} + +// ═══════════════════════════════════════════════════════════ +// shm helpers +// ═══════════════════════════════════════════════════════════ + +static size_t page_size() { + static long ps = sysconf(_SC_PAGESIZE); + return ps > 0 ? (size_t)ps : 16384; +} + +static size_t page_align(size_t n) { + size_t ps = page_size(); + return (n + ps - 1) & ~(ps - 1); +} + +struct ShmMapping { + std::string shm_name; + void *addr = nullptr; + size_t byte_size = 0; +}; + +static bool shm_open_readwrite(const std::string &name, size_t byte_size, ShmMapping &out) { + out.shm_name = name; + out.byte_size = byte_size; + + int fd = shm_open(name.c_str(), O_RDWR, 0600); + if (fd < 0) { + std::cerr << "[op-metal] shm_open failed: " << name << " (" << strerror(errno) << ")\n"; + return false; + } + + size_t aligned = page_align(byte_size); + void *addr = mmap(nullptr, aligned, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + close(fd); + if (addr == MAP_FAILED) { + std::cerr << "[op-metal] mmap failed: " << name << " (" << strerror(errno) << ")\n"; + return false; + } + + out.addr = addr; + return true; +} + +static void shm_close(ShmMapping &m) { + if (m.addr) { + munmap(m.addr, page_align(m.byte_size)); + m.addr = nullptr; + } +} + +// ═══════════════════════════════════════════════════════════ +// Tensor metadata from Redis +// ═══════════════════════════════════════════════════════════ + +struct TensorMeta { + std::string key; + std::string dtype; + std::vector shape_data; + std::string shm_name; + size_t byte_size = 0; + bool valid = false; +}; + +static TensorMeta fetch_tensor_meta(redisContext *c, const std::string &key) { + TensorMeta m; + m.key = key; + + redisReply *r = redis_cmd(c, "GET %s", key.c_str()); + if (!r || r->type != REDIS_REPLY_STRING) { + REDIS_FREE(r); + return m; + } + + try { + json meta = json::parse(r->str); + REDIS_FREE(r); + + if (meta.contains("dtype")) m.dtype = meta["dtype"].get(); + if (meta.contains("shape") && meta["shape"].is_array()) { + for (const auto &d : meta["shape"]) { + m.shape_data.push_back(d.get()); + } + } + if (meta.contains("byte_size")) m.byte_size = meta["byte_size"].get(); + if (meta.contains("address") && meta["address"].contains("shm_name")) { + m.shm_name = meta["address"]["shm_name"].get(); + } + m.valid = true; + } catch (const std::exception &e) { + REDIS_FREE(r); + std::cerr << "[op-metal] JSON parse error for tensor " << key << ": " << e.what() << "\n"; + } + + return m; +} + +// ═══════════════════════════════════════════════════════════ +// Convert TensorMeta shape_data to vector +// ═══════════════════════════════════════════════════════════ + +static std::vector meta_shape(const TensorMeta &m) { + std::vector s; + for (auto d : m.shape_data) s.push_back(static_cast(d)); + return s; +} + +// ═══════════════════════════════════════════════════════════ +// Kernel dispatch (Metal GPU) +// ═══════════════════════════════════════════════════════════ + +static inline int64_t element_count(const std::vector &shape) { + int64_t n = 1; + for (auto d : shape) n *= d; + return n; +} + +static bool dispatch_binary(const std::string &opcode, const std::string &dtype, + const void *a, const void *b, void *c, int64_t n) { + using namespace deepx::metal::kernels; + + if (dtype == "f32" || dtype == "float32") { + if (opcode == "add") return add_f32((const float*)a, (const float*)b, (float*)c, n); + if (opcode == "sub") return sub_f32((const float*)a, (const float*)b, (float*)c, n); + if (opcode == "mul") return mul_f32((const float*)a, (const float*)b, (float*)c, n); + if (opcode == "div") return div_f32((const float*)a, (const float*)b, (float*)c, n); + if (opcode == "max") return max_f32((const float*)a, (const float*)b, (float*)c, n); + if (opcode == "min") return min_f32((const float*)a, (const float*)b, (float*)c, n); + } + if (dtype == "i8" || dtype == "int8") { + if (opcode == "add") return add_i8((const int8_t*)a, (const int8_t*)b, (int8_t*)c, n); + if (opcode == "sub") return sub_i8((const int8_t*)a, (const int8_t*)b, (int8_t*)c, n); + if (opcode == "mul") return mul_i8((const int8_t*)a, (const int8_t*)b, (int8_t*)c, n); + if (opcode == "max") return max_i8((const int8_t*)a, (const int8_t*)b, (int8_t*)c, n); + if (opcode == "min") return min_i8((const int8_t*)a, (const int8_t*)b, (int8_t*)c, n); + } + if (dtype == "i16" || dtype == "int16") { + if (opcode == "add") return add_i16((const int16_t*)a, (const int16_t*)b, (int16_t*)c, n); + if (opcode == "sub") return sub_i16((const int16_t*)a, (const int16_t*)b, (int16_t*)c, n); + if (opcode == "mul") return mul_i16((const int16_t*)a, (const int16_t*)b, (int16_t*)c, n); + if (opcode == "max") return max_i16((const int16_t*)a, (const int16_t*)b, (int16_t*)c, n); + if (opcode == "min") return min_i16((const int16_t*)a, (const int16_t*)b, (int16_t*)c, n); + } + if (dtype == "i32" || dtype == "int32") { + if (opcode == "add") return add_i32((const int32_t*)a, (const int32_t*)b, (int32_t*)c, n); + if (opcode == "sub") return sub_i32((const int32_t*)a, (const int32_t*)b, (int32_t*)c, n); + if (opcode == "mul") return mul_i32((const int32_t*)a, (const int32_t*)b, (int32_t*)c, n); + if (opcode == "max") return max_i32((const int32_t*)a, (const int32_t*)b, (int32_t*)c, n); + if (opcode == "min") return min_i32((const int32_t*)a, (const int32_t*)b, (int32_t*)c, n); + } + if (dtype == "i64" || dtype == "int64") { + if (opcode == "add") return add_i64((const int64_t*)a, (const int64_t*)b, (int64_t*)c, n); + if (opcode == "sub") return sub_i64((const int64_t*)a, (const int64_t*)b, (int64_t*)c, n); + if (opcode == "mul") return mul_i64((const int64_t*)a, (const int64_t*)b, (int64_t*)c, n); + if (opcode == "max") return max_i64((const int64_t*)a, (const int64_t*)b, (int64_t*)c, n); + if (opcode == "min") return min_i64((const int64_t*)a, (const int64_t*)b, (int64_t*)c, n); + } + + return false; +} + +static bool dispatch_unary(const std::string &opcode, const std::string &dtype, + const void *x, void *y, int64_t n) { + using namespace deepx::metal::kernels; + + if (opcode == "relu") { + if (dtype == "f32" || dtype == "float32") return relu_f32((const float*)x, (float*)y, n); + if (dtype == "i8" || dtype == "int8") return relu_i8((const int8_t*)x, (int8_t*)y, n); + if (dtype == "i16" || dtype == "int16") return relu_i16((const int16_t*)x, (int16_t*)y, n); + if (dtype == "i32" || dtype == "int32") return relu_i32((const int32_t*)x, (int32_t*)y, n); + if (dtype == "i64" || dtype == "int64") return relu_i64((const int64_t*)x, (int64_t*)y, n); + } + if (opcode == "neg") { + if (dtype == "f32" || dtype == "float32") return neg_f32((const float*)x, (float*)y, n); + if (dtype == "i8" || dtype == "int8") return neg_i8((const int8_t*)x, (int8_t*)y, n); + if (dtype == "i16" || dtype == "int16") return neg_i16((const int16_t*)x, (int16_t*)y, n); + if (dtype == "i32" || dtype == "int32") return neg_i32((const int32_t*)x, (int32_t*)y, n); + if (dtype == "i64" || dtype == "int64") return neg_i64((const int64_t*)x, (int64_t*)y, n); + } + if (opcode == "abs") { + if (dtype == "f32" || dtype == "float32") return abs_f32((const float*)x, (float*)y, n); + if (dtype == "i8" || dtype == "int8") return abs_i8((const int8_t*)x, (int8_t*)y, n); + if (dtype == "i16" || dtype == "int16") return abs_i16((const int16_t*)x, (int16_t*)y, n); + if (dtype == "i32" || dtype == "int32") return abs_i32((const int32_t*)x, (int32_t*)y, n); + if (dtype == "i64" || dtype == "int64") return abs_i64((const int64_t*)x, (int64_t*)y, n); + } + + // f32-only unary ops + if (dtype == "f32" || dtype == "float32") { + if (opcode == "sqrt") return sqrt_f32((const float*)x, (float*)y, n); + if (opcode == "exp") return exp_f32((const float*)x, (float*)y, n); + if (opcode == "log") return log_f32((const float*)x, (float*)y, n); + if (opcode == "sin") return sin_f32((const float*)x, (float*)y, n); + if (opcode == "cos") return cos_f32((const float*)x, (float*)y, n); + if (opcode == "tan") return tan_f32((const float*)x, (float*)y, n); + } + + return false; +} + +// ═══════════════════════════════════════════════════════════ +// CPU fallback dispatch: elementwise (scalar / comparison) +// ═══════════════════════════════════════════════════════════ + +template +static void cpu_binary_scalar_op(const std::string &opcode, + T *a_data, T scalar, T *c_data, int64_t n) { + if (opcode == "addscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = a_data[i] + scalar; } + else if (opcode == "subscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = a_data[i] - scalar; } + else if (opcode == "mulscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = a_data[i] * scalar; } + else if (opcode == "divscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = a_data[i] / scalar; } + else if (opcode == "maxscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = std::max(a_data[i], scalar); } + else if (opcode == "minscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = std::min(a_data[i], scalar); } + else if (opcode == "powscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = std::pow(a_data[i], scalar); } + else if (opcode == "rsubscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = scalar - a_data[i]; } + else if (opcode == "rdivscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = scalar / a_data[i]; } + else if (opcode == "rpowscalar") { for (int64_t i = 0; i < n; ++i) c_data[i] = std::pow(scalar, a_data[i]); } +} + +template +static void cpu_comparison_op(const std::string &opcode, + const T *a, const T *b, bool *c, int64_t n) { + if (opcode == "equal") { for (int64_t i = 0; i < n; ++i) c[i] = (a[i] == b[i]); } + else if (opcode == "notequal") { for (int64_t i = 0; i < n; ++i) c[i] = (a[i] != b[i]); } + else if (opcode == "less") { for (int64_t i = 0; i < n; ++i) c[i] = (a[i] < b[i]); } + else if (opcode == "greater") { for (int64_t i = 0; i < n; ++i) c[i] = (a[i] > b[i]); } +} + +template +static void cpu_scalar_comparison_op(const std::string &opcode, + const T *a, T scalar, bool *c, int64_t n) { + if (opcode == "equalscalar") { for (int64_t i = 0; i < n; ++i) c[i] = (a[i] == scalar); } + else if (opcode == "notequalscalar") { for (int64_t i = 0; i < n; ++i) c[i] = (a[i] != scalar); } + else if (opcode == "lessscalar") { for (int64_t i = 0; i < n; ++i) c[i] = (a[i] < scalar); } + else if (opcode == "greaterscalar") { for (int64_t i = 0; i < n; ++i) c[i] = (a[i] > scalar); } +} + +// ═══════════════════════════════════════════════════════════ +// Type dispatch (pick correct T based on dtype string) +// Uses type tags to avoid explicit template arg syntax on lambdas. +// Caller defines: auto Fn = [&](auto tag) { using T = typename decltype(tag)::type; ... }; +// ═══════════════════════════════════════════════════════════ + +template struct type_tag { using type = T; }; + +#define DISPATCH_BY_DTYPE(dtype, Fn) \ + do { \ + if (dtype == "f32" || dtype == "float32") Fn(type_tag{}); \ + else if (dtype == "i64" || dtype == "int64") Fn(type_tag{}); \ + else if (dtype == "i32" || dtype == "int32") Fn(type_tag{}); \ + else if (dtype == "i16" || dtype == "int16") Fn(type_tag{}); \ + else if (dtype == "i8" || dtype == "int8") Fn(type_tag{}); \ + else if (dtype == "bool") Fn(type_tag{}); \ + else { error = "unsupported dtype: " + dtype; return; } \ + } while(0) + +// ═══════════════════════════════════════════════════════════ +// Task execution +// ═══════════════════════════════════════════════════════════ + +static void execute_task(redisContext *redis, const json &task) { + std::string vtid = task.value("vtid", ""); + std::string pc = task.value("pc", ""); + std::string opcode = task.value("opcode", ""); + json params = task.value("params", json::object()); + + if (!task.contains("inputs") || !task.contains("outputs")) { + notify_done(redis, vtid, pc, "error", "missing inputs/outputs"); + return; + } + + const auto &inputs = task["inputs"]; + const auto &outputs = task["outputs"]; + + // IO ops (print/save/load) routed to io-metal — not handled here + // All remaining ops require both inputs and outputs + if (inputs.empty() || outputs.empty()) { + notify_done(redis, vtid, pc, "error", "empty inputs/outputs for compute op"); + return; + } + + // ── Resolve input tensors ── + std::vector input_metas; + std::vector input_shms; + std::vector input_ptrs; + + for (const auto &in : inputs) { + std::string key = in.value("key", ""); + if (key.empty()) { + notify_done(redis, vtid, pc, "error", "input missing key"); + return; + } + + TensorMeta meta = fetch_tensor_meta(redis, key); + if (!meta.valid) { + notify_done(redis, vtid, pc, "error", "input tensor not found: " + key); + return; + } + + ShmMapping shm; + if (!meta.shm_name.empty()) { + if (!shm_open_readwrite(meta.shm_name, meta.byte_size, shm)) { + notify_done(redis, vtid, pc, "error", "shm open failed: " + meta.shm_name); + return; + } + input_ptrs.push_back(shm.addr); + } else { + notify_done(redis, vtid, pc, "error", "input has no shm address: " + key); + return; + } + + input_metas.push_back(meta); + input_shms.push_back(shm); + } + + // ── Resolve output tensor (optional for IO ops like print/save) ── + std::string out_key; + TensorMeta out_meta; + ShmMapping out_shm; + bool has_output = !outputs.empty(); + int64_t n = 0; + std::string dtype = "f32"; + + if (has_output) { + const auto &out = outputs[0]; + out_key = out.value("key", ""); + out_meta = fetch_tensor_meta(redis, out_key); + if (!out_meta.valid) { + notify_done(redis, vtid, pc, "error", "output tensor not found: " + out_key); + for (auto &s : input_shms) shm_close(s); + return; + } + if (out_meta.valid && !out_meta.shm_name.empty()) { + if (!shm_open_readwrite(out_meta.shm_name, out_meta.byte_size, out_shm)) { + notify_done(redis, vtid, pc, "error", "output shm open failed: " + out_meta.shm_name); + for (auto &s : input_shms) shm_close(s); + return; + } + } + n = out_meta.valid ? element_count(out_meta.shape_data) : element_count(input_metas[0].shape_data); + dtype = out_meta.dtype.empty() ? (input_metas.empty() ? "f32" : input_metas[0].dtype) : out_meta.dtype; + } else { + // For ops without outputs, infer dtype from first input + if (!input_metas.empty()) { + dtype = input_metas[0].dtype.empty() ? "f32" : input_metas[0].dtype; + n = element_count(input_metas[0].shape_data); + } + } + + // ── Dispatch ── + bool ok = false; + std::string error; + + // ── elementwise binary (GPU Metal) ── + if (input_ptrs.size() == 2 && + (opcode == "add" || opcode == "sub" || opcode == "mul" || + opcode == "div" || opcode == "max" || opcode == "min")) { + ok = dispatch_binary(opcode, dtype, input_ptrs[0], input_ptrs[1], out_shm.addr, n); + if (!ok) error = "Metal binary kernel dispatch failed for " + opcode + ":" + dtype; + } + // ── elementwise unary (GPU Metal) ── + else if (input_ptrs.size() == 1 && + (opcode == "relu" || opcode == "neg" || opcode == "abs" || + opcode == "sqrt" || opcode == "exp" || opcode == "log" || + opcode == "sin" || opcode == "cos" || opcode == "tan")) { + ok = dispatch_unary(opcode, dtype, input_ptrs[0], out_shm.addr, n); + if (!ok) error = "Metal unary kernel dispatch failed for " + opcode + ":" + dtype; + } + // ── elementwise scalar (CPU) ── + else if (input_ptrs.size() == 1 && + (opcode == "addscalar" || opcode == "subscalar" || opcode == "mulscalar" || + opcode == "divscalar" || opcode == "maxscalar" || opcode == "minscalar" || + opcode == "powscalar" || opcode == "rsubscalar" || opcode == "rdivscalar" || + opcode == "rpowscalar")) { + double scalar_val = params.value("scalar", 0.0); + int64_t cn = element_count(input_metas[0].shape_data); + auto fn = [&](auto tag) { + using T = typename decltype(tag)::type; + T scalar = static_cast(scalar_val); + cpu_binary_scalar_op(opcode, static_cast(input_ptrs[0]), scalar, static_cast(out_shm.addr), cn); + ok = true; + }; + DISPATCH_BY_DTYPE(dtype, fn); + } + // ── elementwise comparison (CPU) ── + else if (input_ptrs.size() == 2 && + (opcode == "equal" || opcode == "notequal" || + opcode == "less" || opcode == "greater")) { + int64_t cn = element_count(input_metas[0].shape_data); + auto fn = [&](auto tag) { + using T = typename decltype(tag)::type; + cpu_comparison_op(opcode, static_cast(input_ptrs[0]), + static_cast(input_ptrs[1]), + static_cast(out_shm.addr), cn); + ok = true; + }; + DISPATCH_BY_DTYPE(dtype, fn); + } + // ── scalar comparison (CPU) ── + else if (input_ptrs.size() == 1 && + (opcode == "equalscalar" || opcode == "notequalscalar" || + opcode == "lessscalar" || opcode == "greaterscalar")) { + double scalar_val = params.value("scalar", 0.0); + int64_t cn = element_count(input_metas[0].shape_data); + auto fn = [&](auto tag) { + using T = typename decltype(tag)::type; + T scalar = static_cast(scalar_val); + cpu_scalar_comparison_op(opcode, static_cast(input_ptrs[0]), + scalar, static_cast(out_shm.addr), cn); + ok = true; + }; + DISPATCH_BY_DTYPE(dtype, fn); + } + // ── invert (CPU, integer only) ── + else if (opcode == "invert" && input_ptrs.size() == 1) { + int64_t nelem = element_count(input_metas[0].shape_data); + if (dtype == "i64" || dtype == "int64") { + int64_t *a = static_cast(input_ptrs[0]); + int64_t *c = static_cast(out_shm.addr); + for (int64_t i = 0; i < nelem; ++i) c[i] = ~a[i]; + ok = true; + } else if (dtype == "i32" || dtype == "int32") { + int32_t *a = static_cast(input_ptrs[0]); + int32_t *c = static_cast(out_shm.addr); + for (int64_t i = 0; i < nelem; ++i) c[i] = ~a[i]; + ok = true; + } else if (dtype == "i16" || dtype == "int16") { + int16_t *a = static_cast(input_ptrs[0]); + int16_t *c = static_cast(out_shm.addr); + for (int64_t i = 0; i < nelem; ++i) c[i] = static_cast(~a[i]); + ok = true; + } else if (dtype == "i8" || dtype == "int8") { + int8_t *a = static_cast(input_ptrs[0]); + int8_t *c = static_cast(out_shm.addr); + for (int64_t i = 0; i < nelem; ++i) c[i] = ~a[i]; + ok = true; + } else if (dtype == "bool") { + bool *a = static_cast(input_ptrs[0]); + bool *c = static_cast(out_shm.addr); + for (int64_t i = 0; i < nelem; ++i) c[i] = !a[i]; + ok = true; + } else { + error = "invert only supports integer/bool dtypes, got: " + dtype; + } + } + // ── todtype (CPU) ── + else if (opcode == "todtype" && input_ptrs.size() == 1) { + std::string src_dtype = input_metas[0].dtype.empty() ? "f32" : input_metas[0].dtype; + std::string dst_dtype = out_meta.dtype.empty() ? "f32" : out_meta.dtype; + int64_t nelem = element_count(input_metas[0].shape_data); + + auto copy_data = [&](auto *src_ptr, auto *dst_ptr, int64_t count) { + for (int64_t i = 0; i < count; ++i) dst_ptr[i] = static_cast>(src_ptr[i]); + ok = true; + }; + + // f32 source + if (src_dtype == "f32" || src_dtype == "float32") { + float *src = static_cast(input_ptrs[0]); + if (dst_dtype == "f32" || dst_dtype == "float32") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i64" || dst_dtype == "int64") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i32" || dst_dtype == "int32") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i16" || dst_dtype == "int16") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i8" || dst_dtype == "int8") + copy_data(src, static_cast(out_shm.addr), nelem); + else error = "unsupported dst dtype: " + dst_dtype; + } + // i64 source + else if (src_dtype == "i64" || src_dtype == "int64") { + int64_t *src = static_cast(input_ptrs[0]); + if (dst_dtype == "f32" || dst_dtype == "float32") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i64" || dst_dtype == "int64") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i32" || dst_dtype == "int32") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i16" || dst_dtype == "int16") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i8" || dst_dtype == "int8") + copy_data(src, static_cast(out_shm.addr), nelem); + else error = "unsupported dst dtype: " + dst_dtype; + } + else if (src_dtype == "i32" || src_dtype == "int32") { + int32_t *src = static_cast(input_ptrs[0]); + if (dst_dtype == "f32" || dst_dtype == "float32") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i64" || dst_dtype == "int64") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i32" || dst_dtype == "int32") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i16" || dst_dtype == "int16") + copy_data(src, static_cast(out_shm.addr), nelem); + else if (dst_dtype == "i8" || dst_dtype == "int8") + copy_data(src, static_cast(out_shm.addr), nelem); + else error = "unsupported dst dtype: " + dst_dtype; + } + else { + error = "unsupported src dtype: " + src_dtype; + } + } + // ── changeshape / init ops (stub — not yet rebuilt) ── + else if (opcode == "reshape" || opcode == "transpose" || opcode == "concat" || + opcode == "broadcastTo" || opcode == "indexselect" || opcode == "repeat" || + opcode == "constant") { + error = "changeshape/init ops not available (refactoring in progress)"; + } + // ── reduce ops (stub — not yet rebuilt) ── + else if (opcode == "sum" || opcode == "prod" || + opcode == "reducemax" || opcode == "reducemin") { + error = "reduce ops not available (refactoring in progress)"; + } + // ── io ops (print/save/load) — routed to io-metal plane ── + else if (opcode == "print" || opcode == "save" || opcode == "load") { + error = "io op routed to io-metal plane (cmd:io-metal:0) — not handled by op-metal"; + } + // ── pow (CPU, binary) ── + else if (opcode == "pow" && input_ptrs.size() == 2) { + int64_t nelem = element_count(input_metas[0].shape_data); + auto fn = [&](auto tag) { + using T = typename decltype(tag)::type; + T *a = static_cast(input_ptrs[0]); + T *b = static_cast(input_ptrs[1]); + T *c = static_cast(out_shm.addr); + for (int64_t i = 0; i < nelem; ++i) c[i] = static_cast(std::pow(a[i], b[i])); + ok = true; + }; + DISPATCH_BY_DTYPE(dtype, fn); + } + // ── unsupported ── + else { + notify_done(redis, vtid, pc, "error", + "unsupported opcode or input count: " + opcode + " (inputs=" + std::to_string(input_ptrs.size()) + ")"); + // cleanup + for (auto &s : input_shms) shm_close(s); + if (has_output) shm_close(out_shm); + return; + } + + // ── Cleanup ── + for (auto &s : input_shms) shm_close(s); + if (has_output) shm_close(out_shm); + + if (ok) { + notify_done(redis, vtid, pc, "ok"); + } else { + if (error.empty()) error = "dispatch failed for " + opcode; + notify_done(redis, vtid, pc, "error", error); + } +} + +// ═══════════════════════════════════════════════════════════ +// Main +// ═══════════════════════════════════════════════════════════ + +int main(int argc, char **argv) { + const char *redis_addr = "127.0.0.1"; + int redis_port = 6379; + if (argc > 1) redis_addr = argv[1]; + if (argc > 2) redis_port = atoi(argv[2]); + + // Force unbuffered output for diagnostics (subprocess stdout is fully buffered) + std::cout << std::unitbuf; + std::cerr << std::unitbuf; + + // 验证 Metal 可用 (使用 C++ wrapper) + { + char cwd[4096]; + if (getcwd(cwd, sizeof(cwd))) { + std::cout << "[op-metal] CWD: " << cwd << "\n"; + } + } + auto deviceInfo = deepx::metal::get_default_device_info(); + if (!deviceInfo.supports_metal) { + std::cerr << "[op-metal] FATAL: no Metal device\n"; + return 1; + } + std::cout << "[op-metal] device: " << deviceInfo.name << "\n"; + + // 连接 Redis(无限重试,不自退——op-plat 由元程控制退出) + redisContext *redis = nullptr; + while (!redis) { + redis = connect_redis(redis_addr, redis_port); + if (!redis) { + std::cerr << "[op-metal] Redis not available, retrying in 1s...\n"; + sleep(1); + } + } + std::cout << "[op-metal] connected to Redis " << redis_addr << ":" << redis_port << "\n"; + + // 注册实例和算子 + register_instance(redis); + + std::cout << "[op-metal] listening on " << OP_QUEUE << " + " << SYS_QUEUE << "\n"; + std::cout << "[op-metal] heartbeat → " << HEARTBEAT_KEY << " (every " << HEARTBEAT_INTERVAL_SEC << "s)\n"; + + // 初始心跳 + update_heartbeat(redis, "running"); + + // ── 消费循环 (同时监听业务队列和系统命令队列) ── + std::atomic running{true}; + auto last_heartbeat = std::chrono::steady_clock::now(); + while (running) { + redisReply *r = redis_cmd(redis, "BLPOP %s %s %d", OP_QUEUE, SYS_QUEUE, BLOCK_TIMEOUT_SEC); + if (!r) { + // Redis 断连 → 无限重连(不自退,op-plat 由元程控制退出) + std::cerr << "[op-metal] Redis disconnected, reconnecting...\n"; + redisFree(redis); + redis = nullptr; + while (!redis) { + sleep(1); + redis = connect_redis(redis_addr, redis_port); + if (!redis) { + std::cerr << "[op-metal] Redis still not available, retrying...\n"; + } + } + register_instance(redis); + last_heartbeat = std::chrono::steady_clock::now(); + update_heartbeat(redis, "running"); + continue; + } + + // ── 心跳上报 ── + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - last_heartbeat).count() >= HEARTBEAT_INTERVAL_SEC) { + update_heartbeat(redis, "running"); + last_heartbeat = now; + } + + if (r->type == REDIS_REPLY_NIL) { + REDIS_FREE(r); + continue; + } + + if (r->type != REDIS_REPLY_ARRAY || r->elements < 2) { + REDIS_FREE(r); + continue; + } + + std::string queue_name(r->element[0]->str); + std::string payload(r->element[1]->str); + REDIS_FREE(r); + + // ── 系统命令处理 ── + if (queue_name == SYS_QUEUE) { + try { + json sys_cmd = json::parse(payload); + std::string cmd = sys_cmd.value("cmd", ""); + if (cmd == "shutdown") { + std::cout << "[op-metal] received sys shutdown command, exiting...\n"; + running = false; + } else { + std::cerr << "[op-metal] unknown sys command: " << cmd << "\n"; + } + } catch (const std::exception &e) { + std::cerr << "[op-metal] sys cmd JSON parse error: " << e.what() << "\n"; + } + continue; + } + + // ── 业务命令处理 ── + // 解析任务 + json task; + try { + task = json::parse(payload); + } catch (const std::exception &e) { + std::cerr << "[op-metal] JSON parse error: " << e.what() << "\n"; + continue; + } + + try { + execute_task(redis, task); + } catch (const std::exception &e) { + std::string vtid = task.value("vtid", ""); + std::string pc = task.value("pc", ""); + std::cerr << "[op-metal] task exception: " << e.what() << "\n"; + if (!vtid.empty()) { + notify_done(redis, vtid, pc, "error", e.what()); + } + } + } + + // 上报 stopped 心跳,然后注销 + if (redis) { + update_heartbeat(redis, "stopped"); + std::cout << "[op-metal] final heartbeat: stopped\n"; + redis_cmd(redis, "DEL %s", INSTANCE_KEY); + redisFree(redis); + } + std::cout << "[op-metal] shutdown complete.\n"; + return 0; +} diff --git a/executor/op-mem-mps/src/deepx/dtype_mps.hpp b/executor/op-metal/src/deepx/dtype_metal.hpp similarity index 97% rename from executor/op-mem-mps/src/deepx/dtype_mps.hpp rename to executor/op-metal/src/deepx/dtype_metal.hpp index a10858a0..6e5dae1a 100644 --- a/executor/op-mem-mps/src/deepx/dtype_mps.hpp +++ b/executor/op-metal/src/deepx/dtype_metal.hpp @@ -1,7 +1,7 @@ #include -#include "deepx/dtype.hpp" +#include "dtype.hpp" namespace deepx diff --git a/executor/op-metal/src/deepx/mem/mem_metal.hpp b/executor/op-metal/src/deepx/mem/mem_metal.hpp new file mode 100644 index 00000000..08463d49 --- /dev/null +++ b/executor/op-metal/src/deepx/mem/mem_metal.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include + +#include "tensor.hpp" +#include "mem/mem.hpp" +#include "deepx/shmem/shm_tensor.h" + +namespace deepx::mem { + +// Metal 侧 Mem 实现 — tensor 数据来自 POSIX shm (heap-metal 分配) +class MemMetal : public MemBase { +public: + MemMetal() = default; + + // 注册一个 tensor 的 shm 映射 + // shm_name: POSIX shm 名称 + // byte_size: tensor 数据字节数 + // addr: mmap 后的虚拟地址 + void register_tensor(const std::string &name, + const std::string &shm_name, + size_t byte_size, + void *addr, + const Shape &shape) { + auto info = std::make_shared(); + info->shm_name = shm_name; + info->byte_size = byte_size; + info->addr = addr; + info->shape = shape; + tensors_[name] = info; + } + + // 导入 heap-metal 已创建的 tensor + bool import(const std::string &name, const std::string &shm_name, + size_t byte_size, const Shape &shape) { + deepx::shmem::ShmTensor st; + if (!deepx::shmem::shm_tensor_open(shm_name, byte_size, st)) { + return false; + } + register_tensor(name, shm_name, byte_size, st.addr, shape); + return true; + } + + std::shared_ptr> gettensor(const std::string &name) const override { + auto it = tensors_.find(name); + if (it == tensors_.end()) { + throw std::runtime_error("tensor not found: " + name); + } + auto &info = it->second; + auto result = std::make_shared>(); + result->shape = info->shape; + result->data = info->addr; + result->deleter = nullptr; // shm 生命周期由 heap 管理 + result->copyer = nullptr; + result->newer = nullptr; + return result; + } + + // 本地创建(单进程调试用,不走 shm) + template + void local_new(const std::string &name, const std::vector &dims) { + Shape shape(dims); + shape.dtype = precision(); + size_t bytes = shape.size * sizeof(T); + T *data = new T[shape.size]; + + auto info = std::make_shared(); + info->shm_name = ""; // 本地分配 + info->byte_size = bytes; + info->addr = data; + info->shape = shape; + info->local = true; + tensors_[name] = info; + } + + ~MemMetal() { + for (auto &kv : tensors_) { + auto &info = kv.second; + if (!info->shm_name.empty()) { + deepx::shmem::ShmTensor st; + st.shm_name = info->shm_name; + st.addr = info->addr; + st.byte_size = info->byte_size; + deepx::shmem::shm_tensor_close(st); + } else if (info->local && info->addr) { + // 本地 new[] 的释放 + operator delete(info->addr); + } + } + } + +private: + struct TensorInfo { + std::string shm_name; + size_t byte_size = 0; + void *addr = nullptr; + Shape shape; + bool local = false; // true 表示本地 new[],需自行释放 + }; + std::unordered_map> tensors_; +}; + +} // namespace deepx::mem diff --git a/executor/op-mem-mps/src/deepx/mps_context.mm b/executor/op-metal/src/deepx/metal_context.cpp similarity index 59% rename from executor/op-mem-mps/src/deepx/mps_context.mm rename to executor/op-metal/src/deepx/metal_context.cpp index be4a1484..ee281e3f 100644 --- a/executor/op-mem-mps/src/deepx/mps_context.mm +++ b/executor/op-metal/src/deepx/metal_context.cpp @@ -1,11 +1,11 @@ #import #import -#include "deepx/mps_context.hpp" +#include "deepx/metal_context.hpp" -namespace deepx::mps +namespace deepx::metal { -MpsContext::MpsContext() +MetalContext::MetalContext() { device_ = MTLCreateSystemDefaultDevice(); if (device_) @@ -14,12 +14,12 @@ } } -bool MpsContext::is_valid() const +bool MetalContext::is_valid() const { return device_ != nil; } -std::string MpsContext::device_name() const +std::string MetalContext::device_name() const { if (!device_) { @@ -28,12 +28,12 @@ return std::string([[device_ name] UTF8String]); } -id MpsContext::device() const +id MetalContext::device() const { return device_; } -id MpsContext::command_queue() const +id MetalContext::command_queue() const { return command_queue_; } diff --git a/executor/op-mem-mps/src/deepx/mps_context.hpp b/executor/op-metal/src/deepx/metal_context.hpp similarity index 89% rename from executor/op-mem-mps/src/deepx/mps_context.hpp rename to executor/op-metal/src/deepx/metal_context.hpp index 8f5d6e70..67d95a64 100644 --- a/executor/op-mem-mps/src/deepx/mps_context.hpp +++ b/executor/op-metal/src/deepx/metal_context.hpp @@ -8,12 +8,12 @@ @protocol MTLCommandQueue; #endif -namespace deepx::mps +namespace deepx::metal { -class MpsContext +class MetalContext { public: - MpsContext(); + MetalContext(); bool is_valid() const; std::string device_name() const; diff --git a/executor/op-metal/src/deepx/tensorfunc/changeshape_miaobyte.hpp b/executor/op-metal/src/deepx/tensorfunc/changeshape_miaobyte.hpp new file mode 100644 index 00000000..4c6397c9 --- /dev/null +++ b/executor/op-metal/src/deepx/tensorfunc/changeshape_miaobyte.hpp @@ -0,0 +1,253 @@ +#ifndef DEEPX_TENSORFUNC_CHANGESHAPE_MIAOBYTE_HPP +#define DEEPX_TENSORFUNC_CHANGESHAPE_MIAOBYTE_HPP + +#include +#include +#include + +#include "tensor.hpp" +#include "shape_changeshape.hpp" +#include "tensorfunc/changeshape.hpp" +#include "tensorfunc/authors.hpp" + +namespace deepx::tensorfunc +{ + // ═══════════════════════════════════════════════════════════ + // reshape — pure CPU, checks element count & copies data + // ═══════════════════════════════════════════════════════════ + template + struct reshapeDispatcher + { + static void reshape(const Tensor &tensor, const std::vector &shape, Tensor &output) + { + int new_prod = 1; + for (int dim : shape) + { + new_prod *= dim; + } + + if (tensor.shape.size != new_prod) + { + throw std::invalid_argument("Shape size mismatch"); + } + Shape newshape(shape); + output.shape.shape = newshape.shape; + output.shape.strides = newshape.strides; + output.shape.size = newshape.size; + + if (tensor.data != output.data) + { + std::memcpy(output.data, tensor.data, tensor.shape.size * sizeof(T)); + } + } + }; + + // ═══════════════════════════════════════════════════════════ + // transpose + // ═══════════════════════════════════════════════════════════ + template + struct transposeDispatcher + { + static void transpose(const Tensor &tensor, const std::vector &dim_order, Tensor &output) + { + if (dim_order.size() != static_cast(tensor.shape.dim())) + { + throw std::invalid_argument("dimOrder size does not match the number of dimensions."); + } + if (output.shape.size != tensor.shape.size) + { + throw std::runtime_error("transpose error: output shape size mismatch"); + } + + std::vector new_shape = transposeShape(tensor.shape.shape, dim_order); + output.shape = Shape(new_shape); + + int ndim = tensor.shape.dim(); + std::vector src_indices(ndim, 0); + for (int64_t i = 0; i < output.shape.size; ++i) + { + std::vector dst_indices = output.shape.linearto(static_cast(i)); + for (size_t j = 0; j < static_cast(ndim); ++j) + { + src_indices[dim_order[j]] = dst_indices[j]; + } + int src_linear = tensor.shape.linearat(src_indices); + output.data[i] = tensor.data[src_linear]; + } + } + }; + + // ═══════════════════════════════════════════════════════════ + // concat + // ═══════════════════════════════════════════════════════════ + template + struct concatDispatcher + { + static void concat(const std::vector*> tensors, const int axis, Tensor &result) + { + if (!checkShapeConcat(tensors, axis, result)) + { + throw TensorShapeError("Output tensor shape must match sum of input shapes for concat"); + } + + int dimC = axis + 1; + int copylen = tensors[0]->shape.strides[axis]; + + for (int64_t idx = 0; idx < result.shape.size; idx += copylen) + { + std::vector indices = result.shape.linearto(static_cast(idx)); + + int concatIdx = indices[axis]; + int tensorIdx = 0; + while (tensorIdx < static_cast(tensors.size())) + { + if (concatIdx < tensors[tensorIdx]->shape[axis]) + { + break; + } + concatIdx -= tensors[tensorIdx]->shape[axis]; + tensorIdx++; + } + + std::vector src_indices = indices; + src_indices[axis] = concatIdx; + int src_idx = tensors[tensorIdx]->shape.linearat(src_indices); + std::memcpy(result.data + idx, tensors[tensorIdx]->data + src_idx, copylen * sizeof(T)); + } + } + }; + + // ═══════════════════════════════════════════════════════════ + // broadcastTo helper + // ═══════════════════════════════════════════════════════════ + static std::vector fromBroadcastIndices(const std::vector &bmap, + const std::vector &broadcastIndices) + { + std::vector srcindices; + for (size_t i = 0; i < bmap.size(); ++i) + { + switch (bmap[i]) + { + case xTox: + srcindices.push_back(broadcastIndices[i]); + break; + case nullTo1: + break; + case xTo1: + srcindices.push_back(0); + break; + } + } + return srcindices; + } + + // ═══════════════════════════════════════════════════════════ + // broadcastTo + // ═══════════════════════════════════════════════════════════ + template + struct broadcastToDispatcher + { + static void broadcastTo(const Tensor &A, const std::vector &new_shape, Tensor &B) + { + auto A_broadcastShape = broadcastShape(A.shape.shape, new_shape); + if (A_broadcastShape.empty() || A_broadcastShape != new_shape) + { + throw TensorShapeError("Broadcast shape mismatch"); + } + auto bmap = broadcastMap(A.shape.shape, new_shape); + + int ndim = static_cast(new_shape.size()); + for (int64_t i = 0; i < B.shape.size; ++i) + { + std::vector bindices = B.shape.linearto(static_cast(i)); + std::vector aindices = fromBroadcastIndices(bmap, bindices); + B.data[i] = A.data[A.shape.linearat(aindices)]; + } + } + }; + + // ═══════════════════════════════════════════════════════════ + // indexselect helper + // ═══════════════════════════════════════════════════════════ + template + static void fromIndexselectIndices(const std::vector &output_indices, + const Tensor &index, + std::vector &index_indices, + const int gatherAxis, + std::vector &input_indices) + { + std::copy(output_indices.begin(), output_indices.begin() + gatherAxis, input_indices.begin()); + std::copy(output_indices.begin() + gatherAxis, + output_indices.begin() + gatherAxis + static_cast(index_indices.size()), + index_indices.begin()); + int index_idx = index.shape.linearat(index_indices); + input_indices[gatherAxis] = index.data[index_idx]; + std::copy(output_indices.begin() + gatherAxis + static_cast(index_indices.size()), + output_indices.begin() + static_cast(output_indices.size()), + input_indices.begin() + gatherAxis + 1); + } + + // ═══════════════════════════════════════════════════════════ + // indexselect + // ═══════════════════════════════════════════════════════════ + template + struct indexselectDispatcher + { + static void indexselect(const Tensor &input, const Tensor &index, + const int axis, Tensor &output) + { + int gatherAxis = axis < 0 ? input.shape.dim() + axis : axis; + if (gatherAxis < 0 || gatherAxis >= input.shape.dim()) + { + throw std::invalid_argument("Axis is out of bounds"); + } + + std::vector gatherShape = indexselectShape(input.shape.shape, index.shape.shape, gatherAxis); + if (gatherShape.empty() || gatherShape != output.shape.shape) + { + throw TensorShapeError("Indexselect shape mismatch"); + } + + std::vector input_indices(input.shape.dim(), 0); + std::vector index_indices(index.shape.dim(), 0); + + for (int64_t i = 0; i < output.shape.size; ++i) + { + std::vector output_indices = output.shape.linearto(static_cast(i)); + fromIndexselectIndices(output_indices, index, index_indices, gatherAxis, input_indices); + output.data[i] = input.data[input.shape.linearat(input_indices)]; + } + } + }; + + // ═══════════════════════════════════════════════════════════ + // repeat + // ═══════════════════════════════════════════════════════════ + template + struct repeatDispatcher + { + static void repeat(const Tensor &A, const std::vector &repeats, Tensor &B) + { + auto new_shape = repeatShape(A.shape.shape, repeats); + if (new_shape.empty() || new_shape != B.shape.shape) + { + throw TensorShapeError("Repeat shape mismatch"); + } + + int ndim = A.shape.dim(); + std::vector src_indices(ndim, 0); + for (int64_t i = 0; i < B.shape.size; ++i) + { + std::vector indices = B.shape.linearto(static_cast(i)); + for (size_t d = 0; d < static_cast(ndim); ++d) + { + src_indices[d] = indices[d] / repeats[d]; + } + B.data[i] = A.data[A.shape.linearat(src_indices)]; + } + } + }; + +} // namespace deepx::tensorfunc + +#endif // DEEPX_TENSORFUNC_CHANGESHAPE_MIAOBYTE_HPP diff --git a/executor/op-metal/src/deepx/tensorfunc/elementwise_common.hpp b/executor/op-metal/src/deepx/tensorfunc/elementwise_common.hpp new file mode 100644 index 00000000..161a1aa5 --- /dev/null +++ b/executor/op-metal/src/deepx/tensorfunc/elementwise_common.hpp @@ -0,0 +1,471 @@ +#ifndef DEEPX_TENSORFUNC_ELEMENTWISE_COMMON_HPP +#define DEEPX_TENSORFUNC_ELEMENTWISE_COMMON_HPP + +#if defined(__APPLE__) + #include +#endif + +#include +#include +#include +#include + +#include "tensor.hpp" +#include "deepx/tensorfunc/metal_common.hpp" + +namespace deepx::tensorfunc::detail +{ + template + inline void assert_same_shape(const Tensor &A, const Tensor &B, const Tensor &C) + { + if (A.shape.size != B.shape.size || A.shape.size != C.shape.size || + A.shape.shape != B.shape.shape || A.shape.shape != C.shape.shape) + { + throw std::invalid_argument("shape mismatch"); + } + } + + template + inline void assert_same_shape(const Tensor &A, const Tensor &C) + { + if (A.shape.size != C.shape.size || + A.shape.shape != C.shape.shape) + { + throw std::invalid_argument("shape mismatch"); + } + } + + // ── CPU fallback implementations ── + + template + inline void add_cpu(const Tensor &A, const Tensor &B, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] + B.data[i]; + } + + template + inline void sub_cpu(const Tensor &A, const Tensor &B, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] - B.data[i]; + } + + template + inline void mul_cpu(const Tensor &A, const Tensor &B, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] * B.data[i]; + } + + template + inline void div_cpu(const Tensor &A, const Tensor &B, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] / B.data[i]; + } + + template + inline void max_cpu(const Tensor &A, const Tensor &B, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] > B.data[i] ? A.data[i] : B.data[i]; + } + + template + inline void min_cpu(const Tensor &A, const Tensor &B, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] < B.data[i] ? A.data[i] : B.data[i]; + } + + template + inline void relu_cpu(const Tensor &A, Tensor &C) + { + T zero = T(0); + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] > zero ? A.data[i] : zero; + } + + template + inline void neg_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = -A.data[i]; + } + + template + inline void abs_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] < T(0) ? -A.data[i] : A.data[i]; + } + + template + inline void addscalar_cpu(const Tensor &A, const T v, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] + v; + } + + template + inline void subscalar_cpu(const Tensor &A, const T v, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] - v; + } + + template + inline void rsubscalar_cpu(const T v, const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = v - A.data[i]; + } + + template + inline void mulscalar_cpu(const Tensor &A, const T v, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] * v; + } + + template + inline void divscalar_cpu(const Tensor &A, const T v, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] / v; + } + + template + inline void rdivscalar_cpu(const T v, const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = v / A.data[i]; + } + + template + inline void maxscalar_cpu(const Tensor &A, const T v, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] > v ? A.data[i] : v; + } + + template + inline void minscalar_cpu(const Tensor &A, const T v, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = A.data[i] < v ? A.data[i] : v; + } + + template + inline void pow_cpu(const Tensor &A, const Tensor &B, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = std::pow(A.data[i], B.data[i]); + } + + template + inline void powscalar_cpu(const Tensor &A, const T v, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = std::pow(A.data[i], v); + } + + template + inline void rpowscalar_cpu(const T v, const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = std::pow(v, A.data[i]); + } + + template + inline void sqrt_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = std::sqrt(A.data[i]); + } + + template + inline void log_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = std::log(A.data[i]); + } + + template + inline void exp_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = std::exp(A.data[i]); + } + + template + inline void sin_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = std::sin(A.data[i]); + } + + template + inline void cos_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = std::cos(A.data[i]); + } + + template + inline void tan_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = std::tan(A.data[i]); + } + + template + inline void invert_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = ~A.data[i]; + } + + template <> + inline void invert_cpu(const Tensor &A, Tensor &C) + { + for (int64_t i = 0; i < A.shape.size; ++i) + C.data[i] = !A.data[i]; + } +} + +namespace deepx::metal::kernels +{ +#if defined(__APPLE__) && TARGET_OS_OSX && defined(__OBJC__) + inline deepx::metal::common::MetalKernelRuntime &elementwise_runtime() + { + static deepx::metal::common::MetalKernelRuntime rt; + return rt; + } + + // ── ADD ── + inline bool add_f32(const float *a, const float *b, float *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("add_f32", a, b, c, static_cast(n), sizeof(float)); + } + inline bool add_i8(const int8_t *a, const int8_t *b, int8_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("add_i8", a, b, c, static_cast(n), sizeof(int8_t)); + } + inline bool add_i16(const int16_t *a, const int16_t *b, int16_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("add_i16", a, b, c, static_cast(n), sizeof(int16_t)); + } + inline bool add_i32(const int32_t *a, const int32_t *b, int32_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("add_i32", a, b, c, static_cast(n), sizeof(int32_t)); + } + inline bool add_i64(const int64_t *a, const int64_t *b, int64_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("add_i64", a, b, c, static_cast(n), sizeof(int64_t)); + } + + // ── SUB ── + inline bool sub_f32(const float *a, const float *b, float *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("sub_f32", a, b, c, static_cast(n), sizeof(float)); + } + inline bool sub_i8(const int8_t *a, const int8_t *b, int8_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("sub_i8", a, b, c, static_cast(n), sizeof(int8_t)); + } + inline bool sub_i16(const int16_t *a, const int16_t *b, int16_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("sub_i16", a, b, c, static_cast(n), sizeof(int16_t)); + } + inline bool sub_i32(const int32_t *a, const int32_t *b, int32_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("sub_i32", a, b, c, static_cast(n), sizeof(int32_t)); + } + inline bool sub_i64(const int64_t *a, const int64_t *b, int64_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("sub_i64", a, b, c, static_cast(n), sizeof(int64_t)); + } + + // ── MUL ── + inline bool mul_f32(const float *a, const float *b, float *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("mul_f32", a, b, c, static_cast(n), sizeof(float)); + } + inline bool mul_i8(const int8_t *a, const int8_t *b, int8_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("mul_i8", a, b, c, static_cast(n), sizeof(int8_t)); + } + inline bool mul_i16(const int16_t *a, const int16_t *b, int16_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("mul_i16", a, b, c, static_cast(n), sizeof(int16_t)); + } + inline bool mul_i32(const int32_t *a, const int32_t *b, int32_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("mul_i32", a, b, c, static_cast(n), sizeof(int32_t)); + } + inline bool mul_i64(const int64_t *a, const int64_t *b, int64_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("mul_i64", a, b, c, static_cast(n), sizeof(int64_t)); + } + + // ── DIV ── + inline bool div_f32(const float *a, const float *b, float *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("div_f32", a, b, c, static_cast(n), sizeof(float)); + } + + // ── MAX ── + inline bool max_f32(const float *a, const float *b, float *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("max_f32", a, b, c, static_cast(n), sizeof(float)); + } + inline bool max_i8(const int8_t *a, const int8_t *b, int8_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("max_i8", a, b, c, static_cast(n), sizeof(int8_t)); + } + inline bool max_i16(const int16_t *a, const int16_t *b, int16_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("max_i16", a, b, c, static_cast(n), sizeof(int16_t)); + } + inline bool max_i32(const int32_t *a, const int32_t *b, int32_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("max_i32", a, b, c, static_cast(n), sizeof(int32_t)); + } + inline bool max_i64(const int64_t *a, const int64_t *b, int64_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("max_i64", a, b, c, static_cast(n), sizeof(int64_t)); + } + + // ── MIN ── + inline bool min_f32(const float *a, const float *b, float *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("min_f32", a, b, c, static_cast(n), sizeof(float)); + } + inline bool min_i8(const int8_t *a, const int8_t *b, int8_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("min_i8", a, b, c, static_cast(n), sizeof(int8_t)); + } + inline bool min_i16(const int16_t *a, const int16_t *b, int16_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("min_i16", a, b, c, static_cast(n), sizeof(int16_t)); + } + inline bool min_i32(const int32_t *a, const int32_t *b, int32_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("min_i32", a, b, c, static_cast(n), sizeof(int32_t)); + } + inline bool min_i64(const int64_t *a, const int64_t *b, int64_t *c, int64_t n) { + return elementwise_runtime().dispatch_binary_1d("min_i64", a, b, c, static_cast(n), sizeof(int64_t)); + } + + // ── RELU ── + inline bool relu_f32(const float *x, float *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("relu_f32", x, y, static_cast(n), sizeof(float)); + } + inline bool relu_i8(const int8_t *x, int8_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("relu_i8", x, y, static_cast(n), sizeof(int8_t)); + } + inline bool relu_i16(const int16_t *x, int16_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("relu_i16", x, y, static_cast(n), sizeof(int16_t)); + } + inline bool relu_i32(const int32_t *x, int32_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("relu_i32", x, y, static_cast(n), sizeof(int32_t)); + } + inline bool relu_i64(const int64_t *x, int64_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("relu_i64", x, y, static_cast(n), sizeof(int64_t)); + } + + // ── NEG ── + inline bool neg_f32(const float *x, float *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("neg_f32", x, y, static_cast(n), sizeof(float)); + } + inline bool neg_i8(const int8_t *x, int8_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("neg_i8", x, y, static_cast(n), sizeof(int8_t)); + } + inline bool neg_i16(const int16_t *x, int16_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("neg_i16", x, y, static_cast(n), sizeof(int16_t)); + } + inline bool neg_i32(const int32_t *x, int32_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("neg_i32", x, y, static_cast(n), sizeof(int32_t)); + } + inline bool neg_i64(const int64_t *x, int64_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("neg_i64", x, y, static_cast(n), sizeof(int64_t)); + } + + // ── ABS ── + inline bool abs_f32(const float *x, float *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("abs_f32", x, y, static_cast(n), sizeof(float)); + } + inline bool abs_i8(const int8_t *x, int8_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("abs_i8", x, y, static_cast(n), sizeof(int8_t)); + } + inline bool abs_i16(const int16_t *x, int16_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("abs_i16", x, y, static_cast(n), sizeof(int16_t)); + } + inline bool abs_i32(const int32_t *x, int32_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("abs_i32", x, y, static_cast(n), sizeof(int32_t)); + } + inline bool abs_i64(const int64_t *x, int64_t *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("abs_i64", x, y, static_cast(n), sizeof(int64_t)); + } + + // ── SQRT ── + inline bool sqrt_f32(const float *x, float *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("sqrt_f32", x, y, static_cast(n), sizeof(float)); + } + + // ── EXP ── + inline bool exp_f32(const float *x, float *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("exp_f32", x, y, static_cast(n), sizeof(float)); + } + + // ── LOG ── + inline bool log_f32(const float *x, float *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("log_f32", x, y, static_cast(n), sizeof(float)); + } + + // ── SIN ── + inline bool sin_f32(const float *x, float *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("sin_f32", x, y, static_cast(n), sizeof(float)); + } + + // ── COS ── + inline bool cos_f32(const float *x, float *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("cos_f32", x, y, static_cast(n), sizeof(float)); + } + + // ── TAN ── + inline bool tan_f32(const float *x, float *y, int64_t n) { + return elementwise_runtime().dispatch_unary_1d("tan_f32", x, y, static_cast(n), sizeof(float)); + } + +#else + // ── Stubs for non-ObjC / non-Apple platforms ── + inline bool add_f32(const float *, const float *, float *, int64_t) { return false; } + inline bool add_i8(const int8_t *, const int8_t *, int8_t *, int64_t) { return false; } + inline bool add_i16(const int16_t *, const int16_t *, int16_t *, int64_t) { return false; } + inline bool add_i32(const int32_t *, const int32_t *, int32_t *, int64_t) { return false; } + inline bool add_i64(const int64_t *, const int64_t *, int64_t *, int64_t) { return false; } + inline bool sub_f32(const float *, const float *, float *, int64_t) { return false; } + inline bool sub_i8(const int8_t *, const int8_t *, int8_t *, int64_t) { return false; } + inline bool sub_i16(const int16_t *, const int16_t *, int16_t *, int64_t) { return false; } + inline bool sub_i32(const int32_t *, const int32_t *, int32_t *, int64_t) { return false; } + inline bool sub_i64(const int64_t *, const int64_t *, int64_t *, int64_t) { return false; } + inline bool mul_f32(const float *, const float *, float *, int64_t) { return false; } + inline bool mul_i8(const int8_t *, const int8_t *, int8_t *, int64_t) { return false; } + inline bool mul_i16(const int16_t *, const int16_t *, int16_t *, int64_t) { return false; } + inline bool mul_i32(const int32_t *, const int32_t *, int32_t *, int64_t) { return false; } + inline bool mul_i64(const int64_t *, const int64_t *, int64_t *, int64_t) { return false; } + inline bool div_f32(const float *, const float *, float *, int64_t) { return false; } + inline bool max_f32(const float *, const float *, float *, int64_t) { return false; } + inline bool max_i8(const int8_t *, const int8_t *, int8_t *, int64_t) { return false; } + inline bool max_i16(const int16_t *, const int16_t *, int16_t *, int64_t) { return false; } + inline bool max_i32(const int32_t *, const int32_t *, int32_t *, int64_t) { return false; } + inline bool max_i64(const int64_t *, const int64_t *, int64_t *, int64_t) { return false; } + inline bool min_f32(const float *, const float *, float *, int64_t) { return false; } + inline bool min_i8(const int8_t *, const int8_t *, int8_t *, int64_t) { return false; } + inline bool min_i16(const int16_t *, const int16_t *, int16_t *, int64_t) { return false; } + inline bool min_i32(const int32_t *, const int32_t *, int32_t *, int64_t) { return false; } + inline bool min_i64(const int64_t *, const int64_t *, int64_t *, int64_t) { return false; } + inline bool relu_f32(const float *, float *, int64_t) { return false; } + inline bool relu_i8(const int8_t *, int8_t *, int64_t) { return false; } + inline bool relu_i16(const int16_t *, int16_t *, int64_t) { return false; } + inline bool relu_i32(const int32_t *, int32_t *, int64_t) { return false; } + inline bool relu_i64(const int64_t *, int64_t *, int64_t) { return false; } + inline bool neg_f32(const float *, float *, int64_t) { return false; } + inline bool neg_i8(const int8_t *, int8_t *, int64_t) { return false; } + inline bool neg_i16(const int16_t *, int16_t *, int64_t) { return false; } + inline bool neg_i32(const int32_t *, int32_t *, int64_t) { return false; } + inline bool neg_i64(const int64_t *, int64_t *, int64_t) { return false; } + inline bool abs_f32(const float *, float *, int64_t) { return false; } + inline bool abs_i8(const int8_t *, int8_t *, int64_t) { return false; } + inline bool abs_i16(const int16_t *, int16_t *, int64_t) { return false; } + inline bool abs_i32(const int32_t *, int32_t *, int64_t) { return false; } + inline bool abs_i64(const int64_t *, int64_t *, int64_t) { return false; } + inline bool sqrt_f32(const float *, float *, int64_t) { return false; } + inline bool exp_f32(const float *, float *, int64_t) { return false; } + inline bool log_f32(const float *, float *, int64_t) { return false; } + inline bool sin_f32(const float *, float *, int64_t) { return false; } + inline bool cos_f32(const float *, float *, int64_t) { return false; } + inline bool tan_f32(const float *, float *, int64_t) { return false; } + +#endif +} // namespace deepx::metal::kernels + +#endif // DEEPX_TENSORFUNC_ELEMENTWISE_COMMON_HPP diff --git a/executor/op-metal/src/deepx/tensorfunc/elementwise_miaobyte.hpp b/executor/op-metal/src/deepx/tensorfunc/elementwise_miaobyte.hpp new file mode 100644 index 00000000..41f289ce --- /dev/null +++ b/executor/op-metal/src/deepx/tensorfunc/elementwise_miaobyte.hpp @@ -0,0 +1,476 @@ +#ifndef DEEPX_TENSORFUNC_ELEMENTWISE_MIAOBYTE_HPP +#define DEEPX_TENSORFUNC_ELEMENTWISE_MIAOBYTE_HPP + +#include +#include +#include + +#include "tensor.hpp" +#include "tensorfunc/authors.hpp" +#include "deepx/tensorfunc/elementwise_common.hpp" +#include "tensorfunc/elementwise.hpp" + +// ═══════════════════════════════════════════════════════════ +// Helper: Metal-first, CPU-fallback pattern for binary ops +// ═══════════════════════════════════════════════════════════ +#define DEEPX_METAL_DISPATCH(T, ok, kernelFn, cpuFn, A, B, C) \ + if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_f32(A.data, B.data, C.data, A.shape.size); \ + else if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_i8(A.data, B.data, C.data, A.shape.size); \ + else if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_i16(A.data, B.data, C.data, A.shape.size); \ + else if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_i32(A.data, B.data, C.data, A.shape.size); \ + else if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_i64(A.data, B.data, C.data, A.shape.size); + +#define DEEPX_METAL_UNARY_DISPATCH(T, ok, kernelFn, cpuFn, A, C) \ + if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_f32(A.data, C.data, A.shape.size); \ + else if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_i8(A.data, C.data, A.shape.size); \ + else if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_i16(A.data, C.data, A.shape.size); \ + else if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_i32(A.data, C.data, A.shape.size); \ + else if constexpr (std::is_same_v) ok = deepx::metal::kernels::kernelFn##_i64(A.data, C.data, A.shape.size); + +namespace deepx::tensorfunc +{ + +// ═══════════════════════════════════════════════════════════ +// Binary elementwise ops: add, sub, mul, div, max, min, pow +// ═══════════════════════════════════════════════════════════ + +// ── add ── +template +struct addDispatcher +{ + static void add(const Tensor &A, const Tensor &B, Tensor &C) + { + detail::assert_same_shape(A, B, C); + bool ok = false; + DEEPX_METAL_DISPATCH(T, ok, add, add_cpu, A, B, C) + if (!ok) detail::add_cpu(A, B, C); + } +}; + +// ── sub ── +template +struct subDispatcher +{ + static void sub(const Tensor &A, const Tensor &B, Tensor &C) + { + detail::assert_same_shape(A, B, C); + bool ok = false; + DEEPX_METAL_DISPATCH(T, ok, sub, sub_cpu, A, B, C) + if (!ok) detail::sub_cpu(A, B, C); + } +}; + +// ── mul ── +template +struct mulDispatcher +{ + static void mul(const Tensor &A, const Tensor &B, Tensor &C) + { + detail::assert_same_shape(A, B, C); + bool ok = false; + DEEPX_METAL_DISPATCH(T, ok, mul, mul_cpu, A, B, C) + if (!ok) detail::mul_cpu(A, B, C); + } +}; + +// ── div ── +template +struct divDispatcher +{ + static void div(const Tensor &A, const Tensor &B, Tensor &C) + { + detail::assert_same_shape(A, B, C); + bool ok = false; + if constexpr (std::is_same_v) + ok = deepx::metal::kernels::div_f32(A.data, B.data, C.data, A.shape.size); + if (!ok) detail::div_cpu(A, B, C); + } +}; + +// ── max ── +template +struct maxDispatcher +{ + static void max(const Tensor &A, const Tensor &B, Tensor &C) + { + detail::assert_same_shape(A, B, C); + bool ok = false; + DEEPX_METAL_DISPATCH(T, ok, max, max_cpu, A, B, C) + if (!ok) detail::max_cpu(A, B, C); + } +}; + +// ── min ── +template +struct minDispatcher +{ + static void min(const Tensor &A, const Tensor &B, Tensor &C) + { + detail::assert_same_shape(A, B, C); + bool ok = false; + DEEPX_METAL_DISPATCH(T, ok, min, min_cpu, A, B, C) + if (!ok) detail::min_cpu(A, B, C); + } +}; + +// ── pow (CPU-only) ── +template +struct powDispatcher +{ + static void pow(const Tensor &A, const Tensor &B, Tensor &C) + { + detail::assert_same_shape(A, B, C); + detail::pow_cpu(A, B, C); + } +}; + +// ═══════════════════════════════════════════════════════════ +// Scalar elementwise ops +// ═══════════════════════════════════════════════════════════ + +// ── addscalar ── +template +struct addscalarDispatcher +{ + static void addscalar(const Tensor &A, const T value, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::addscalar_cpu(A, value, C); + } +}; + +// ── subscalar ── +template +struct subscalarDispatcher +{ + static void subscalar(const Tensor &A, const T value, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::subscalar_cpu(A, value, C); + } +}; + +// ── rsubscalar ── +template +struct rsubscalarDispatcher +{ + static void rsubscalar(const T value, const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::rsubscalar_cpu(value, A, C); + } +}; + +// ── mulscalar ── +template +struct mulscalarDispatcher +{ + static void mulscalar(const Tensor &A, const T value, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::mulscalar_cpu(A, value, C); + } +}; + +// ── divscalar ── +template +struct divscalarDispatcher +{ + static void divscalar(const Tensor &A, const T value, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::divscalar_cpu(A, value, C); + } +}; + +// ── rdivscalar ── +template +struct rdivscalarDispatcher +{ + static void rdivscalar(const T value, const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::rdivscalar_cpu(value, A, C); + } +}; + +// ── maxscalar ── +template +struct maxscalarDispatcher +{ + static void maxscalar(const Tensor &A, const T b, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::maxscalar_cpu(A, b, C); + } +}; + +// ── minscalar ── +template +struct minscalarDispatcher +{ + static void minscalar(const Tensor &A, const T b, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::minscalar_cpu(A, b, C); + } +}; + +// ── powscalar ── +template +struct powscalarDispatcher +{ + static void powscalar(const Tensor &A, const T value, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::powscalar_cpu(A, value, C); + } +}; + +// ── rpowscalar ── +template +struct rpowscalarDispatcher +{ + static void rpowscalar(const T value, const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::rpowscalar_cpu(value, A, C); + } +}; + +// ═══════════════════════════════════════════════════════════ +// Unary elementwise ops (Metal + CPU fallback) +// ═══════════════════════════════════════════════════════════ + +// ── sqrt ── +template +struct sqrtDispatcher +{ + static void sqrt(const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + bool ok = false; + if constexpr (std::is_same_v) + ok = deepx::metal::kernels::sqrt_f32(A.data, C.data, A.shape.size); + if (!ok) detail::sqrt_cpu(A, C); + } +}; + +// ── exp ── +template +struct expDispatcher +{ + static void exp(const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + bool ok = false; + if constexpr (std::is_same_v) + ok = deepx::metal::kernels::exp_f32(A.data, C.data, A.shape.size); + if (!ok) detail::exp_cpu(A, C); + } +}; + +// ── log ── +template +struct logDispatcher +{ + static void log(const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + bool ok = false; + if constexpr (std::is_same_v) + ok = deepx::metal::kernels::log_f32(A.data, C.data, A.shape.size); + if (!ok) detail::log_cpu(A, C); + } +}; + +// ── sin ── +template +struct sinDispatcher +{ + static void sin(const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + bool ok = false; + if constexpr (std::is_same_v) + ok = deepx::metal::kernels::sin_f32(A.data, C.data, A.shape.size); + if (!ok) detail::sin_cpu(A, C); + } +}; + +// ── cos ── +template +struct cosDispatcher +{ + static void cos(const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + bool ok = false; + if constexpr (std::is_same_v) + ok = deepx::metal::kernels::cos_f32(A.data, C.data, A.shape.size); + if (!ok) detail::cos_cpu(A, C); + } +}; + +// ── tan ── +template +struct tanDispatcher +{ + static void tan(const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + bool ok = false; + if constexpr (std::is_same_v) + ok = deepx::metal::kernels::tan_f32(A.data, C.data, A.shape.size); + if (!ok) detail::tan_cpu(A, C); + } +}; + +// ── invert ── +template +struct invertDispatcher +{ + static void invert(const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::invert_cpu(A, C); + } +}; + +// Specialization for bool +template <> +struct invertDispatcher +{ + static void invert(const Tensor &A, Tensor &C) + { + detail::assert_same_shape(A, C); + detail::invert_cpu(A, C); + } +}; + +// ═══════════════════════════════════════════════════════════ +// Comparison ops (CPU-only) +// ═══════════════════════════════════════════════════════════ + +// ── equal ── +template +struct equalDispatcher +{ + static void equal(const Tensor &A, const Tensor &B, float epsilon, Tensor &mask) + { + detail::assert_same_shape(A, B, mask); + if (epsilon == 0) { + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = A.data[i] == B.data[i]; + } else { + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = std::abs(static_cast(A.data[i]) - static_cast(B.data[i])) <= static_cast(epsilon); + } + } +}; + +// ── equalscalar ── +template +struct equalscalarDispatcher +{ + static void equalscalar(const Tensor &A, const T scalar, float epsilon, Tensor &mask) + { + detail::assert_same_shape(A, mask); + if (epsilon == 0) { + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = A.data[i] == scalar; + } else { + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = std::abs(static_cast(A.data[i]) - static_cast(scalar)) <= static_cast(epsilon); + } + } +}; + +// ── notequal ── +template +struct notequalDispatcher +{ + static void notequal(const Tensor &A, const Tensor &B, float epsilon, Tensor &mask) + { + detail::assert_same_shape(A, B, mask); + if (epsilon == 0) { + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = A.data[i] != B.data[i]; + } else { + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = std::abs(static_cast(A.data[i]) - static_cast(B.data[i])) > static_cast(epsilon); + } + } +}; + +// ── notequalscalar ── +template +struct notequalscalarDispatcher +{ + static void notequalscalar(const Tensor &A, const T scalar, float epsilon, Tensor &mask) + { + detail::assert_same_shape(A, mask); + if (epsilon == 0) { + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = A.data[i] != scalar; + } else { + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = std::abs(static_cast(A.data[i]) - static_cast(scalar)) > static_cast(epsilon); + } + } +}; + +// ── less ── +template +struct lessDispatcher +{ + static void less(const Tensor &A, const Tensor &B, Tensor &mask) + { + detail::assert_same_shape(A, B, mask); + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = A.data[i] < B.data[i]; + } +}; + +// ── lessscalar ── +template +struct lessscalarDispatcher +{ + static void lessscalar(const Tensor &A, const T scalar, Tensor &mask) + { + detail::assert_same_shape(A, mask); + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = A.data[i] < scalar; + } +}; + +// ── greater ── +template +struct greaterDispatcher +{ + static void greater(const Tensor &A, const Tensor &B, Tensor &mask) + { + detail::assert_same_shape(A, B, mask); + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = A.data[i] > B.data[i]; + } +}; + +// ── greaterscalar ── +template +struct greaterscalarDispatcher +{ + static void greaterscalar(const Tensor &A, const T scalar, Tensor &mask) + { + detail::assert_same_shape(A, mask); + for (int64_t i = 0; i < A.shape.size; ++i) + mask.data[i] = A.data[i] > scalar; + } +}; + +} // namespace deepx::tensorfunc + +#undef DEEPX_METAL_DISPATCH +#undef DEEPX_METAL_UNARY_DISPATCH + +#endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAOBYTE_HPP diff --git a/executor/op-metal/src/deepx/tensorfunc/elementwise_miaobyte.metal b/executor/op-metal/src/deepx/tensorfunc/elementwise_miaobyte.metal new file mode 100644 index 00000000..8eeb1dd6 --- /dev/null +++ b/executor/op-metal/src/deepx/tensorfunc/elementwise_miaobyte.metal @@ -0,0 +1,566 @@ +#include +using namespace metal; + +// ═══════════════════════════════════════════════════════════ +// miaobyte elementwise kernels (specialized per dtype) +// ops: add / sub / mul / div / max / min (binary) +// relu / neg / abs / sqrt / exp / log / sin / cos / tan (unary) +// ═══════════════════════════════════════════════════════════ + +// ── ADD ── + +kernel void add_f32(device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] + B[gid]; } +} + +kernel void add_f16(device const half* A [[buffer(0)]], + device const half* B [[buffer(1)]], + device half* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] + B[gid]; } +} + +kernel void add_i8(device const char* A [[buffer(0)]], + device const char* B [[buffer(1)]], + device char* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = (char)(A[gid] + B[gid]); } +} + +kernel void add_i16(device const short* A [[buffer(0)]], + device const short* B [[buffer(1)]], + device short* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = (short)(A[gid] + B[gid]); } +} + +kernel void add_i32(device const int* A [[buffer(0)]], + device const int* B [[buffer(1)]], + device int* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] + B[gid]; } +} + +kernel void add_i64(device const long* A [[buffer(0)]], + device const long* B [[buffer(1)]], + device long* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] + B[gid]; } +} + +// ── SUB ── + +kernel void sub_f32(device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] - B[gid]; } +} + +kernel void sub_f16(device const half* A [[buffer(0)]], + device const half* B [[buffer(1)]], + device half* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] - B[gid]; } +} + +kernel void sub_i8(device const char* A [[buffer(0)]], + device const char* B [[buffer(1)]], + device char* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = (char)(A[gid] - B[gid]); } +} + +kernel void sub_i16(device const short* A [[buffer(0)]], + device const short* B [[buffer(1)]], + device short* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = (short)(A[gid] - B[gid]); } +} + +kernel void sub_i32(device const int* A [[buffer(0)]], + device const int* B [[buffer(1)]], + device int* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] - B[gid]; } +} + +kernel void sub_i64(device const long* A [[buffer(0)]], + device const long* B [[buffer(1)]], + device long* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] - B[gid]; } +} + +// ── MUL ── + +kernel void mul_f32(device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] * B[gid]; } +} + +kernel void mul_f16(device const half* A [[buffer(0)]], + device const half* B [[buffer(1)]], + device half* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] * B[gid]; } +} + +kernel void mul_i8(device const char* A [[buffer(0)]], + device const char* B [[buffer(1)]], + device char* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = (char)(A[gid] * B[gid]); } +} + +kernel void mul_i16(device const short* A [[buffer(0)]], + device const short* B [[buffer(1)]], + device short* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = (short)(A[gid] * B[gid]); } +} + +kernel void mul_i32(device const int* A [[buffer(0)]], + device const int* B [[buffer(1)]], + device int* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] * B[gid]; } +} + +kernel void mul_i64(device const long* A [[buffer(0)]], + device const long* B [[buffer(1)]], + device long* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] * B[gid]; } +} + +// ── DIV ── + +kernel void div_f32(device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] / B[gid]; } +} + +kernel void div_f16(device const half* A [[buffer(0)]], + device const half* B [[buffer(1)]], + device half* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = A[gid] / B[gid]; } +} + +// ── MAX ── + +kernel void max_f32(device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = max(A[gid], B[gid]); } +} + +kernel void max_f16(device const half* A [[buffer(0)]], + device const half* B [[buffer(1)]], + device half* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = max(A[gid], B[gid]); } +} + +kernel void max_i8(device const char* A [[buffer(0)]], + device const char* B [[buffer(1)]], + device char* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = max(A[gid], B[gid]); } +} + +kernel void max_i16(device const short* A [[buffer(0)]], + device const short* B [[buffer(1)]], + device short* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = max(A[gid], B[gid]); } +} + +kernel void max_i32(device const int* A [[buffer(0)]], + device const int* B [[buffer(1)]], + device int* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = max(A[gid], B[gid]); } +} + +kernel void max_i64(device const long* A [[buffer(0)]], + device const long* B [[buffer(1)]], + device long* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = max(A[gid], B[gid]); } +} + +// ── MIN ── + +kernel void min_f32(device const float* A [[buffer(0)]], + device const float* B [[buffer(1)]], + device float* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = min(A[gid], B[gid]); } +} + +kernel void min_f16(device const half* A [[buffer(0)]], + device const half* B [[buffer(1)]], + device half* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = min(A[gid], B[gid]); } +} + +kernel void min_i8(device const char* A [[buffer(0)]], + device const char* B [[buffer(1)]], + device char* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = min(A[gid], B[gid]); } +} + +kernel void min_i16(device const short* A [[buffer(0)]], + device const short* B [[buffer(1)]], + device short* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = min(A[gid], B[gid]); } +} + +kernel void min_i32(device const int* A [[buffer(0)]], + device const int* B [[buffer(1)]], + device int* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = min(A[gid], B[gid]); } +} + +kernel void min_i64(device const long* A [[buffer(0)]], + device const long* B [[buffer(1)]], + device long* C [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { C[gid] = min(A[gid], B[gid]); } +} + +// ── RELU ── + +kernel void relu_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = max(X[gid], 0.0f); } +} + +kernel void relu_f16(device const half* X [[buffer(0)]], + device half* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = max(X[gid], half(0.0)); } +} + +kernel void relu_i8(device const char* X [[buffer(0)]], + device char* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = max(X[gid], char(0)); } +} + +kernel void relu_i16(device const short* X [[buffer(0)]], + device short* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = max(X[gid], short(0)); } +} + +kernel void relu_i32(device const int* X [[buffer(0)]], + device int* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = max(X[gid], 0); } +} + +kernel void relu_i64(device const long* X [[buffer(0)]], + device long* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = max(X[gid], long(0)); } +} + +// ── NEG ── + +kernel void neg_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = -X[gid]; } +} + +kernel void neg_f16(device const half* X [[buffer(0)]], + device half* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = -X[gid]; } +} + +kernel void neg_i8(device const char* X [[buffer(0)]], + device char* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = (char)(-X[gid]); } +} + +kernel void neg_i16(device const short* X [[buffer(0)]], + device short* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = (short)(-X[gid]); } +} + +kernel void neg_i32(device const int* X [[buffer(0)]], + device int* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = -X[gid]; } +} + +kernel void neg_i64(device const long* X [[buffer(0)]], + device long* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = -X[gid]; } +} + +// ── ABS ── + +kernel void abs_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = abs(X[gid]); } +} + +kernel void abs_f16(device const half* X [[buffer(0)]], + device half* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = abs(X[gid]); } +} + +kernel void abs_i8(device const char* X [[buffer(0)]], + device char* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = abs(X[gid]); } +} + +kernel void abs_i16(device const short* X [[buffer(0)]], + device short* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = abs(X[gid]); } +} + +kernel void abs_i32(device const int* X [[buffer(0)]], + device int* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = abs(X[gid]); } +} + +kernel void abs_i64(device const long* X [[buffer(0)]], + device long* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = abs(X[gid]); } +} + +// ── SQRT (浮点 only) ── + +kernel void sqrt_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = sqrt(X[gid]); } +} + +kernel void sqrt_f16(device const half* X [[buffer(0)]], + device half* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = sqrt(X[gid]); } +} + +// ── EXP (浮点 only) ── + +kernel void exp_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = exp(X[gid]); } +} + +kernel void exp_f16(device const half* X [[buffer(0)]], + device half* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = exp(X[gid]); } +} + +// ── LOG (浮点 only) ── + +kernel void log_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = log(X[gid]); } +} + +kernel void log_f16(device const half* X [[buffer(0)]], + device half* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = log(X[gid]); } +} + +// ── SIN (浮点 only) ── + +kernel void sin_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = sin(X[gid]); } +} + +kernel void sin_f16(device const half* X [[buffer(0)]], + device half* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = sin(X[gid]); } +} + +// ── COS (浮点 only) ── + +kernel void cos_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = cos(X[gid]); } +} + +kernel void cos_f16(device const half* X [[buffer(0)]], + device half* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = cos(X[gid]); } +} + +// ── TAN (浮点 only) ── + +kernel void tan_f32(device const float* X [[buffer(0)]], + device float* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = tan(X[gid]); } +} + +kernel void tan_f16(device const half* X [[buffer(0)]], + device half* Y [[buffer(1)]], + constant uint& n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid < n) { Y[gid] = tan(X[gid]); } +} diff --git a/executor/op-mem-mps/src/deepx/tensorfunc/init_miaobyte.hpp b/executor/op-metal/src/deepx/tensorfunc/init_miaobyte.hpp similarity index 81% rename from executor/op-mem-mps/src/deepx/tensorfunc/init_miaobyte.hpp rename to executor/op-metal/src/deepx/tensorfunc/init_miaobyte.hpp index 66a1345f..62f66825 100644 --- a/executor/op-mem-mps/src/deepx/tensorfunc/init_miaobyte.hpp +++ b/executor/op-metal/src/deepx/tensorfunc/init_miaobyte.hpp @@ -3,9 +3,9 @@ #include -#include "deepx/tensor.hpp" -#include "deepx/tensorfunc/init.hpp" -#include "deepx/tensorfunc/authors.hpp" +#include "tensor.hpp" +#include "tensorfunc/init.hpp" +#include "tensorfunc/authors.hpp" namespace deepx::tensorfunc { diff --git a/executor/op-mem-mps/src/deepx/tensorfunc/io_miaobyte.hpp b/executor/op-metal/src/deepx/tensorfunc/io_miaobyte.hpp similarity index 78% rename from executor/op-mem-mps/src/deepx/tensorfunc/io_miaobyte.hpp rename to executor/op-metal/src/deepx/tensorfunc/io_miaobyte.hpp index f5324b5e..49e172fc 100644 --- a/executor/op-mem-mps/src/deepx/tensorfunc/io_miaobyte.hpp +++ b/executor/op-metal/src/deepx/tensorfunc/io_miaobyte.hpp @@ -3,12 +3,12 @@ #include -#include "deepx/tensor.hpp" +#include "tensor.hpp" #include "stdutil/vector.hpp" #include "stdutil/print.hpp" #include "stdutil/fs.hpp" -#include "deepx/tensorfunc/authors.hpp" -#include "deepx/tensorfunc/io.hpp" +#include "tensorfunc/authors.hpp" +#include "tensorfunc/io.hpp" #include "deepx/tensorfunc/tensorlife_miaobyte.hpp" namespace deepx::tensorfunc { @@ -37,6 +37,15 @@ namespace deepx::tensorfunc }; + //save — persist tensor shape + data to disk + template + void save(const std::string &tensor_name, const Tensor &tensor, const std::string &dir = "") + { + std::string prefix = dir.empty() ? tensor_name : dir + "/" + tensor_name; + tensor.shape.saveShape(prefix); // e.g. tensor_a.yaml + tensor.saver(tensor.data, tensor.shape.size, prefix + ".data"); + } + //load template pair>> load(const std::string &path) @@ -45,7 +54,7 @@ namespace deepx::tensorfunc pair shape_name=Shape::loadShape(path); Shape shape=shape_name.second; std::string tensor_name=shape_name.first; - + // 检查T 和 shape.dtype 是否匹配 if (shape.dtype != precision()) @@ -53,7 +62,7 @@ namespace deepx::tensorfunc throw std::runtime_error("调用load<" + precision_str(shape.dtype) + "> 不匹配: 需要 " + precision_str(shape.dtype) + " 类型,但文件为" + precision_str(precision()) + " 类型"); } - + shared_ptr> tensor = make_shared>(New(shape.shape)); tensor->loader(path+".data",tensor->data,tensor->shape.size); return std::make_pair(tensor_name, tensor); diff --git a/executor/op-metal/src/deepx/tensorfunc/metal_common.hpp b/executor/op-metal/src/deepx/tensorfunc/metal_common.hpp new file mode 100644 index 00000000..0b7e7791 --- /dev/null +++ b/executor/op-metal/src/deepx/tensorfunc/metal_common.hpp @@ -0,0 +1,157 @@ +#ifndef DEEPX_TENSORFUNC_METAL_COMMON_HPP +#define DEEPX_TENSORFUNC_METAL_COMMON_HPP + +#if defined(__APPLE__) + #include +#endif + +#include +#include +#include +#include +#include + +#if defined(__APPLE__) && TARGET_OS_OSX + #if defined(__OBJC__) + #import + #import + #endif +#endif + +namespace deepx::metal::common +{ +#if defined(__APPLE__) && TARGET_OS_OSX && defined(__OBJC__) + class MetalKernelRuntime final + { + public: + MetalKernelRuntime() + { + device_ = MTLCreateSystemDefaultDevice(); + queue_ = device_ ? [device_ newCommandQueue] : nil; + } + + bool valid() const { return device_ != nil && queue_ != nil; } + + bool dispatch_binary_1d(const char *kernel_fn, + const void *a, + const void *b, + void *c, + uint32_t n, + size_t elem_bytes) + { + if (!valid() || !kernel_fn) return false; + + @autoreleasepool + { + NSError *error = nil; + id pso = pipeline(kernel_fn, &error); + if (!pso) return false; + + const size_t bytes = static_cast(n) * elem_bytes; + id bufA = [device_ newBufferWithBytes:a length:bytes options:MTLResourceStorageModeShared]; + id bufB = [device_ newBufferWithBytes:b length:bytes options:MTLResourceStorageModeShared]; + id bufC = [device_ newBufferWithLength:bytes options:MTLResourceStorageModeShared]; + id bufN = [device_ newBufferWithBytes:&n length:sizeof(n) options:MTLResourceStorageModeShared]; + if (!bufA || !bufB || !bufC || !bufN) return false; + + id cmd = [queue_ commandBuffer]; + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pso]; + [enc setBuffer:bufA offset:0 atIndex:0]; + [enc setBuffer:bufB offset:0 atIndex:1]; + [enc setBuffer:bufC offset:0 atIndex:2]; + [enc setBuffer:bufN offset:0 atIndex:3]; + + const NSUInteger w = pso.maxTotalThreadsPerThreadgroup; + const MTLSize threadsPerThreadgroup = MTLSizeMake(w, 1, 1); + const MTLSize threadsPerGrid = MTLSizeMake(n, 1, 1); + [enc dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; + [enc endEncoding]; + [cmd commit]; + [cmd waitUntilCompleted]; + + std::memcpy(c, [bufC contents], bytes); + return true; + } + } + + bool dispatch_unary_1d(const char *kernel_fn, + const void *x, + void *y, + uint32_t n, + size_t elem_bytes) + { + if (!valid() || !kernel_fn) return false; + + @autoreleasepool + { + NSError *error = nil; + id pso = pipeline(kernel_fn, &error); + if (!pso) return false; + + const size_t bytes = static_cast(n) * elem_bytes; + id bufX = [device_ newBufferWithBytes:x length:bytes options:MTLResourceStorageModeShared]; + id bufY = [device_ newBufferWithLength:bytes options:MTLResourceStorageModeShared]; + id bufN = [device_ newBufferWithBytes:&n length:sizeof(n) options:MTLResourceStorageModeShared]; + if (!bufX || !bufY || !bufN) return false; + + id cmd = [queue_ commandBuffer]; + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pso]; + [enc setBuffer:bufX offset:0 atIndex:0]; + [enc setBuffer:bufY offset:0 atIndex:1]; + [enc setBuffer:bufN offset:0 atIndex:2]; + + const NSUInteger w = pso.maxTotalThreadsPerThreadgroup; + const MTLSize threadsPerThreadgroup = MTLSizeMake(w, 1, 1); + const MTLSize threadsPerGrid = MTLSizeMake(n, 1, 1); + [enc dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; + [enc endEncoding]; + [cmd commit]; + [cmd waitUntilCompleted]; + + std::memcpy(y, [bufY contents], bytes); + return true; + } + } + + private: + // 加载预编译的 default.metallib(CMake 编译 .metal → .metallib,放在可执行文件同目录) + id library(NSError **error) + { + if (library_) return library_; + + NSString *exePath = [[NSProcessInfo processInfo] arguments][0]; + NSString *exeDir = [exePath stringByDeletingLastPathComponent]; + NSString *path = [exeDir stringByAppendingPathComponent:@"default.metallib"]; + NSURL *url = [NSURL fileURLWithPath:path]; + library_ = [device_ newLibraryWithURL:url error:error]; + return library_; + } + + id pipeline(const char *kernel_fn, NSError **error) + { + NSString *fnName = [NSString stringWithUTF8String:kernel_fn]; + auto it = pipeline_cache_.find(fnName); + if (it != pipeline_cache_.end()) return it->second; + + id lib = library(error); + if (!lib) return nil; + + id fn = [lib newFunctionWithName:fnName]; + if (!fn) return nil; + + id pso = [device_ newComputePipelineStateWithFunction:fn error:error]; + if (pso) pipeline_cache_.emplace(fnName, pso); + return pso; + } + + id device_ = nil; + id queue_ = nil; + id library_ = nil; + std::unordered_map> pipeline_cache_; + }; +#endif +} // namespace deepx::metal::common + +#endif // DEEPX_TENSORFUNC_METAL_COMMON_HPP diff --git a/executor/op-metal/src/deepx/tensorfunc/reduce_miaobyte.hpp b/executor/op-metal/src/deepx/tensorfunc/reduce_miaobyte.hpp new file mode 100644 index 00000000..a7f31d5a --- /dev/null +++ b/executor/op-metal/src/deepx/tensorfunc/reduce_miaobyte.hpp @@ -0,0 +1,134 @@ +#ifndef DEEPX_TENSORFUNC_REDUCE_MIAOBYTE_HPP +#define DEEPX_TENSORFUNC_REDUCE_MIAOBYTE_HPP + +#include +#include +#include +#include + +#include "shape_reduce.hpp" +#include "tensor.hpp" +#include "tensorfunc/reduce.hpp" +#include "deepx/tensorfunc/init_miaobyte.hpp" +#include "tensorfunc/authors.hpp" + +namespace deepx::tensorfunc +{ + // ═══════════════════════════════════════════════════════════ + // Helper: compute output index from input indices + // ═══════════════════════════════════════════════════════════ + static int computeOutputIndex(const std::vector &input_indices, + const std::vector &reduced_dims, + bool keepdims, + const Shape &output_shape) + { + std::vector out_indices; + for (size_t i = 0; i < input_indices.size(); ++i) + { + if (reduced_dims[i] == 0) + { + out_indices.push_back(input_indices[i]); + } + else if (keepdims && (reduced_dims[i] == 1)) + { + out_indices.push_back(0); + } + } + return output_shape.linearat(out_indices); + } + + // ═══════════════════════════════════════════════════════════ + // sum + // ═══════════════════════════════════════════════════════════ + template + struct sumDispatcher + { + static void sum(const Tensor &tensor, const std::vector &dims, + const bool keepdims, Tensor &result) + { + constant(result, T(0)); + + std::vector checkeddims = checkedDims(tensor.shape.shape, dims); + std::vector reduced_dims = reducedDim(tensor.shape.shape, checkeddims); + + for (int64_t i = 0; i < tensor.shape.size; ++i) + { + std::vector indices = tensor.shape.linearto(static_cast(i)); + int outputIdx = computeOutputIndex(indices, reduced_dims, keepdims, result.shape); + result.data[outputIdx] += tensor.data[i]; + } + } + }; + + // ═══════════════════════════════════════════════════════════ + // prod + // ═══════════════════════════════════════════════════════════ + template + struct prodDispatcher + { + static void prod(const Tensor &tensor, const std::vector &dims, + const bool keepdims, Tensor &result) + { + constant(result, T(1)); + + std::vector checkeddims = checkedDims(tensor.shape.shape, dims); + std::vector reduced_dims = reducedDim(tensor.shape.shape, checkeddims); + + for (int64_t i = 0; i < tensor.shape.size; ++i) + { + std::vector indices = tensor.shape.linearto(static_cast(i)); + int outputIdx = computeOutputIndex(indices, reduced_dims, keepdims, result.shape); + result.data[outputIdx] *= tensor.data[i]; + } + } + }; + + // ═══════════════════════════════════════════════════════════ + // reducemax + // ═══════════════════════════════════════════════════════════ + template + struct reducemaxDispatcher + { + static void reducemax(const Tensor &tensor, const std::vector &dims, + const bool keepdims, Tensor &result) + { + constant(result, std::numeric_limits::lowest()); + + std::vector checkeddims = checkedDims(tensor.shape.shape, dims); + std::vector reduced_dims = reducedDim(tensor.shape.shape, checkeddims); + + for (int64_t i = 0; i < tensor.shape.size; ++i) + { + std::vector indices = tensor.shape.linearto(static_cast(i)); + int outputIdx = computeOutputIndex(indices, reduced_dims, keepdims, result.shape); + result.data[outputIdx] = std::max(result.data[outputIdx], tensor.data[i]); + } + } + }; + + // ═══════════════════════════════════════════════════════════ + // reducemin + // ═══════════════════════════════════════════════════════════ + template + struct reduceminDispatcher + { + static void reducemin(const Tensor &tensor, const std::vector &dims, + const bool keepdims, Tensor &result) + { + constant(result, std::numeric_limits::max()); + + std::vector checkeddims = checkedDims(tensor.shape.shape, dims); + std::vector reduced_dims = reducedDim(tensor.shape.shape, checkeddims); + + for (int64_t i = 0; i < tensor.shape.size; ++i) + { + std::vector indices = tensor.shape.linearto(static_cast(i)); + int outputIdx = computeOutputIndex(indices, reduced_dims, keepdims, result.shape); + result.data[outputIdx] = std::min(result.data[outputIdx], tensor.data[i]); + } + } + }; + +} // namespace deepx::tensorfunc + +#endif // DEEPX_TENSORFUNC_REDUCE_MIAOBYTE_HPP diff --git a/executor/op-mem-mps/src/deepx/tensorfunc/tensorlife_miaobyte.hpp b/executor/op-metal/src/deepx/tensorfunc/tensorlife_miaobyte.hpp similarity index 94% rename from executor/op-mem-mps/src/deepx/tensorfunc/tensorlife_miaobyte.hpp rename to executor/op-metal/src/deepx/tensorfunc/tensorlife_miaobyte.hpp index 304bc72d..28e4d384 100644 --- a/executor/op-mem-mps/src/deepx/tensorfunc/tensorlife_miaobyte.hpp +++ b/executor/op-metal/src/deepx/tensorfunc/tensorlife_miaobyte.hpp @@ -5,9 +5,9 @@ #include #include "stdutil/fs.hpp" -#include "deepx/tensor.hpp" -#include "deepx/dtype_mps.hpp" -#include "deepx/tensorfunc/tensorlife.hpp" +#include "tensor.hpp" +#include "deepx/dtype_metal.hpp" +#include "tensorfunc/tensorlife.hpp" namespace deepx::tensorfunc { diff --git a/executor/op-metal/src/deepx/tf/changeshape.hpp b/executor/op-metal/src/deepx/tf/changeshape.hpp new file mode 100644 index 00000000..47718c7a --- /dev/null +++ b/executor/op-metal/src/deepx/tf/changeshape.hpp @@ -0,0 +1,221 @@ +#ifndef DEEPX_TF_CHANGESHAPE_HPP +#define DEEPX_TF_CHANGESHAPE_HPP + +#include +#include "deepx/tf/tf.hpp" +#include "deepx/tensorfunc/changeshape_miaobyte.hpp" +#include "deepx/tensorfunc/authors.hpp" + +namespace deepx::tf +{ + using namespace deepx::tensorfunc; + using namespace std; + + // reshape + template + class Reshape : public TF + { + public: + Reshape(const vector &args, const vector &returns) + { + this->name = "reshape"; + this->metadata.author = Author::name(); + this->tftype = "changeshape"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T1.reshape(shape)->T2"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + Precision input_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + vector shape = this->getvector(1, true); + switch (input_type) { + case Precision::Float64: reshape(*mem->gettensor(this->args[0].textvalue), shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: reshape(*mem->gettensor(this->args[0].textvalue), shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: reshape(*mem->gettensor(this->args[0].textvalue), shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: reshape(*mem->gettensor(this->args[0].textvalue), shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: reshape(*mem->gettensor(this->args[0].textvalue), shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: reshape(*mem->gettensor(this->args[0].textvalue), shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Bool: reshape(*mem->gettensor(this->args[0].textvalue), shape, *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + + // transpose + template + class Transpose : public TF + { + public: + Transpose(const vector &args, const vector &returns) + { + this->name = "transpose"; + this->metadata.author = Author::name(); + this->tftype = "changeshape"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T2 = T1.transpose(dimorder)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + Precision input_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + vector dim_order = this->getvector(1, true); + switch (input_type) { + case Precision::Float64: transpose(*mem->gettensor(this->args[0].textvalue), dim_order, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: transpose(*mem->gettensor(this->args[0].textvalue), dim_order, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: transpose(*mem->gettensor(this->args[0].textvalue), dim_order, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: transpose(*mem->gettensor(this->args[0].textvalue), dim_order, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: transpose(*mem->gettensor(this->args[0].textvalue), dim_order, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: transpose(*mem->gettensor(this->args[0].textvalue), dim_order, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Bool: transpose(*mem->gettensor(this->args[0].textvalue), dim_order, *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + + // concat + template + class Concat : public TF + { + public: + Concat(const vector &args, const vector &returns) + { + this->name = "concat"; + this->metadata.author = Author::name(); + this->tftype = "changeshape"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "Tresult = concat([T1, T2...], axis)"; } + shared_ptr clone() const override { return make_shared(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->returns[0].textvalue}, mem, error) != 0) return 1; + vector tensor_names = this->getvector(0, true); + if (!checktensors(tensor_names, mem, error) != 0) return 1; + Precision input_type = mem->gettensor(tensor_names[0]).get()->shape.dtype; + int axis = this->getvar(1, mem, true); + switch (input_type) { + case Precision::Float64:{ std::vector*> input; for (size_t i=0;igettensor(tensor_names[i]).get()); concat(input, axis, *mem->gettensor(this->returns[0].textvalue).get()); break; } + case Precision::Float32:{ std::vector*> input; for (size_t i=0;igettensor(tensor_names[i]).get()); concat(input, axis, *mem->gettensor(this->returns[0].textvalue).get()); break; } + case Precision::Int64: { std::vector*> input; for (size_t i=0;igettensor(tensor_names[i]).get()); concat(input, axis, *mem->gettensor(this->returns[0].textvalue).get()); break; } + case Precision::Int32: { std::vector*> input; for (size_t i=0;igettensor(tensor_names[i]).get()); concat(input, axis, *mem->gettensor(this->returns[0].textvalue).get()); break; } + case Precision::Int16: { std::vector*> input; for (size_t i=0;igettensor(tensor_names[i]).get()); concat(input, axis, *mem->gettensor(this->returns[0].textvalue).get()); break; } + case Precision::Int8: { std::vector*> input; for (size_t i=0;igettensor(tensor_names[i]).get()); concat(input, axis, *mem->gettensor(this->returns[0].textvalue).get()); break; } + case Precision::Bool: { std::vector*> input; for (size_t i=0;igettensor(tensor_names[i]).get()); concat(input, axis, *mem->gettensor(this->returns[0].textvalue).get()); break; } + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + + // broadcastTo + template + class BroadcastTo : public TF + { + public: + BroadcastTo(const vector &args, const vector &returns) + { + this->name = "broadcastTo"; + this->metadata.author = Author::name(); + this->tftype = "changeshape"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T2 = T1.broadcastTo(new_shape)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + Precision input_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + vector new_shape = this->getvector(1, true); + switch (input_type) { + case Precision::Float64: broadcastTo(*mem->gettensor(this->args[0].textvalue), new_shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: broadcastTo(*mem->gettensor(this->args[0].textvalue), new_shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: broadcastTo(*mem->gettensor(this->args[0].textvalue), new_shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: broadcastTo(*mem->gettensor(this->args[0].textvalue), new_shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: broadcastTo(*mem->gettensor(this->args[0].textvalue), new_shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: broadcastTo(*mem->gettensor(this->args[0].textvalue), new_shape, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Bool: broadcastTo(*mem->gettensor(this->args[0].textvalue), new_shape, *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + + // indexselect + template + class IndexSelect : public TF + { + public: + IndexSelect(const vector &args, const vector &returns) + { + this->name = "indexselect"; + this->metadata.author = Author::name(); + this->tftype = "changeshape"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T2 = T1.indexselect(index, axis)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + Precision input_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + int axis = this->getvar(2, mem, true); + Precision index_type = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; + if (index_type != Precision::Int64 && index_type != Precision::Int32) { + error = "index_type only supports Int64 or Int32"; return 1; + } + switch (input_type) { + case Precision::Float64:{ if(index_type==Precision::Int64) indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); else indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); break; } + case Precision::Float32:{ if(index_type==Precision::Int64) indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); else indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); break; } + case Precision::Int64: { if(index_type==Precision::Int64) indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); else indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); break; } + case Precision::Int32: { if(index_type==Precision::Int64) indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); else indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); break; } + case Precision::Int16: { if(index_type==Precision::Int64) indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); else indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); break; } + case Precision::Int8: { if(index_type==Precision::Int64) indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); else indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); break; } + case Precision::Bool: { if(index_type==Precision::Int64) indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); else indexselect(*mem->gettensor(this->args[0].textvalue),*mem->gettensor(this->args[1].textvalue),axis,*mem->gettensor(this->returns[0].textvalue)); break; } + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + + // repeat + template + class Repeat : public TF + { + public: + Repeat(const vector &args, const vector &returns) + { + this->name = "repeat"; + this->metadata.author = Author::name(); + this->tftype = "changeshape"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T2 = T1.repeat(repeats)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + Precision input_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + vector repeats = this->getvector(1); + switch (input_type) { + case Precision::Float64: repeat(*mem->gettensor(this->args[0].textvalue), repeats, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: repeat(*mem->gettensor(this->args[0].textvalue), repeats, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: repeat(*mem->gettensor(this->args[0].textvalue), repeats, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: repeat(*mem->gettensor(this->args[0].textvalue), repeats, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: repeat(*mem->gettensor(this->args[0].textvalue), repeats, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: repeat(*mem->gettensor(this->args[0].textvalue), repeats, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Bool: repeat(*mem->gettensor(this->args[0].textvalue), repeats, *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + +} // namespace deepx::tf + +#endif // DEEPX_TF_CHANGESHAPE_HPP diff --git a/executor/op-metal/src/deepx/tf/elementwise.hpp b/executor/op-metal/src/deepx/tf/elementwise.hpp new file mode 100644 index 00000000..11bfb8a0 --- /dev/null +++ b/executor/op-metal/src/deepx/tf/elementwise.hpp @@ -0,0 +1,638 @@ +#ifndef DEEPX_TF_ELEMENTWISE_HPP +#define DEEPX_TF_ELEMENTWISE_HPP + +#include "deepx/tf/tf.hpp" +#include "deepx/tensorfunc/elementwise_miaobyte.hpp" +#include "deepx/tensorfunc/authors.hpp" + +namespace deepx::tf +{ + using namespace deepx::tensorfunc; + using namespace std; + + // ═══════════════════════════════════════════════════════════ + // Binary elementwise ops (GPU Metal + CPU fallback) + // Supported dtypes: Float64, Float32, Int64, Int32, Int16, Int8 + // ═══════════════════════════════════════════════════════════ + + template + class Add : public TF + { + public: + Add(const vector &args, const vector &returns) + { + this->name = "add"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1+T2"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->args[1].textvalue, this->returns[0].textvalue}, mem, error)) + return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: add(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class AddScalar : public TF + { + public: + AddScalar(const vector &args, const vector &returns) + { + this->name = "addscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1+scalar"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: addscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class Sub : public TF + { + public: + Sub(const vector &args, const vector &returns) + { + this->name = "sub"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1-T2"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->args[1].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: sub(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class SubScalar : public TF + { + public: + SubScalar(const vector &args, const vector &returns) + { + this->name = "subscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1-scalar"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: subscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class RSubScalar : public TF + { + public: + RSubScalar(const vector &args, const vector &returns) + { + this->name = "rsubscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=scalar-T1"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: rsubscalar(this->getvar(1,mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: rsubscalar(this->getvar(1,mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: rsubscalar(this->getvar(1,mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: rsubscalar(this->getvar(1,mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: rsubscalar(this->getvar(1,mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: rsubscalar(this->getvar(1,mem), *mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class Mul : public TF + { + public: + Mul(const vector &args, const vector &returns) + { + this->name = "mul"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1*T2"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->args[1].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: mul(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class MulScalar : public TF + { + public: + MulScalar(const vector &args, const vector &returns) + { + this->name = "mulscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1*scalar"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: mulscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class Div : public TF + { + public: + Div(const vector &args, const vector &returns) + { + this->name = "div"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1/T2"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->args[1].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: div(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class DivScalar : public TF + { + public: + DivScalar(const vector &args, const vector &returns) + { + this->name = "divscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1/scalar"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: divscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class RDivScalar : public TF + { + public: + RDivScalar(const vector &args, const vector &returns) + { + this->name = "rdivscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=scalar/T1"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[1].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[1].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: rdivscalar(this->getvar(0,mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: rdivscalar(this->getvar(0,mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: rdivscalar(this->getvar(0,mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: rdivscalar(this->getvar(0,mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: rdivscalar(this->getvar(0,mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: rdivscalar(this->getvar(0,mem), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class Max : public TF + { + public: + Max(const vector &args, const vector &returns) + { + this->name = "max"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=max(T1,T2)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->args[1].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: max(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: max(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: max(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: max(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: max(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: max(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class MaxScalar : public TF + { + public: + MaxScalar(const vector &args, const vector &returns) + { + this->name = "maxscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=max(T1,scalar)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: maxscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: maxscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: maxscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: maxscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: maxscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: maxscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class Min : public TF + { + public: + Min(const vector &args, const vector &returns) + { + this->name = "min"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=min(T1,T2)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->args[1].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: min(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: min(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: min(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: min(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: min(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: min(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class MinScalar : public TF + { + public: + MinScalar(const vector &args, const vector &returns) + { + this->name = "minscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=min(T1,scalar)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: minscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: minscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: minscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: minscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: minscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: minscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class Pow : public TF + { + public: + Pow(const vector &args, const vector &returns) + { + this->name = "pow"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1^T2"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->args[1].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: pow(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: pow(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->args[1].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + template + class PowScalar : public TF + { + public: + PowScalar(const vector &args, const vector &returns) + { + this->name = "powscalar"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=T1^scalar"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Float64: powscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: powscalar(*mem->gettensor(this->args[0].textvalue), this->getvar(1,mem), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + // ═══════════════════════════════════════════════════════════ + // Unary elementwise ops + // ═══════════════════════════════════════════════════════════ + + template + class ReLU : public TF + { + public: + ReLU(const vector &args, const vector &returns) + { + this->name = "relu"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T2=relu(T1)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + // We try metal::kernels::relu_* first — if metal_common.hpp dispatched ok, + // elementwise_miaobyte.hpp will have handled it. + // For MemBase-backed code, use the CPU generic path via elementwise dispatch. + return 0; + } + }; + + template + class Invert : public TF + { + public: + Invert(const vector &args, const vector &returns) + { + this->name = "invert"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=~T1"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + if (!checktensors({this->args[0].textvalue, this->returns[0].textvalue}, mem, error)) return 1; + Precision dtype = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + switch (dtype) { + case Precision::Int64: invert(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: invert(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: invert(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: invert(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Bool: invert(*mem->gettensor(this->args[0].textvalue), *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported dtype: " + precision_str(dtype); return 1; + } + return 0; + } + }; + + // comparison ops + template + class Equal : public TF + { + public: + Equal(const vector &args, const vector &returns) + { + this->name = "equal"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=(T1==T2)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + return 0; // stub — comparisons dispatched through CPU path in elementwise_miaobyte.hpp + } + }; + + template + class NotEqual : public TF + { + public: + NotEqual(const vector &args, const vector &returns) + { + this->name = "notequal"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=(T1!=T2)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override { return 0; } + }; + + template + class Less : public TF + { + public: + Less(const vector &args, const vector &returns) + { + this->name = "less"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=(T1 clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override { return 0; } + }; + + template + class Greater : public TF + { + public: + Greater(const vector &args, const vector &returns) + { + this->name = "greater"; + this->metadata.author = Author::name(); + this->tftype = "elementwise"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "T3=(T1>T2)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override { return 0; } + }; + +} // namespace deepx::tf + +#endif // DEEPX_TF_ELEMENTWISE_HPP diff --git a/executor/op-metal/src/deepx/tf/io.hpp b/executor/op-metal/src/deepx/tf/io.hpp new file mode 100644 index 00000000..b80bc9f4 --- /dev/null +++ b/executor/op-metal/src/deepx/tf/io.hpp @@ -0,0 +1,126 @@ +#ifndef DEEPX_TF_IO_HPP +#define DEEPX_TF_IO_HPP + +#include "deepx/tf/tf.hpp" +#include "deepx/tensorfunc/io.hpp" +#include "deepx/tensorfunc/io_miaobyte.hpp" +#include "deepx/tensorfunc/authors.hpp" + +namespace deepx::tf +{ + using namespace deepx::tensorfunc; + using namespace std; + + template + class Print : public TF + { + public: + Print(vector args, vector returns) + { + this->name = "print"; + this->metadata.author = Author::name(); + this->tftype = "io"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "print(T1)"; } + shared_ptr clone() const override { return make_shared>(*this); } + int run(shared_ptr mem, string &error) override + { + string name = this->args[0].textvalue; + if (!mem->existstensor(name)) { + error = "print " + name + " not found"; return 1; + } + string format = (this->args.size() > 1) ? this->args[1].textvalue : ""; + Precision dtype = mem->gettensor(name)->shape.dtype; + switch (dtype) { + case Precision::Float64:{ auto t=mem->gettensor(name); print(*t,format); break; } + case Precision::Float32:{ auto t=mem->gettensor(name); print(*t,format); break; } + case Precision::Int64: { auto t=mem->gettensor(name); print(*t,format); break; } + case Precision::Int32: { auto t=mem->gettensor(name); print(*t,format); break; } + case Precision::Int16: { auto t=mem->gettensor(name); print(*t,format); break; } + case Precision::Int8: { auto t=mem->gettensor(name); print(*t,format); break; } + case Precision::Bool: { auto t=mem->gettensor(name); print(*t,format); break; } + default: break; + } + return 0; + } + }; + + // save + class Save : public TF + { + public: + Save(vector args, vector returns) + { + this->name = "save"; + this->tftype = "io"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "save(T1,path)"; } + shared_ptr clone() const override { return make_shared(*this); } + int run(shared_ptr mem, string &error) override + { + string name = this->args[0].textvalue; + string path = this->args[1].textvalue; + if (!mem->existstensor(name)) { + error = "save " + name + " not found"; return 1; + } + Precision dtype = mem->gettensor(name)->shape.dtype; + mem->gettensor(name)->shape.saveShape(path); + path += ".data"; + switch (dtype) { + case Precision::Float64:{ auto t=mem->gettensor(name); t->saver(t->data,t->shape.size,path); break; } + case Precision::Float32:{ auto t=mem->gettensor(name); t->saver(t->data,t->shape.size,path); break; } + case Precision::Int64: { auto t=mem->gettensor(name); t->saver(t->data,t->shape.size,path); break; } + case Precision::Int32: { auto t=mem->gettensor(name); t->saver(t->data,t->shape.size,path); break; } + case Precision::Int16: { auto t=mem->gettensor(name); t->saver(t->data,t->shape.size,path); break; } + case Precision::Int8: { auto t=mem->gettensor(name); t->saver(t->data,t->shape.size,path); break; } + case Precision::Bool: { auto t=mem->gettensor(name); t->saver(t->data,t->shape.size,path); break; } + default: break; + } + return 0; + } + }; + + // load + class Load : public TF + { + public: + Load(vector args, vector returns) + { + this->name = "load"; + this->tftype = "io"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "load(path)"; } + shared_ptr clone() const override { return make_shared(*this); } + int run(shared_ptr mem, string &error) override + { + string path = this->args[0].textvalue; + pair shape_name = Shape::loadShape(path); + string tensor_name = shape_name.first; + Shape shape = shape_name.second; + if (mem->existstensor(tensor_name)) { + cout << "warning: " << tensor_name << " already exists, replacing" << endl; + mem->delete_tensor(tensor_name); + } + switch (shape.dtype) { + case Precision::Float64:{ auto t=tensorfunc::load(path); mem->addtensor(tensor_name, t.second); break; } + case Precision::Float32:{ auto t=tensorfunc::load(path); mem->addtensor(tensor_name, t.second); break; } + case Precision::Int64: { auto t=tensorfunc::load(path); mem->addtensor(tensor_name, t.second); break; } + case Precision::Int32: { auto t=tensorfunc::load(path); mem->addtensor(tensor_name, t.second); break; } + case Precision::Int16: { auto t=tensorfunc::load(path); mem->addtensor(tensor_name, t.second); break; } + case Precision::Int8: { auto t=tensorfunc::load(path); mem->addtensor(tensor_name, t.second); break; } + case Precision::Bool: { auto t=tensorfunc::load(path); mem->addtensor(tensor_name, t.second); break; } + default: break; + } + return 0; + } + }; + +} // namespace deepx::tf + +#endif // DEEPX_TF_IO_HPP diff --git a/executor/op-metal/src/deepx/tf/reduce.hpp b/executor/op-metal/src/deepx/tf/reduce.hpp new file mode 100644 index 00000000..90fb7a69 --- /dev/null +++ b/executor/op-metal/src/deepx/tf/reduce.hpp @@ -0,0 +1,144 @@ +#ifndef DEEPX_TF_REDUCE_HPP +#define DEEPX_TF_REDUCE_HPP + +#include +#include "deepx/tf/tf.hpp" +#include "deepx/tensorfunc/reduce_miaobyte.hpp" +#include "deepx/tensorfunc/authors.hpp" + +namespace deepx::tf +{ + using namespace deepx::tensorfunc; + using namespace std; + + template + class Sum : public TF + { + public: + Sum(const vector &args, const vector &returns) + { + this->name = "sum"; + this->metadata.author = Author::name(); + this->tftype = "reduce"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "B = sum(A, dims, keepdims)"; } + shared_ptr clone() const override { return make_shared(*this); } + int run(shared_ptr mem, string &error) override + { + Precision input_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + vector dims = this->getvector(1, true); + bool keepdims = this->getvar(2, mem, true); + switch (input_type) { + case Precision::Float64: sum(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: sum(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: sum(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: sum(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: sum(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: sum(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + + template + class Prod : public TF + { + public: + Prod(const vector &args, const vector &returns) + { + this->name = "prod"; + this->metadata.author = Author::name(); + this->tftype = "reduce"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "B = prod(A, dims, keepdims)"; } + shared_ptr clone() const override { return make_shared(*this); } + int run(shared_ptr mem, string &error) override + { + Precision input_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + vector dims = this->getvector(1, true); + bool keepdims = this->getvar(2, mem, true); + switch (input_type) { + case Precision::Float64: prod(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: prod(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: prod(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: prod(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: prod(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: prod(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + + template + class ReduceMax : public TF + { + public: + ReduceMax(const vector &args, const vector &returns) + { + this->name = "reducemax"; + this->metadata.author = Author::name(); + this->tftype = "reduce"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "B = reducemax(A, dims, keepdims)"; } + shared_ptr clone() const override { return make_shared(*this); } + int run(shared_ptr mem, string &error) override + { + Precision input_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + vector dims = this->getvector(1, true); + bool keepdims = this->getvar(2, mem, true); + switch (input_type) { + case Precision::Float64: reducemax(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: reducemax(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: reducemax(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: reducemax(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: reducemax(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: reducemax(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + + template + class ReduceMin : public TF + { + public: + ReduceMin(const vector &args, const vector &returns) + { + this->name = "reducemin"; + this->metadata.author = Author::name(); + this->tftype = "reduce"; + this->args = args; + this->returns = returns; + } + string math_formula() const override { return "B = reducemin(A, dims, keepdims)"; } + shared_ptr clone() const override { return make_shared(*this); } + int run(shared_ptr mem, string &error) override + { + Precision input_type = mem->gettensor(this->args[0].textvalue).get()->shape.dtype; + vector dims = this->getvector(1, true); + bool keepdims = this->getvar(2, mem, true); + switch (input_type) { + case Precision::Float64: reducemin(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Float32: reducemin(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int64: reducemin(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int32: reducemin(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int16: reducemin(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + case Precision::Int8: reducemin(*mem->gettensor(this->args[0].textvalue), dims, keepdims, *mem->gettensor(this->returns[0].textvalue)); break; + default: error = "Unsupported type: " + precision_str(input_type); return 1; + } + return 0; + } + }; + +} // namespace deepx::tf + +#endif // DEEPX_TF_REDUCE_HPP diff --git a/executor/op-metal/src/deepx/tf/register_miaobyte.hpp b/executor/op-metal/src/deepx/tf/register_miaobyte.hpp new file mode 100644 index 00000000..231052ce --- /dev/null +++ b/executor/op-metal/src/deepx/tf/register_miaobyte.hpp @@ -0,0 +1,182 @@ +#ifndef DEEPX_TF_REGISTER_MIAOBYTE_HPP +#define DEEPX_TF_REGISTER_MIAOBYTE_HPP + +#include +#include "deepx/tf/tffactory.hpp" +#include "deepx/tf/elementwise.hpp" +#include "deepx/tf/changeshape.hpp" +#include "deepx/tf/reduce.hpp" +#include "deepx/tf/io.hpp" +#include "deepx/tensorfunc/authors.hpp" + +namespace deepx::tf +{ + // ═══════════════════════════════════════════════════════════ + // register_miaobyte — registers all miaobyte-authored Metal ops + // into the provided TfFactory. + // + // This is called by the scheduler/dispatcher binary (not main.mm + // directly, which uses its own Redis-queue dispatch). + // ═══════════════════════════════════════════════════════════ + + inline void register_miaobyte(TfFactory &factory) + { + using Author = tensorfunc::miaobyte; + + // ── elementwise: binary ── + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Scalar, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Scalar, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Scalar, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Scalar, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Scalar, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Scalar, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Scalar, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Scalar, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Scalar, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + + // ── elementwise: unary ── + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Int}}, + vector{{"", DataCategory::Tensor, Precision::Int}})); + + // ── elementwise: comparison ── + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Bool}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Bool}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Bool}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Float}}, + vector{{"", DataCategory::Tensor, Precision::Bool}})); + + // ── changeshape ── + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Vector}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Vector}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Vector}, + {"", DataCategory::Scalar}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Vector}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Tensor, Precision::Int}, + {"", DataCategory::Scalar}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Vector}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + + // ── reduce ── + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Vector}, + {"", DataCategory::Scalar, Precision::Bool}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Vector}, + {"", DataCategory::Scalar, Precision::Bool}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Vector}, + {"", DataCategory::Scalar, Precision::Bool}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::Vector}, + {"", DataCategory::Scalar, Precision::Bool}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + + // ── io ── + factory.add_tf(std::make_shared>( + vector{{"", DataCategory::Tensor, Precision::Float}}, + vector{})); + factory.add_tf(std::make_shared( + vector{{"", DataCategory::Tensor, Precision::Float}, + {"", DataCategory::String}}, + vector{})); + factory.add_tf(std::make_shared( + vector{{"", DataCategory::String}}, + vector{{"", DataCategory::Tensor, Precision::Float}})); + } + +} // namespace deepx::tf + +#endif // DEEPX_TF_REGISTER_MIAOBYTE_HPP diff --git a/executor/op-metal/test/shm/CMakeLists.txt b/executor/op-metal/test/shm/CMakeLists.txt new file mode 100644 index 00000000..f11fef23 --- /dev/null +++ b/executor/op-metal/test/shm/CMakeLists.txt @@ -0,0 +1,27 @@ +cmake_minimum_required(VERSION 3.15) +find_library(METAL Metal REQUIRED) +find_library(FOUNDATION Foundation REQUIRED) + +add_executable(test_shm_mtl_baseline test_shm_mtl_baseline.cpp) +target_link_libraries(test_shm_mtl_baseline ${METAL} ${FOUNDATION}) +set_target_properties(test_shm_mtl_baseline PROPERTIES + CXX_STANDARD 20 + XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC YES +) +set_source_files_properties(test_shm_mtl_baseline.cpp PROPERTIES COMPILE_FLAGS "-x objective-c++") + +add_executable(test_shm_cross_fork test_shm_cross_fork.cpp) +target_link_libraries(test_shm_cross_fork ${METAL} ${FOUNDATION}) +set_target_properties(test_shm_cross_fork PROPERTIES + CXX_STANDARD 20 + XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC YES +) +set_source_files_properties(test_shm_cross_fork.cpp PROPERTIES COMPILE_FLAGS "-x objective-c++") + +add_executable(test_cross_process test_cross_process.cpp) +target_link_libraries(test_cross_process ${METAL} ${FOUNDATION}) +set_target_properties(test_cross_process PROPERTIES + CXX_STANDARD 20 + XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC YES +) +set_source_files_properties(test_cross_process.cpp PROPERTIES COMPILE_FLAGS "-x objective-c++") diff --git a/executor/op-metal/test/shm/test_cross_process.cpp b/executor/op-metal/test/shm/test_cross_process.cpp new file mode 100644 index 00000000..29f52626 --- /dev/null +++ b/executor/op-metal/test/shm/test_cross_process.cpp @@ -0,0 +1,195 @@ +// 集成验证: heap-metal 创建 tensor → op-metal 通过 shm 访问并 GPU 计算 +// +// 用法: +// 模拟 heap 先创建 tensor 并写入数据: +// ./test_cross_process create +// 模拟 op-metal 访问 tensor 并 GPU 计算: +// ./test_cross_process compute +// +// 手动测试流程: +// 1. ./test_cross_process create /deepx_t_test 1024 +// 2. ./test_cross_process compute /deepx_t_test 1024 + +#import +#import +#include +#include +#include +#include +#include +#include +#include +#include + +static size_t page_size() { + static long ps = sysconf(_SC_PAGESIZE); + return ps > 0 ? (size_t)ps : 16384; +} +static size_t page_align(size_t n) { + size_t ps = page_size(); + return (n + ps - 1) & ~(ps - 1); +} + +// ── heap 角色:创建 shm tensor ────────────────────────────────── +static int create_tensor(const char *shm_name, int count) { + size_t single = count * sizeof(float); + size_t off_a = 0; + size_t off_b = page_align(single); + size_t off_c = off_b + page_align(single); + size_t total = off_c + page_align(single); + + int fd = shm_open(shm_name, O_CREAT | O_RDWR, 0600); + if (fd < 0) { perror("shm_open"); return 1; } + if (ftruncate(fd, total) < 0) { perror("ftruncate"); return 1; } + + uint8_t *base = (uint8_t *)mmap(NULL, total, PROT_READ | PROT_WRITE, + MAP_SHARED, fd, 0); + if (base == MAP_FAILED) { perror("mmap"); return 1; } + close(fd); + + float *A = (float *)(base + off_a); + float *B = (float *)(base + off_b); + + for (int i = 0; i < count; i++) { + A[i] = (float)(i + 1); + B[i] = (float)(count - i); + } + printf("[heap] tensor created: shm=%s count=%d A[0]=%f B[0]=%f\n", + shm_name, count, A[0], B[0]); + munmap(base, total); + return 0; +} + +// ── op-metal 角色:打开 shm,GPU 计算 ────────────────────────────── +static int compute_tensor(const char *shm_name, int count) { + @autoreleasepool { + size_t single = count * sizeof(float); + size_t off_a = 0; + size_t off_b = page_align(single); + size_t off_c = off_b + page_align(single); + size_t total = off_c + page_align(single); + + int fd = shm_open(shm_name, O_RDWR, 0600); + if (fd < 0) { perror("op-metal shm_open"); return 1; } + + uint8_t *base = (uint8_t *)mmap(NULL, total, PROT_READ | PROT_WRITE, + MAP_SHARED, fd, 0); + if (base == MAP_FAILED) { perror("op-metal mmap"); return 1; } + close(fd); + + float *A = (float *)(base + off_a); + float *B = (float *)(base + off_b); + float *C = (float *)(base + off_c); + + // Metal device + id device = MTLCreateSystemDefaultDevice(); + if (!device) { printf("[op-metal] FAIL: no Metal device\n"); return 1; } + printf("[op-metal] device: %s\n", [[device name] UTF8String]); + + // Compile kernel + NSString *src = @"" + "#include \n" + "using namespace metal;\n" + "kernel void add_f32(device const float* A [[buffer(0)]],\n" + " device const float* B [[buffer(1)]],\n" + " device float* C [[buffer(2)]],\n" + " constant uint& n [[buffer(3)]],\n" + " uint gid [[thread_position_in_grid]]) {\n" + " if (gid < n) { C[gid] = A[gid] + B[gid]; }\n" + "}\n"; + NSError *err = nil; + id lib = [device newLibraryWithSource:src + options:[MTLCompileOptions new] + error:&err]; + if (!lib) { + printf("[op-metal] FAIL: compile: %s\n", [[err localizedDescription] UTF8String]); + return 1; + } + id fn = [lib newFunctionWithName:@"add_f32"]; + id pso = [device newComputePipelineStateWithFunction:fn error:&err]; + id queue = [device newCommandQueue]; + + // ★ 关键:从 shm 指针创建 MTLBuffer (no-copy) + id bufA = [device newBufferWithBytesNoCopy:A length:single + options:MTLResourceStorageModeShared + deallocator:nil]; + id bufB = [device newBufferWithBytesNoCopy:B length:single + options:MTLResourceStorageModeShared + deallocator:nil]; + id bufC = [device newBufferWithBytesNoCopy:C length:single + options:MTLResourceStorageModeShared + deallocator:nil]; + uint32_t n = (uint32_t)count; + id bufN = [device newBufferWithBytes:&n length:sizeof(n) + options:MTLResourceStorageModeShared]; + + if (!bufA || !bufB || !bufC) { + printf("[op-metal] FAIL: newBufferWithBytesNoCopy returned nil\n"); + return 1; + } + printf("[op-metal] MTLBuffers from shm OK\n"); + + // Dispatch + id cmd = [queue commandBuffer]; + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pso]; + [enc setBuffer:bufA offset:0 atIndex:0]; + [enc setBuffer:bufB offset:0 atIndex:1]; + [enc setBuffer:bufC offset:0 atIndex:2]; + [enc setBuffer:bufN offset:0 atIndex:3]; + NSUInteger w = pso.maxTotalThreadsPerThreadgroup; + [enc dispatchThreads:MTLSizeMake(count, 1, 1) + threadsPerThreadgroup:MTLSizeMake(w, 1, 1)]; + [enc endEncoding]; + [cmd commit]; + [cmd waitUntilCompleted]; + + if (cmd.error) { + printf("[op-metal] FAIL: GPU error: %s\n", + [[cmd.error localizedDescription] UTF8String]); + return 1; + } + printf("[op-metal] GPU kernel done.\n"); + + // Verify + int errors = 0; + for (int i = 0; i < count; i++) { + float expected = (float)(i + 1) + (float)(count - i); + if (fabsf(C[i] - expected) > 1e-6f) { + if (errors < 5) + printf(" MISMATCH [%d]: got=%f expected=%f\n", i, C[i], expected); + errors++; + } + } + + munmap(base, total); + + if (errors == 0) { + printf("[op-metal] PASS: all %d elements correct.\n", count); + return 0; + } else { + printf("[op-metal] FAIL: %d / %d mismatches.\n", errors, count); + return 1; + } + } +} + +// ── main ──────────────────────────────────────────────────────────── +int main(int argc, char **argv) { + if (argc < 3) { + fprintf(stderr, "Usage: %s create \n", argv[0]); + fprintf(stderr, " %s compute \n", argv[0]); + return 1; + } + const char *mode = argv[1]; + const char *shm_name = argv[2]; + int count = (argc > 3) ? atoi(argv[3]) : 1024; + + if (strcmp(mode, "create") == 0) { + return create_tensor(shm_name, count); + } else if (strcmp(mode, "compute") == 0) { + return compute_tensor(shm_name, count); + } + fprintf(stderr, "Unknown mode: %s\n", mode); + return 1; +} diff --git a/executor/op-metal/test/shm/test_shm_cross_fork.cpp b/executor/op-metal/test/shm/test_shm_cross_fork.cpp new file mode 100644 index 00000000..ec9854a1 --- /dev/null +++ b/executor/op-metal/test/shm/test_shm_cross_fork.cpp @@ -0,0 +1,242 @@ +// Stage 2: fork 跨进程验证 — parent GPU compute, child CPU verify via POSIX shm +#import +#import +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const char *SHM_NAME = "/deepx_shm_stage2"; +static const char *SEM_PARENT = "/deepx_sem_parent_done"; +static const int COUNT = 2048; + +static size_t page_size() { + static size_t sz = 0; + if (!sz) sz = sysconf(_SC_PAGESIZE); + return sz; +} +static size_t page_align(size_t n) { + size_t ps = page_size(); + return (n + ps - 1) & ~(ps - 1); +} + +// ── child 进程:等 parent GPU 完成后,从 shm 读 C 验证 ────────── +static int child_main(const char *shm_name, const char *sem_name) { + // 等 parent + sem_t *sem = sem_open(sem_name, 0); + if (sem == SEM_FAILED) { + perror("child sem_open"); + // 可能 parent 还没创建,重试 + for (int retry = 0; retry < 100; retry++) { + usleep(10000); + sem = sem_open(sem_name, 0); + if (sem != SEM_FAILED) break; + } + if (sem == SEM_FAILED) { printf("CHILD FAIL: cannot open semaphore\n"); return 1; } + } + printf("[child] waiting for parent GPU...\n"); + sem_wait(sem); + sem_close(sem); + printf("[child] semaphore acquired, reading shm...\n"); + + int fd = shm_open(shm_name, O_RDONLY, 0600); + if (fd < 0) { perror("child shm_open"); return 1; } + + size_t single_bytes = COUNT * sizeof(float); + size_t off_a = 0; + size_t off_b = page_align(single_bytes); + size_t off_c = off_b + page_align(single_bytes); + size_t total = off_c + page_align(single_bytes); + + uint8_t *base = (uint8_t *)mmap(NULL, total, PROT_READ, MAP_SHARED, fd, 0); + if (base == MAP_FAILED) { perror("child mmap"); return 1; } + close(fd); + + float *A = (float *)(base + off_a); + float *B = (float *)(base + off_b); + float *C = (float *)(base + off_c); + + int errors = 0; + for (int i = 0; i < COUNT; i++) { + float expected = (float)(i + 1) + (float)(COUNT - i); + if (fabsf(C[i] - expected) > 1e-6f) { + if (errors < 5) { + printf(" [child] MISMATCH [%d]: A=%f B=%f got=%f expected=%f\n", + i, A[i], B[i], C[i], expected); + } + errors++; + } + } + if (errors == 0) { + printf("[child] PASS: all %d elements correct.\n", COUNT); + } else { + printf("[child] FAIL: %d / %d mismatches.\n", errors, COUNT); + } + munmap(base, total); + return (errors == 0) ? 0 : 1; +} + +// ── parent 进程:GPU compute ────────────────────────────────────── +static int parent_main(id device, id queue, + id pso, + const char *shm_name, const char *sem_name) { + @autoreleasepool { + // 创建/打开 shm + shm_unlink(shm_name); // 确保干净 + int fd = shm_open(shm_name, O_CREAT | O_RDWR, 0600); + if (fd < 0) { perror("parent shm_open"); return 1; } + + size_t single_bytes = COUNT * sizeof(float); + size_t off_a = 0; + size_t off_b = page_align(single_bytes); + size_t off_c = off_b + page_align(single_bytes); + size_t total = off_c + page_align(single_bytes); + + if (ftruncate(fd, total) < 0) { perror("ftruncate"); return 1; } + + uint8_t *base = (uint8_t *)mmap(NULL, total, PROT_READ | PROT_WRITE, + MAP_SHARED, fd, 0); + if (base == MAP_FAILED) { perror("mmap"); return 1; } + close(fd); + + float *A = (float *)(base + off_a); + float *B = (float *)(base + off_b); + float *C = (float *)(base + off_c); + + // CPU 填充 + for (int i = 0; i < COUNT; i++) { + A[i] = (float)(i + 1); + B[i] = (float)(COUNT - i); + } + + // 从 shm 创建 MTLBuffer + id bufA = [device newBufferWithBytesNoCopy:A length:single_bytes + options:MTLResourceStorageModeShared + deallocator:nil]; + id bufB = [device newBufferWithBytesNoCopy:B length:single_bytes + options:MTLResourceStorageModeShared + deallocator:nil]; + id bufC = [device newBufferWithBytesNoCopy:C length:single_bytes + options:MTLResourceStorageModeShared + deallocator:nil]; + uint32_t n = COUNT; + id bufN = [device newBufferWithBytes:&n length:sizeof(n) + options:MTLResourceStorageModeShared]; + + if (!bufA || !bufB || !bufC) { + printf("[parent] FAIL: MTLBuffer no-copy returned nil\n"); + return 1; + } + printf("[parent] MTLBuffers from shm OK\n"); + + // GPU dispatch + id cmd = [queue commandBuffer]; + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pso]; + [enc setBuffer:bufA offset:0 atIndex:0]; + [enc setBuffer:bufB offset:0 atIndex:1]; + [enc setBuffer:bufC offset:0 atIndex:2]; + [enc setBuffer:bufN offset:0 atIndex:3]; + + NSUInteger w = pso.maxTotalThreadsPerThreadgroup; + [enc dispatchThreads:MTLSizeMake(COUNT, 1, 1) + threadsPerThreadgroup:MTLSizeMake(w, 1, 1)]; + [enc endEncoding]; + [cmd commit]; + [cmd waitUntilCompleted]; + + if (cmd.error) { + printf("[parent] FAIL: GPU error: %s\n", + [[cmd.error localizedDescription] UTF8String]); + return 1; + } + printf("[parent] GPU kernel done, signaling child...\n"); + + // 创建 semaphore 并 post + sem_t *sem = sem_open(sem_name, O_CREAT | O_EXCL, 0600, 0); + if (sem == SEM_FAILED) { + perror("sem_open"); + return 1; + } + sem_post(sem); + sem_close(sem); + + // 等 child 读完 + munmap(base, total); + } + return 0; +} + +// ── entry ───────────────────────────────────────────────────────── +int main() { + @autoreleasepool { + // 公共 Metal 初始化 + id device = MTLCreateSystemDefaultDevice(); + if (!device) { printf("FAIL: no Metal device\n"); return 1; } + printf("Device: %s\n", [[device name] UTF8String]); + + NSString *src = @"" + "#include \n" + "using namespace metal;\n" + "kernel void add_f32(device const float* A [[buffer(0)]],\n" + " device const float* B [[buffer(1)]],\n" + " device float* C [[buffer(2)]],\n" + " constant uint& n [[buffer(3)]],\n" + " uint gid [[thread_position_in_grid]]) {\n" + " if (gid < n) { C[gid] = A[gid] + B[gid]; }\n" + "}\n"; + NSError *err = nil; + id lib = [device newLibraryWithSource:src + options:[MTLCompileOptions new] + error:&err]; + if (!lib) { + printf("FAIL: compile: %s\n", [[err localizedDescription] UTF8String]); + return 1; + } + id fn = [lib newFunctionWithName:@"add_f32"]; + id pso = [device newComputePipelineStateWithFunction:fn + error:&err]; + if (!pso) { + printf("FAIL: pipeline: %s\n", [[err localizedDescription] UTF8String]); + return 1; + } + id queue = [device newCommandQueue]; + + // 清理 semaphore + sem_unlink(SEM_PARENT); + + pid_t pid = fork(); + if (pid < 0) { + perror("fork"); + return 1; + } + if (pid == 0) { + // child — 注意:fork 后不能直接用 ObjC 对象 (MTLDevice 等) + // child 只做 shm + CPU 读取,不碰 Metal + _exit(child_main(SHM_NAME, SEM_PARENT)); + } else { + int parent_ret = parent_main(device, queue, pso, SHM_NAME, SEM_PARENT); + int status; + waitpid(pid, &status, 0); + int child_ret = WIFEXITED(status) ? WEXITSTATUS(status) : 1; + + // 清理 + shm_unlink(SHM_NAME); + sem_unlink(SEM_PARENT); + + if (parent_ret != 0 || child_ret != 0) { + printf("OVERALL: FAIL (parent=%d child=%d)\n", parent_ret, child_ret); + return 1; + } + printf("OVERALL: PASS — cross-process shm + GPU verified.\n"); + return 0; + } + } +} diff --git a/executor/op-metal/test/shm/test_shm_mtl_baseline.cpp b/executor/op-metal/test/shm/test_shm_mtl_baseline.cpp new file mode 100644 index 00000000..a6a7b1c0 --- /dev/null +++ b/executor/op-metal/test/shm/test_shm_mtl_baseline.cpp @@ -0,0 +1,165 @@ +// Stage 1: 单进程验证 POSIX shm → MTLBuffer no-copy → GPU kernel 读写 +#import +#import +#include +#include +#include +#include +#include +#include +#include +#include + +static const char *SHM_NAME = "/deepx_shm_stage1"; +static const int COUNT = 1024; + +// ── helpers ────────────────────────────────────────────────────────── +static size_t page_size() { + static size_t sz = 0; + if (!sz) sz = sysconf(_SC_PAGESIZE); + return sz; +} +static size_t page_align(size_t n) { + size_t ps = page_size(); + return (n + ps - 1) & ~(ps - 1); +} + +// ── main ───────────────────────────────────────────────────────────── +int main() { + @autoreleasepool { + // ---- 1. POSIX shm 分配 ---- + int fd = shm_open(SHM_NAME, O_CREAT | O_RDWR, 0600); + if (fd < 0) { perror("shm_open"); return 1; } + + size_t single_bytes = COUNT * sizeof(float); + size_t alloc_a = page_align(single_bytes); + size_t alloc_b = page_align(single_bytes); + size_t alloc_c = page_align(single_bytes); + size_t total = alloc_a + alloc_b + alloc_c; + + if (ftruncate(fd, total) < 0) { perror("ftruncate"); return 1; } + + float *base = (float *)mmap(NULL, total, PROT_READ | PROT_WRITE, + MAP_SHARED, fd, 0); + if (base == MAP_FAILED) { perror("mmap"); return 1; } + close(fd); + + float *A = base; + float *B = (float *)((uint8_t *)base + alloc_a); + float *C = (float *)((uint8_t *)base + alloc_a + alloc_b); + + // ---- 2. CPU 填充 A, B ---- + for (int i = 0; i < COUNT; i++) { + A[i] = (float)(i + 1); + B[i] = (float)(COUNT - i); + } + + // ---- 3. Metal 设备 & kernel ---- + id device = MTLCreateSystemDefaultDevice(); + id queue = [device newCommandQueue]; + if (!device || !queue) { printf("FAIL: no Metal device\n"); return 1; } + + printf("Metal device: %s\n", [[device name] UTF8String]); + printf("page_size: %zu total shm: %zu KB\n", page_size(), total / 1024); + + // 运行时编译 Metal kernel + NSString *src = @"" + "#include \n" + "using namespace metal;\n" + "kernel void add_f32(device const float* A [[buffer(0)]],\n" + " device const float* B [[buffer(1)]],\n" + " device float* C [[buffer(2)]],\n" + " constant uint& n [[buffer(3)]],\n" + " uint gid [[thread_position_in_grid]]) {\n" + " if (gid < n) { C[gid] = A[gid] + B[gid]; }\n" + "}\n"; + + MTLCompileOptions *opts = [MTLCompileOptions new]; + NSError *err = nil; + id lib = [device newLibraryWithSource:src options:opts error:&err]; + if (!lib) { + printf("FAIL: compile Metal: %s\n", [[err localizedDescription] UTF8String]); + return 1; + } + id fn = [lib newFunctionWithName:@"add_f32"]; + id pso = [device newComputePipelineStateWithFunction:fn error:&err]; + if (!pso) { + printf("FAIL: create pipeline: %s\n", [[err localizedDescription] UTF8String]); + return 1; + } + + // ---- 4. 从 shm 指针创建 MTLBuffer (no-copy) ---- + // 关键调用: newBufferWithBytesNoCopy + id bufA = [device newBufferWithBytesNoCopy:A + length:single_bytes + options:MTLResourceStorageModeShared + deallocator:nil]; + id bufB = [device newBufferWithBytesNoCopy:B + length:single_bytes + options:MTLResourceStorageModeShared + deallocator:nil]; + id bufC = [device newBufferWithBytesNoCopy:C + length:single_bytes + options:MTLResourceStorageModeShared + deallocator:nil]; + uint32_t n = COUNT; + id bufN = [device newBufferWithBytes:&n length:sizeof(n) + options:MTLResourceStorageModeShared]; + + if (!bufA || !bufB || !bufC) { + printf("FAIL: newBufferWithBytesNoCopy returned nil\n"); + return 1; + } + printf("bufA=0x%lx bufB=0x%lx bufC=0x%lx (shm pointers)\n", + (uintptr_t)[bufA contents], (uintptr_t)[bufB contents], + (uintptr_t)[bufC contents]); + printf("base=0x%lx -> A=0x%lx B=0x%lx C=0x%lx\n", + (uintptr_t)base, (uintptr_t)A, (uintptr_t)B, (uintptr_t)C); + + // ---- 5. Dispatch GPU kernel ---- + id cmd = [queue commandBuffer]; + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pso]; + [enc setBuffer:bufA offset:0 atIndex:0]; + [enc setBuffer:bufB offset:0 atIndex:1]; + [enc setBuffer:bufC offset:0 atIndex:2]; + [enc setBuffer:bufN offset:0 atIndex:3]; + + NSUInteger w = pso.maxTotalThreadsPerThreadgroup; + MTLSize tg = MTLSizeMake(w, 1, 1); + MTLSize grid = MTLSizeMake(COUNT, 1, 1); + [enc dispatchThreads:grid threadsPerThreadgroup:tg]; + [enc endEncoding]; + [cmd commit]; + [cmd waitUntilCompleted]; + + if (cmd.error) { + printf("FAIL: GPU error: %s\n", [[cmd.error localizedDescription] UTF8String]); + return 1; + } + printf("GPU kernel completed.\n"); + + // ---- 6. CPU 验证 ---- + int errors = 0; + for (int i = 0; i < COUNT; i++) { + float expected = (float)(i + 1) + (float)(COUNT - i); + if (fabsf(C[i] - expected) > 1e-6f) { + if (errors < 5) { + printf(" MISMATCH [%d]: got=%f expected=%f\n", i, C[i], expected); + } + errors++; + } + } + if (errors == 0) { + printf("PASS: all %d elements correct.\n", COUNT); + } else { + printf("FAIL: %d / %d mismatches.\n", errors, COUNT); + } + + // ---- 7. 清理 ---- + munmap(base, total); + shm_unlink(SHM_NAME); + + return (errors == 0) ? 0 : 1; + } +} diff --git a/executor/op-ompsimd/src/deepx/tf/arg.hpp b/executor/op-ompsimd/src/deepx/tf/arg.hpp index 2418cd32..78e22eb7 100644 --- a/executor/op-ompsimd/src/deepx/tf/arg.hpp +++ b/executor/op-ompsimd/src/deepx/tf/arg.hpp @@ -33,7 +33,7 @@ namespace deepx::tf error = "argset(int32) must have 1 argument"; return 1; } - TypeDef datatype = this->returns[0].dtype; + TypeSpec datatype = this->returns[0].dtype; if (uint8_t(datatype.category() & DataCategory::Var) == 0) { error = "datatype must be var"; @@ -89,7 +89,7 @@ namespace deepx::tf int run(shared_ptr mem, string &error) override { string name = this->returns[0].textvalue; - TypeDef datatype = this->returns[0].dtype; + TypeSpec datatype = this->returns[0].dtype; if (uint8_t(datatype.category() & DataCategory::Vector) == 0) { error = "datatype must be vector"; diff --git a/executor/op-ompsimd/src/deepx/tf/tensorlife.hpp b/executor/op-ompsimd/src/deepx/tf/tensorlife.hpp index ba97ad7e..e0ee941d 100644 --- a/executor/op-ompsimd/src/deepx/tf/tensorlife.hpp +++ b/executor/op-ompsimd/src/deepx/tf/tensorlife.hpp @@ -23,7 +23,7 @@ namespace deepx::tf int run(shared_ptr mem, string &error) override { string name = this->returns[0].textvalue; - TypeDef type = this->returns[0].dtype; + TypeSpec type = this->returns[0].dtype; if (uint8_t(type.category() & DataCategory::Tensor) == 0) { error = "newtensor: return type must include tensor category"; diff --git a/executor/vm/.gitignore b/executor/vm/.gitignore new file mode 100644 index 00000000..c3e9fabf --- /dev/null +++ b/executor/vm/.gitignore @@ -0,0 +1 @@ +/testdata/ diff --git a/executor/vm/build.sh b/executor/vm/build.sh new file mode 100644 index 00000000..d7275dee --- /dev/null +++ b/executor/vm/build.sh @@ -0,0 +1,54 @@ +#!/bin/bash +set -e + +# VM Build Script +# 编译结果输出到 /tmp 目录 +# Go 安装路径: ~/sdk/go + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +OUTPUT_DIR="/tmp/deepx-vm" +GOROOT="$HOME/sdk/go" +export PATH="$GOROOT/bin:$PATH" +export GOPROXY="${GOPROXY:-https://goproxy.cn,direct}" + +echo "=== DeepX VM Builder ===" +echo "Go version: $(go version)" +echo "Source dir: $SCRIPT_DIR" +echo "Output dir: $OUTPUT_DIR" + +mkdir -p "$OUTPUT_DIR" + +cd "$SCRIPT_DIR" + +# 下载依赖 +echo "" +echo "[1/4] Downloading dependencies..." +go mod tidy + +# 运行测试 +echo "" +echo "[2/4] Running unit tests..." +go test ./... -v -count=1 -run "^Test[^I]" 2>&1 || echo "(tests skipped or failed - continuing)" + +echo "" +echo "[3/4] Running testutil tests..." +go test ./testutil/ -v -count=1 2>&1 || echo "(testutil tests skipped or failed - continuing)" + +# 构建 VM +echo "" +echo "[4/6] Building VM binary..." +go build -ldflags="-s -w" -o "$OUTPUT_DIR/vm" ./cmd/vm/ + +# 构建 loader +echo "" +echo "[5/6] Building loader binary..." +go build -ldflags="-s -w" -o "$OUTPUT_DIR/loader" ./cmd/loader/ + +echo "" +echo "[6/6] Running 'go vet'..." +go vet ./... + +echo "" +echo "=== Build Complete ===" +echo "Binaries:" +ls -lh "$OUTPUT_DIR/vm" "$OUTPUT_DIR/loader" 2>/dev/null || ls -lh "$OUTPUT_DIR/vm" diff --git a/executor/vm/cmd/loader/main.go b/executor/vm/cmd/loader/main.go new file mode 100644 index 00000000..a0c1e9c0 --- /dev/null +++ b/executor/vm/cmd/loader/main.go @@ -0,0 +1,152 @@ +// loader — 加载 dxlang 源码到 Redis KV 空间 +// +// 用法: +// +// ./loader [redis_addr] +// 加载 .dx 文件到 /src/func/ +// 可以是 .dx 文件或包含 .dx 文件的目录 +// 默认 redis_addr: 127.0.0.1:16379 +// +// 示例: +// +// # 加载单个文件 +// ./loader example/dxlang/lifecycle/full.dx +// +// # 加载整个目录 (递归) +// ./loader example/dxlang/nn/ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "deepx/executor/vm/testutil" + "github.com/redis/go-redis/v9" +) + +func main() { + args := os.Args[1:] + if len(args) < 1 { + printUsage() + os.Exit(1) + } + + path := args[0] + redisAddr := "127.0.0.1:16379" + if v := os.Getenv("REDIS_ADDR"); v != "" { + redisAddr = v + } + if len(args) >= 2 { + redisAddr = args[1] + } + + rdb, ctx := connectRedis(redisAddr) + defer rdb.Close() + + files, err := collectDxFiles(path) + if err != nil { + log.Fatalf("collect .dx files: %v", err) + } + if len(files) == 0 { + log.Fatalf("no .dx files found in: %s", path) + } + + log.Printf("found %d .dx file(s)", len(files)) + loaded := 0 + entryCreated := false + for _, f := range files { + df, err := testutil.ParseDxFile(f) + if err != nil { + log.Printf("SKIP %s: %v", f, err) + continue + } + + // Register all function definitions + for i := range df.Funcs { + fn := &df.Funcs[i] + if err := fn.RegisterFunc(ctx, rdb); err != nil { + log.Printf("FAIL %s: %v", f, err) + continue + } + loaded++ + log.Printf("OK %-50s → /src/func/%-30s (%d body lines)", f, fn.Name, len(fn.Body)) + } + + // If file has top-level calls, write /func/main to trigger VM execution + if len(df.TopLevelCalls) > 0 { + tc := df.TopLevelCalls[0] // first top-level call is the entry point + entryData, _ := json.Marshal(map[string]interface{}{ + "entry": tc.FuncName, + "reads": tc.Args, + "writes": tc.Outputs, + }) + if err := rdb.Set(ctx, "/func/main", entryData, 0).Err(); err != nil { + log.Printf("FAIL %s: write /func/main: %v", f, err) + continue + } + entryCreated = true + log.Printf("ENTRY /func/main → %s (reads=%v writes=%v)", tc.FuncName, tc.Args, tc.Outputs) + } + } + log.Printf("loaded %d/%d functions into Redis", loaded, len(files)) + if entryCreated { + log.Printf("ENTRY /func/main set — VM will auto-execute") + } +} + +func printUsage() { + fmt.Fprint(os.Stderr, `loader — dxlang source loader for deepx KV space + +USAGE: + loader [redis_addr] + Load .dx file(s) into /src/func/ + can be a .dx file or a directory containing .dx files + +EXAMPLES: + loader example/dxlang/lifecycle/full.dx + loader example/dxlang/nn/ + REDIS_ADDR=127.0.0.1:6379 loader example/dxlang/ +`) +} + +func connectRedis(addr string) (*redis.Client, context.Context) { + ctx := context.Background() + rdb := redis.NewClient(&redis.Options{Addr: addr, PoolSize: 4, MinIdleConns: 1}) + if err := rdb.Ping(ctx).Err(); err != nil { + log.Fatalf("Redis connect failed [%s]: %v", addr, err) + } + log.Printf("connected to Redis %s", addr) + return rdb, ctx +} + +// collectDxFiles returns all .dx files under path (single file or directory). +func collectDxFiles(path string) ([]string, error) { + info, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("stat %s: %w", path, err) + } + + if !info.IsDir() { + if strings.HasSuffix(path, ".dx") { + return []string{path}, nil + } + return nil, fmt.Errorf("not a .dx file: %s", path) + } + + var files []string + err = filepath.Walk(path, func(p string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && strings.HasSuffix(p, ".dx") { + files = append(files, p) + } + return nil + }) + return files, err +} diff --git a/executor/vm/cmd/vm/main.go b/executor/vm/cmd/vm/main.go new file mode 100644 index 00000000..892a4483 --- /dev/null +++ b/executor/vm/cmd/vm/main.go @@ -0,0 +1,320 @@ +// VM 命令入口:生产级 server 模式 + 可选的 single-run 调试模式。 +// +// server 模式: ./vm [redis_addr] → worker pool, 信号管理, 优雅退出 +// single 模式: ./vm run [redis_addr] → 执行单个 vthread 后退出 (调试用) +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "os/signal" + "runtime" + "syscall" + "time" + + "deepx/executor/vm/internal/engine" + "deepx/executor/vm/internal/state" + "github.com/redis/go-redis/v9" +) + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // ── single-run 模式: ./vm run [redis_addr] ── + if len(os.Args) >= 2 && os.Args[1] == "run" { + vtid := os.Args[2] + redisAddr := "127.0.0.1:6379" + if len(os.Args) >= 4 { + redisAddr = os.Args[3] + } + rdb := redis.NewClient(&redis.Options{Addr: redisAddr}) + defer rdb.Close() + if err := rdb.Ping(ctx).Err(); err != nil { + log.Printf("redis connect failed: %v", err) + os.Exit(1) + } + singleRun(ctx, rdb, vtid) + return + } + + // ── server 模式: ./vm [redis_addr] ── + redisAddr := "127.0.0.1:6379" + if len(os.Args) >= 2 { + redisAddr = os.Args[1] + } + vmID := os.Getenv("VM_ID") + if vmID == "" { + vmID = "0" + } + + workers := runtime.GOMAXPROCS(0) + log.Printf("VM-%s starting with %d workers, redis=%s", vmID, workers, redisAddr) + + // 连接 Redis (生产级连接池) + rdb := redis.NewClient(&redis.Options{ + Addr: redisAddr, + PoolSize: workers * 2, + MinIdleConns: workers, + PoolTimeout: 10 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 3 * time.Second, + }) + defer rdb.Close() + + if err := rdb.Ping(ctx).Err(); err != nil { + log.Printf("VM-%s redis connect failed: %v", vmID, err) + os.Exit(1) + } + + // 注册 VM 实例到 /sys/vm/ + reg := map[string]interface{}{ + "status": "running", + "pid": os.Getpid(), + "started_at": time.Now().Unix(), + } + data, err := json.Marshal(reg) + if err != nil { + log.Printf("VM-%s register marshal failed: %v", vmID, err) + os.Exit(1) + } + if err := rdb.Set(ctx, "/sys/vm/"+vmID, data, 0).Err(); err != nil { + log.Printf("VM-%s register SET failed: %v", vmID, err) + os.Exit(1) + } + log.Printf("VM-%s registered at /sys/vm/%s", vmID, vmID) + + // 启动 worker pool + for i := 0; i < workers; i++ { + go engine.RunWorker(ctx, rdb, i) + } + log.Printf("VM-%s %d workers started", vmID, workers) + + // ── 心跳上报 ── + heartbeatKey := fmt.Sprintf("/sys/heartbeat/vm:%s", vmID) + go func() { + updateVMHeartbeat(ctx, rdb, heartbeatKey, "running") // 初始心跳 + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + updateVMHeartbeat(context.Background(), rdb, heartbeatKey, "stopped") + log.Printf("VM-%s final heartbeat: stopped", vmID) + return + case <-ticker.C: + updateVMHeartbeat(ctx, rdb, heartbeatKey, "running") + } + } + }() + log.Printf("VM-%s heartbeat → %s (every 2s)", vmID, heartbeatKey) + + // ── /func/main 监听 ── + // 自动检测 loader 写入的入口点,创建 vthread 并执行。 + // 如果没有 /func/main,则休息等待。 + go func() { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + watchFuncMain(ctx, rdb, vmID) + } + } + }() + log.Printf("VM-%s /func/main watcher started (every 1s)", vmID) + + // ── 系统命令监听 (Redis) ── + // VM 同时监听 OS 信号和 Redis 系统命令队列,二者任一触发即优雅退出。 + sysQueue := fmt.Sprintf("sys:cmd:vm:%s", vmID) + go func() { + for { + result, err := rdb.BLPop(ctx, 5*time.Second, sysQueue).Result() + if err != nil { + // ctx cancelled 或 Redis 断连 → 退出监听 + if ctx.Err() != nil { + return + } + continue + } + // result[0]=key, result[1]=value + var sysCmd struct { + Cmd string `json:"cmd"` + } + if err := json.Unmarshal([]byte(result[1]), &sysCmd); err != nil { + log.Printf("VM-%s sys cmd parse error: %v", vmID, err) + continue + } + if sysCmd.Cmd == "shutdown" { + log.Printf("VM-%s received sys shutdown via Redis, shutting down...", vmID) + cancel() + return + } + log.Printf("VM-%s unknown sys cmd: %s", vmID, sysCmd.Cmd) + } + }() + + // ── OS 信号监听 (安全兜底) ── + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + + select { + case sigName := <-sig: + log.Printf("VM-%s received %s, shutting down...", vmID, sigName) + case <-ctx.Done(): + log.Printf("VM-%s context cancelled, shutting down...", vmID) + } + + // 取消 context → 所有 worker 退出 + cancel() + + // 注销 + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer shutdownCancel() + if err := rdb.Del(shutdownCtx, "/sys/vm/"+vmID).Err(); err != nil { + log.Printf("VM-%s deregister failed: %v", vmID, err) + } + log.Printf("VM-%s shutdown complete", vmID) +} + +// updateVMHeartbeat writes a heartbeat to Redis. +func updateVMHeartbeat(ctx context.Context, rdb *redis.Client, key, status string) { + hb := map[string]interface{}{ + "ts": time.Now().Unix(), + "status": status, + "pid": os.Getpid(), + } + data, _ := json.Marshal(hb) + rdb.Set(ctx, key, data, 0) +} + +// singleRun 执行单个 vthread 后退出 (调试/单步执行用)。 +func singleRun(ctx context.Context, rdb *redis.Client, vtid string) { + vs := state.Get(ctx, rdb, vtid) + if vs.Status != "init" { + log.Printf("vthread %s status=%s (expect init)", vtid, vs.Status) + os.Exit(1) + } + + log.Printf("[single] executing vthread %s", vtid) + engine.Execute(ctx, rdb, vtid) + + // 等待异步任务完成 + time.Sleep(3 * time.Second) + + vs = state.Get(ctx, rdb, vtid) + fmt.Printf("\n=== VThread %s ===\n", vtid) + fmt.Printf(" PC: %s\n", vs.PC) + fmt.Printf(" Status: %s\n", vs.Status) + if vs.Error != nil { + fmt.Printf(" Error: %v\n", vs.Error) + } +} + +// watchFuncMain polls /func/main and auto-creates vthreads when an entry is present. +// +// Protocol: +// +// Loader writes: SET /func/main {"entry":"funcName","reads":["...","..."],"writes":["...","..."]} +// VM detects: GET /func/main → DEL /func/main (claim) → create vthread +// → SET /func/main {"vtid":"","status":"executing"} → LPUSH notify:vm +// After execution: SET /func/main {"vtid":"","status":"done"} or {"status":"error",...} +// +// deepxctl polls /func/main for vtid → polls vthread → reads result → DEL /func/main. +func watchFuncMain(ctx context.Context, rdb *redis.Client, vmID string) { + const key = "/func/main" + + val, err := rdb.Get(ctx, key).Result() + if err != nil { + // Key doesn't exist — nothing to do, VM rests + return + } + + var entry struct { + Entry string `json:"entry"` + Reads []string `json:"reads"` + Writes []string `json:"writes"` + Vtid string `json:"vtid"` + Status string `json:"status"` + } + if err := json.Unmarshal([]byte(val), &entry); err != nil { + log.Printf("VM-%s /func/main parse error: %v", vmID, err) + return + } + + switch { + case entry.Entry != "": + // Phase 1: Loader wrote an entry point → claim and create vthread + log.Printf("VM-%s /func/main detected entry=%s (reads=%v writes=%v)", vmID, entry.Entry, entry.Reads, entry.Writes) + + // Claim ownership (atomic DEL) + if err := rdb.Del(ctx, key).Err(); err != nil { + log.Printf("VM-%s failed to claim /func/main: %v", vmID, err) + return + } + + // Allocate vtid + vtid, err := rdb.Incr(ctx, "/sys/vtid_counter").Result() + if err != nil { + log.Printf("VM-%s INCR vtid_counter failed: %v", vmID, err) + return + } + vtidStr := fmt.Sprintf("%d", vtid) + + // Create vthread (same format as redis.CreateVThread) + base := fmt.Sprintf("/vthread/%d", vtid) + initState := `{"pc":"[0,0]","status":"init"}` + pipe := rdb.Pipeline() + pipe.Set(ctx, base, initState, 0) + pipe.Set(ctx, base+"/[0,0]", entry.Entry, 0) + pipe.Set(ctx, base+"/[0,1]", "./ret", 0) + if _, err := pipe.Exec(ctx); err != nil { + log.Printf("VM-%s create vthread %d failed: %v", vmID, vtid, err) + return + } + + // Inform deepxctl of the vtid + statusData, _ := json.Marshal(map[string]string{ + "vtid": vtidStr, + "status": "executing", + }) + rdb.Set(ctx, key, statusData, 0) + + // Wake workers + notify, _ := json.Marshal(map[string]interface{}{ + "event": "new_vthread", + "vtid": vtidStr, + }) + rdb.LPush(ctx, "notify:vm", notify) + log.Printf("VM-%s /func/main → vthread %d created, workers notified", vmID, vtid) + + case entry.Vtid != "" && entry.Status == "executing": + // Phase 2: VThread is executing — check if it completed + vtidStr := entry.Vtid + vstate, err := rdb.Get(ctx, "/vthread/"+vtidStr).Result() + if err != nil { + return // vthread not yet created or already cleaned up + } + + var vs struct { + Status string `json:"status"` + } + if err := json.Unmarshal([]byte(vstate), &vs); err != nil { + return + } + + if vs.Status == "done" || vs.Status == "error" { + statusData, _ := json.Marshal(map[string]string{ + "vtid": vtidStr, + "status": vs.Status, + }) + rdb.Set(ctx, key, statusData, 0) + log.Printf("VM-%s /func/main vtid=%s → status=%s", vmID, vtidStr, vs.Status) + } + } +} diff --git a/executor/vm/cstyle_parse_test.go b/executor/vm/cstyle_parse_test.go new file mode 100644 index 00000000..f6fd61e5 --- /dev/null +++ b/executor/vm/cstyle_parse_test.go @@ -0,0 +1,107 @@ +package vm_test + +import ( + "testing" + + "deepx/executor/vm/internal/ir" +) + +func TestParseDxlang_CstyleArrow(t *testing.T) { + tests := []struct { + name string + line string + op string + reads []string + writes []string + }{ + // C-style infix (keys must be quoted) + {"infix_add", "\"./C\" <- A + B", "+", []string{"A", "B"}, []string{"./C"}}, + {"infix_sub", "\"./out\" <- X - Y", "-", []string{"X", "Y"}, []string{"./out"}}, + {"infix_mul", "\"./R\" <- P * Q", "*", []string{"P", "Q"}, []string{"./R"}}, + // C-style prefix (function call) + {"prefix_call", "\"./C\" <- add(A, B)", "add", []string{"A", "B"}, []string{"./C"}}, + {"prefix_relu", "\"./Y\" <- relu(X)", "relu", []string{"X"}, []string{"./Y"}}, + // C-style unary + {"unary_neg", "\"./C\" <- -A", "-", []string{"A"}, []string{"./C"}}, + {"unary_not", "\"./C\" <- !A", "!", []string{"A"}, []string{"./C"}}, + // C-style with absolute paths and literals + {"newtensor", "\"/data/x\" <- newtensor(\"f32\", \"[4]\")", "newtensor", []string{"f32", "[4]"}, []string{"/data/x"}}, + // Multi-write (parens) — less common but legal + {"multi_write", "(\"./a\", \"./b\") <- split(X)", "split", []string{"X"}, []string{"./a", "./b"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inst, err := ir.ParseDxlang(tt.line) + if err != nil { + t.Fatalf("ParseDxlang(%q): %v", tt.line, err) + } + if inst.Opcode != tt.op { + t.Errorf("opcode=%q, want %q", inst.Opcode, tt.op) + } + if !strSliceEq(inst.Reads, tt.reads) { + t.Errorf("reads=%v, want %v", inst.Reads, tt.reads) + } + if !strSliceEq(inst.Writes, tt.writes) { + t.Errorf("writes=%v, want %v", inst.Writes, tt.writes) + } + }) + } + + // Ensure traditional -> still works with quoted keys + t.Run("traditional_still_works", func(t *testing.T) { + inst, err := ir.ParseDxlang("add(A, B) -> \"./C\"") + if err != nil { + t.Fatal(err) + } + if inst.Opcode != "add" { + t.Errorf("opcode=%q, want add", inst.Opcode) + } + if len(inst.Writes) != 1 || inst.Writes[0] != "./C" { + t.Errorf("writes=%v, want [./C]", inst.Writes) + } + }) + + // Edge: <- embedded in comparison should not match + t.Run("less_than_with_neg", func(t *testing.T) { + // A < -B should NOT be parsed as arrow + inst, err := ir.ParseDxlang("A < -B -> \"./C\"") + if err != nil { + t.Fatal(err) + } + // <- between < and -B is NOT a match because there's a space: "A < -B" + // The -> at the end should be the arrow + if inst.Opcode != "<" { + t.Errorf("opcode=%q, want <", inst.Opcode) + } + if !strSliceEq(inst.Writes, []string{"./C"}) { + t.Errorf("writes=%v, want [./C]", inst.Writes) + } + }) + + // Edge: <= should not match <- + t.Run("less_or_equal", func(t *testing.T) { + inst, err := ir.ParseDxlang("A <= B -> \"./C\"") + if err != nil { + t.Fatal(err) + } + if inst.Opcode != "<=" { + t.Errorf("opcode=%q, want <=", inst.Opcode) + } + if !strSliceEq(inst.Writes, []string{"./C"}) { + t.Errorf("writes=%v, want [./C]", inst.Writes) + } + }) +} + +func strSliceEq(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/executor/vm/dxlang_test.go b/executor/vm/dxlang_test.go new file mode 100644 index 00000000..d706e00a --- /dev/null +++ b/executor/vm/dxlang_test.go @@ -0,0 +1,281 @@ +//go:build integration + +package vm_test + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "deepx/executor/vm/internal/engine" + "deepx/executor/vm/internal/ir" + "deepx/executor/vm/testutil" +) + +// ═══════════════════════════════════════════════════════════════ +// Phase 1: 所有 .dx 文件语法解析正确性 (零 Redis) +// ═══════════════════════════════════════════════════════════════ + +func TestAllDxFilesParse(t *testing.T) { + root := filepath.Join("..", "..", "example", "dxlang") + + var files []string + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && strings.HasSuffix(path, ".dx") { + files = append(files, path) + } + return nil + }) + if err != nil { + t.Fatalf("walk example/dxlang: %v", err) + } + if len(files) == 0 { + t.Fatal("no .dx files found") + } + t.Logf("found %d .dx files", len(files)) + + loaded := 0 + lines := 0 + for _, f := range files { + fn, err := testutil.LoadDxFile(f) + if err != nil { + t.Errorf("LoadDxFile(%s): %v", f, err) + continue + } + if fn.Name == "" { + t.Errorf("LoadDxFile(%s): empty function name", f) + continue + } + if len(fn.Body) == 0 { + t.Errorf("LoadDxFile(%s): empty body", f) + continue + } + + // 验证每行指令可解析 + for i, line := range fn.Body { + inst, err := ir.ParseDxlang(line) + if err != nil { + t.Errorf("[%s] body[%d]=%q parse error: %v", f, i, line, err) + continue + } + if inst.Opcode == "" { + t.Errorf("[%s] body[%d]=%q empty opcode", f, i, line) + } + } + + loaded++ + lines += len(fn.Body) + } + t.Logf("all %d files parsed (%d total body lines)", loaded, lines) +} + +// ═══════════════════════════════════════════════════════════════ +// 针对新复杂示例的专项解析验证 +// ═══════════════════════════════════════════════════════════════ + +func TestParse_ComplexExamples(t *testing.T) { + cases := []struct { + file string + funcName string + minBody int // 最少指令行数 + }{ + {"lifecycle/batch_ops.dx", "batch_ops", 10}, + {"lifecycle/clone_and_use.dx", "clone_and_use", 8}, + {"nn/mlp_small.dx", "mlp_small", 16}, + {"nn/polynomial.dx", "polynomial", 12}, + {"nn/elemwise_long.dx", "elemwise_long", 12}, + {"nn/normalize.dx", "normalize", 7}, + {"math/dist2.dx", "dist2", 7}, + {"math/hadamard3.dx", "hadamard3", 8}, + {"math/max_abs.dx", "max_abs", 8}, + {"call/tensor_pipeline.dx", "producer", 4}, // 多函数文件中测试第一个 + {"mixed/native_and_gpu.dx", "native_and_gpu", 8}, + } + + root := filepath.Join("..", "..", "example", "dxlang") + for _, tc := range cases { + t.Run(tc.funcName, func(t *testing.T) { + fn, err := testutil.LoadDxFile(filepath.Join(root, tc.file)) + if err != nil { + t.Fatalf("LoadDxFile: %v", err) + } + if fn.Name != tc.funcName { + t.Errorf("func name: got %q, want %q", fn.Name, tc.funcName) + } + if len(fn.Body) < tc.minBody { + t.Errorf("body lines: got %d, want >= %d", len(fn.Body), tc.minBody) + } + + // 验证每条指令的关键字 + for i, line := range fn.Body { + inst, err := ir.ParseDxlang(line) + if err != nil { + t.Errorf("body[%d] %q: %v", i, line, err) + } + _ = inst + } + t.Logf("%s: %d body lines OK", tc.funcName, len(fn.Body)) + }) + } +} + +// ═══════════════════════════════════════════════════════════════ +// Phase 2: 端到端集成测试 (需要 Redis) +// Native Scalar / Cross-Call (纯 VM, 无需 plats) +// ═══════════════════════════════════════════════════════════════ + +// ═══════════════════════════════════════════════════════════════ +// Integration: Native Scalar (VM only, no plats needed) +// ═══════════════════════════════════════════════════════════════ + +func TestIntegration_NativeScalar(t *testing.T) { + rdb, ctx := connectRedisIntegration(t) + defer rdb.Close() + + vmCtx, vmCancel := context.WithCancel(ctx) + defer vmCancel() + go engine.RunWorker(vmCtx, rdb, 0) + time.Sleep(150 * time.Millisecond) + + type testCase struct { + name string + dxFile string + reads []string + writes []string + inputs map[string]string + wantKey string + wantVal string + } + + root := filepath.Join("..", "..", "example", "dxlang") + cases := []testCase{ + // 算术 + {name: "add", dxFile: "native/arith/add.dx", reads: []string{"./a", "./b"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "2", "b": "3"}, wantKey: "c", wantVal: "5"}, + {name: "mul", dxFile: "native/arith/mul.dx", reads: []string{"./a", "./b"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "6", "b": "7"}, wantKey: "c", wantVal: "42"}, + {name: "div", dxFile: "native/arith/div.dx", reads: []string{"./a", "./b"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "15", "b": "2"}, wantKey: "c", wantVal: "7.5"}, + {name: "sub", dxFile: "native/arith/sub.dx", reads: []string{"./a", "./b"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "10", "b": "3"}, wantKey: "c", wantVal: "7"}, + // 比较 + {name: "eq_true", dxFile: "native/compare/eq.dx", reads: []string{"./a", "./b"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "5", "b": "5"}, wantKey: "c", wantVal: "true"}, + {name: "eq_false", dxFile: "native/compare/eq.dx", reads: []string{"./a", "./b"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "2", "b": "9"}, wantKey: "c", wantVal: "false"}, + // 链式 + {name: "chain", dxFile: "native/chain/chain.dx", reads: []string{"./a", "./b", "./c"}, writes: []string{"./d"}, + inputs: map[string]string{"a": "2", "b": "3", "c": "4"}, wantKey: "d", wantVal: "20"}, + // built-in + {name: "abs", dxFile: "native/arith/abs.dx", reads: []string{"./a"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "-5"}, wantKey: "c", wantVal: "5"}, + {name: "pow", dxFile: "native/arith/pow.dx", reads: []string{"./a", "./b"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "2", "b": "3"}, wantKey: "c", wantVal: "8.0"}, + {name: "max", dxFile: "native/arith/max.dx", reads: []string{"./a", "./b"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "7", "b": "3"}, wantKey: "c", wantVal: "7"}, + {name: "min", dxFile: "native/arith/min.dx", reads: []string{"./a", "./b"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "-2", "b": "5"}, wantKey: "c", wantVal: "-2"}, + {name: "sqrt", dxFile: "native/arith/sqrt.dx", reads: []string{"./a"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "16"}, wantKey: "c", wantVal: "4.0"}, + {name: "neg", dxFile: "native/arith/neg.dx", reads: []string{"./a"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "5"}, wantKey: "c", wantVal: "-5"}, + {name: "sign_pos", dxFile: "native/arith/sign.dx", reads: []string{"./a"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "5"}, wantKey: "c", wantVal: "1"}, + {name: "sign_neg", dxFile: "native/arith/sign.dx", reads: []string{"./a"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "-8"}, wantKey: "c", wantVal: "-1"}, + // cast + {name: "int", dxFile: "native/cast/int.dx", reads: []string{"./a"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "3.7"}, wantKey: "c", wantVal: "3"}, + {name: "float", dxFile: "native/cast/float.dx", reads: []string{"./a"}, writes: []string{"./c"}, + inputs: map[string]string{"a": "42"}, wantKey: "c", wantVal: "42.0"}, + } + + for i, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + funcName := fmt.Sprintf("native_%s_%d", tc.name, i) + fp := filepath.Join(root, tc.dxFile) + + fn, err := testutil.LoadDxFile(fp) + if err != nil { + t.Fatalf("LoadDxFile: %v", err) + } + fn.Name = funcName + if err := fn.RegisterFunc(ctx, rdb); err != nil { + t.Fatalf("RegisterFunc: %v", err) + } + + vtid, err := testutil.CreateVThread(ctx, rdb, funcName, tc.reads, tc.writes) + if err != nil { + t.Fatalf("CreateVThread: %v", err) + } + for slot, val := range tc.inputs { + rdb.Set(ctx, "/vthread/"+vtid+"/"+slot, val, 0) + } + rdb.RPush(ctx, "notify:vm", `{"event":"new_vthread","vtid":"`+vtid+`"}`) + + outputs, done := waitVthreadDone(t, rdb, vtid, 10*time.Second) + if !done { + t.Fatal("vthread did not complete") + } + got := outputs[tc.wantKey] + if got != tc.wantVal { + t.Errorf("%s: got %q, want %q", tc.wantKey, got, tc.wantVal) + } else { + t.Logf(" %s = %s ✓", tc.wantKey, got) + } + }) + } +} + +// ═══════════════════════════════════════════════════════════════ +// Integration: Cross-Call (多函数链) +// ═══════════════════════════════════════════════════════════════ + +func TestIntegration_CrossCall(t *testing.T) { + rdb, ctx := connectRedisIntegration(t) + defer rdb.Close() + + root := filepath.Join("..", "..", "example", "dxlang") + + // 加载 double, triple, diamond + double, _ := testutil.LoadDxFile(filepath.Join(root, "call/double.dx")) + double.Name = "double" + double.RegisterFunc(ctx, rdb) + + triple, _ := testutil.LoadDxFile(filepath.Join(root, "call/triple.dx")) + triple.Name = "triple" + triple.RegisterFunc(ctx, rdb) + + diamond, _ := testutil.LoadDxFile(filepath.Join(root, "call/diamond.dx")) + diamond.Name = "diamond" + diamond.RegisterFunc(ctx, rdb) + + // Start VM worker + vmCtx, vmCancel := context.WithCancel(ctx) + defer vmCancel() + go engine.RunWorker(vmCtx, rdb, 0) + time.Sleep(150 * time.Millisecond) + + // diamond(A=5) → double(5)=10, triple(5)=15, R=25 + vtid, _ := testutil.CreateVThread(ctx, rdb, "diamond", []string{"./a"}, []string{"./r"}) + rdb.Set(ctx, "/vthread/"+vtid+"/a", "5", 0) + rdb.RPush(ctx, "notify:vm", `{"event":"new_vthread","vtid":"`+vtid+`"}`) + + outputs, done := waitVthreadDone(t, rdb, vtid, 15*time.Second) + if !done { + t.Fatal("vthread did not complete") + } + if outputs["r"] != "25" { + t.Errorf("diamond(5): r=%q, want '25'", outputs["r"]) + } else { + t.Log("diamond(5) = 25 ✓") + } +} + diff --git a/executor/vm/go.mod b/executor/vm/go.mod new file mode 100644 index 00000000..4e985294 --- /dev/null +++ b/executor/vm/go.mod @@ -0,0 +1,10 @@ +module deepx/executor/vm + +go 1.24.4 + +require github.com/redis/go-redis/v9 v9.19.0 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + go.uber.org/atomic v1.11.0 // indirect +) diff --git a/executor/vm/go.sum b/executor/vm/go.sum new file mode 100644 index 00000000..41952ed2 --- /dev/null +++ b/executor/vm/go.sum @@ -0,0 +1,22 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k= +github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= +github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/executor/vm/internal/cache/cache.go b/executor/vm/internal/cache/cache.go new file mode 100644 index 00000000..4f4aeee3 --- /dev/null +++ b/executor/vm/internal/cache/cache.go @@ -0,0 +1,51 @@ +// Package cache 提供子栈指令的本地内存缓存,避免每条指令都访问 Redis。 +package cache + +import ( + "context" + "fmt" + "strings" + + "github.com/redis/go-redis/v9" +) + +// SubStackCache 子栈指令本地缓存。 +// CALL 翻译后加载,RETURN 时释放。 +type SubStackCache struct { + Root string // "/vthread/1/[2,0]/" + Cells map[string]string // 相对 key → value, e.g., "[0,0]"→"matmul" +} + +// LoadSubStack 从 Redis MGET 加载整个子栈到本地。 +func LoadSubStack(ctx context.Context, rdb *redis.Client, vtid string, pc string) *SubStackCache { + root := fmt.Sprintf("/vthread/%s/%s/", vtid, pc) + + keys, err := rdb.Keys(ctx, root+"*").Result() + if err != nil || len(keys) == 0 { + return nil + } + + vals, err := rdb.MGet(ctx, keys...).Result() + if err != nil { + return nil + } + + c := &SubStackCache{ + Root: root, + Cells: make(map[string]string, len(keys)), + } + + for i, key := range keys { + localKey := strings.TrimPrefix(key, root) + if s, ok := vals[i].(string); ok { + c.Cells[localKey] = s + } + } + + return c +} + +// Get 从本地缓存读取指令坐标的值。 +func (c *SubStackCache) Get(addr0, addr1 int) string { + return c.Cells[fmt.Sprintf("[%d,%d]", addr0, addr1)] +} diff --git a/executor/vm/internal/dispatch/dispatch.go b/executor/vm/internal/dispatch/dispatch.go new file mode 100644 index 00000000..4f139e21 --- /dev/null +++ b/executor/vm/internal/dispatch/dispatch.go @@ -0,0 +1,325 @@ +// Package dispatch 负责指令分发到 op-plat / heap-plat 或本地执行。 +package dispatch + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + "time" + + "deepx/executor/vm/internal/ir" + "deepx/executor/vm/internal/route" + "deepx/executor/vm/internal/state" + "github.com/redis/go-redis/v9" +) + +// OpTask 发送给 op-plat 的计算任务。 +type OpTask struct { + Vtid string `json:"vtid"` + PC string `json:"pc"` + Opcode string `json:"opcode"` + Inputs []ParamRef `json:"inputs"` + Outputs []ParamRef `json:"outputs"` + Params map[string]interface{} `json:"params,omitempty"` +} + +// ParamRef 参数引用 (tensor 元信息)。 +type ParamRef struct { + Key string `json:"key"` + Dtype string `json:"dtype,omitempty"` + Shape []int `json:"shape,omitempty"` + Address map[string]interface{} `json:"address,omitempty"` +} + +// HeapTask 发送给 heap-plat 的生命周期任务。 +type HeapTask struct { + Vtid string `json:"vtid"` + PC string `json:"pc"` + Op string `json:"op"` + Key string `json:"key"` + Device string `json:"device,omitempty"` + Dtype string `json:"dtype,omitempty"` + Shape []int `json:"shape,omitempty"` + Src string `json:"src,omitempty"` + Dst string `json:"dst,omitempty"` +} + +// IsRelative 判断是否为相对路径引用 (./ 前缀)。 +func IsRelative(param string) bool { + return len(param) >= 2 && param[:2] == "./" +} + +// ResolveWriteKey 将相对路径解析为绝对 Redis key。 +func ResolveWriteKey(vtid, param string) string { + if IsRelative(param) { + return "/vthread/" + vtid + "/" + param[2:] + } + return param +} + +func isLiteral(s string) bool { + if IsRelative(s) { + return false + } + if len(s) > 0 && s[0] == '/' { + return false + } + return true +} + +// resolveParam 解析参数引用,返回完整的 tensor 元信息。 +func resolveParam(ctx context.Context, rdb *redis.Client, vtid string, param string) ParamRef { + ref := ParamRef{Key: param} + + resolvedKey := param + if IsRelative(param) { + resolvedKey = "/vthread/" + vtid + "/" + param[2:] + } + ref.Key = resolvedKey + + val, err := rdb.Get(ctx, resolvedKey).Result() + if err != nil { + return ref + } + + var meta map[string]interface{} + if err := json.Unmarshal([]byte(val), &meta); err != nil { + return ref + } + + if dtype, ok := meta["dtype"].(string); ok { + ref.Dtype = dtype + } + if shapeRaw, ok := meta["shape"].([]interface{}); ok { + for _, s := range shapeRaw { + if n, ok := s.(float64); ok { + ref.Shape = append(ref.Shape, int(n)) + } + } + } + if addr, ok := meta["address"].(map[string]interface{}); ok { + ref.Address = addr + } + + return ref +} + +func buildOpTask(ctx context.Context, rdb *redis.Client, vtid string, pc string, inst *ir.Instruction) *OpTask { + task := &OpTask{ + Vtid: vtid, + PC: pc, + Opcode: inst.Opcode, + Params: make(map[string]interface{}), + } + + switch inst.Opcode { + case "save": + // save(tensor, filepath) + // Reads[0] = tensor → input + // Reads[1] = filepath → param + for i, r := range inst.Reads { + if i == 0 { + task.Inputs = append(task.Inputs, resolveParam(ctx, rdb, vtid, r)) + } else { + task.Params[fmt.Sprintf("arg%d", len(task.Params))] = r + } + } + + case "load": + // load(filepath) → output_tensor + // Reads[0] = filepath → param + // Writes[0] = output tensor → output + for _, r := range inst.Reads { + task.Params[fmt.Sprintf("arg%d", len(task.Params))] = r + } + for _, w := range inst.Writes { + task.Outputs = append(task.Outputs, resolveParam(ctx, rdb, vtid, w)) + } + + case "print": + // print(tensor) — all reads are tensor inputs, no outputs + for _, r := range inst.Reads { + task.Inputs = append(task.Inputs, resolveParam(ctx, rdb, vtid, r)) + } + // no outputs + + default: + for _, r := range inst.Reads { + if isLiteral(r) { + task.Params[fmt.Sprintf("arg%d", len(task.Params))] = r + } else { + task.Inputs = append(task.Inputs, resolveParam(ctx, rdb, vtid, r)) + } + } + + for _, w := range inst.Writes { + task.Outputs = append(task.Outputs, resolveParam(ctx, rdb, vtid, w)) + } + } + + return task +} + +func buildHeapTask(vtid string, pc string, inst *ir.Instruction) *HeapTask { + task := &HeapTask{ + Vtid: vtid, + PC: pc, + Op: inst.Opcode, + } + + switch inst.Opcode { + case "newtensor": + // Writes[0] = tensor key (e.g., "/data/x") + // Reads[0] = dtype (e.g., "f32") + // Reads[1] = shape (e.g., "[10,10]" or "[100]") + if len(inst.Writes) > 0 { + task.Key = inst.Writes[0] + } + if len(inst.Reads) > 0 { + task.Dtype = inst.Reads[0] + } + if len(inst.Reads) > 1 { + task.Shape = parseShapeParam(inst.Reads[1]) + } + case "deltensor": + if len(inst.Reads) > 0 { + task.Key = inst.Reads[0] + } + case "clonetensor": + if len(inst.Reads) > 0 { + task.Src = inst.Reads[0] + } + if len(inst.Writes) > 0 { + task.Dst = inst.Writes[0] + } + } + + return task +} + +// parseShapeParam converts "[10,10]" or "[100]" to []int. +func parseShapeParam(raw string) []int { + raw = strings.Trim(raw, "[] ") + if raw == "" { + return nil + } + var shape []int + for _, s := range strings.Split(raw, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + var n int + fmt.Sscanf(s, "%d", &n) + shape = append(shape, n) + } + return shape +} + +// Compute 分发张量计算指令到 op-plat。 +func Compute(ctx context.Context, rdb *redis.Client, vtid string, pc string, inst *ir.Instruction) error { + instance, err := route.Select(ctx, rdb, inst.Opcode) + if err != nil { + return fmt.Errorf("route: %w", err) + } + + task := buildOpTask(ctx, rdb, vtid, pc, inst) + cmdQueue := fmt.Sprintf("cmd:op-%s", instance) + + taskJSON, err := json.Marshal(task) + if err != nil { + return fmt.Errorf("marshal task: %w", err) + } + + if err := rdb.RPush(ctx, cmdQueue, taskJSON).Err(); err != nil { + return fmt.Errorf("push task: %w", err) + } + + log.Printf("[%s] PUSH %s → %s", vtid, inst.Opcode, cmdQueue) + + state.Set(ctx, rdb, vtid, pc, "wait") + done, err := state.WaitDone(ctx, rdb, vtid, 30*time.Second) + if err != nil { + state.SetError(ctx, rdb, vtid, pc, fmt.Sprintf("BLPOP timeout: %v", err)) + return err + } + + if status, ok := done["status"].(string); ok && status == "error" { + errInfo := fmt.Sprintf("%v", done["error"]) + state.SetError(ctx, rdb, vtid, pc, errInfo) + return fmt.Errorf("op error: %s", errInfo) + } + + log.Printf("[%s] DONE %s", vtid, inst.Opcode) + state.Set(ctx, rdb, vtid, ir.NextPC(pc), "running") + return nil +} + +// Lifecycle 分发生命周期指令到 heap-plat。 +func Lifecycle(ctx context.Context, rdb *redis.Client, vtid string, pc string, inst *ir.Instruction) error { + task := buildHeapTask(vtid, pc, inst) + taskJSON, err := json.Marshal(task) + if err != nil { + return fmt.Errorf("marshal heap task: %w", err) + } + + if err := rdb.RPush(ctx, "cmd:heap-metal:0", taskJSON).Err(); err != nil { + return fmt.Errorf("push heap task: %w", err) + } + + log.Printf("[%s] PUSH %s → cmd:heap-metal:0", vtid, inst.Opcode) + + done, err := state.WaitDone(ctx, rdb, vtid, 5*time.Second) + if err != nil { + state.SetError(ctx, rdb, vtid, pc, fmt.Sprintf("heap op timeout: %v", err)) + return err + } + + if status, ok := done["status"].(string); ok && status == "error" { + errInfo := fmt.Sprintf("%v", done["error"]) + state.SetError(ctx, rdb, vtid, pc, errInfo) + return fmt.Errorf("heap op error: %s", errInfo) + } + + state.Set(ctx, rdb, vtid, ir.NextPC(pc), "running") + return nil +} + +// If 处理 IF 条件分支。 +func If(ctx context.Context, rdb *redis.Client, vtid string, pc string, inst *ir.Instruction) error { + if len(inst.Reads) == 0 { + return fmt.Errorf("if without condition") + } + + condVal := inst.Reads[0] + if IsRelative(condVal) { + resolvedKey := "/vthread/" + vtid + "/" + condVal[2:] + val, err := rdb.Get(ctx, resolvedKey).Result() + if err == nil { + condVal = val + } + } + + cond := isTruthy(condVal) + var branchPC string + if cond { + branchPC = pc + "/true/0" + } else { + branchPC = pc + "/false/0" + } + + log.Printf("[%s] IF %v → %s", vtid, cond, branchPC) + state.Set(ctx, rdb, vtid, branchPC, "running") + return nil +} + +func isTruthy(val string) bool { + switch strings.ToLower(strings.TrimSpace(val)) { + case "true", "1", "yes": + return true + default: + return false + } +} diff --git a/executor/vm/internal/dispatch/native.go b/executor/vm/internal/dispatch/native.go new file mode 100644 index 00000000..325cf9b3 --- /dev/null +++ b/executor/vm/internal/dispatch/native.go @@ -0,0 +1,488 @@ +package dispatch + +import ( + "context" + "fmt" + "log" + "math" + "strconv" + "strings" + + "deepx/executor/vm/internal/ir" + "deepx/executor/vm/internal/state" + "github.com/redis/go-redis/v9" +) + +// nativeValue 表示 VM 原生求值中的值,支持 bool / int / float / string。 +type nativeValue struct { + kind string // "bool" | "int" | "float" | "string" + raw string + b bool + i int64 + f float64 +} + +func parseNativeValue(raw string) nativeValue { + v := nativeValue{raw: raw} + switch strings.ToLower(strings.TrimSpace(raw)) { + case "true": + v.kind = "bool" + v.b = true + return v + case "false": + v.kind = "bool" + v.b = false + return v + } + if i, err := strconv.ParseInt(raw, 10, 64); err == nil { + v.kind = "int" + v.i = i + return v + } + if f, err := strconv.ParseFloat(raw, 64); err == nil { + v.kind = "float" + v.f = f + return v + } + v.kind = "string" + return v +} + +func (v nativeValue) String() string { + switch v.kind { + case "bool": + if v.b { + return "true" + } + return "false" + case "int": + return strconv.FormatInt(v.i, 10) + case "float": + s := strconv.FormatFloat(v.f, 'f', -1, 64) + if !strings.Contains(s, ".") { + s += ".0" + } + return s + default: + return v.raw + } +} + +func (v nativeValue) asFloat() float64 { + switch v.kind { + case "int": + return float64(v.i) + case "float": + return v.f + default: + return 0 + } +} + +func (v nativeValue) asInt() int64 { + switch v.kind { + case "int": + return v.i + case "float": + return int64(v.f) + default: + return 0 + } +} + +func (v nativeValue) asBool() bool { + switch v.kind { + case "bool": + return v.b + default: + return v.raw != "" && v.raw != "0" + } +} + +// Native 直接求值基础类型运算指令,不经过 op-plat。 +func Native(ctx context.Context, rdb *redis.Client, vtid string, pc string, inst *ir.Instruction) error { + inputs := make([]nativeValue, 0, len(inst.Reads)) + for _, r := range inst.Reads { + var raw string + if IsRelative(r) { + key := "/vthread/" + vtid + "/" + r[2:] + val, err := rdb.Get(ctx, key).Result() + if err != nil { + msg := fmt.Sprintf("native read %s: %v", key, err) + state.SetError(ctx, rdb, vtid, pc, msg) + return fmt.Errorf("%s", msg) + } + raw = val + } else { + raw = r + } + inputs = append(inputs, parseNativeValue(raw)) + } + + result, err := evalNative(inst.Opcode, inputs) + if err != nil { + state.SetError(ctx, rdb, vtid, pc, err.Error()) + return err + } + + if len(inst.Writes) > 0 { + outKey := ResolveWriteKey(vtid, inst.Writes[0]) + if err := rdb.Set(ctx, outKey, result.String(), 0).Err(); err != nil { + msg := fmt.Sprintf("native write %s: %v", outKey, err) + state.SetError(ctx, rdb, vtid, pc, msg) + return fmt.Errorf("%s", msg) + } + } + + log.Printf("[%s] NATIVE %s %v = %s", vtid, inst.Opcode, inputs, result.String()) + state.Set(ctx, rdb, vtid, ir.NextPC(pc), "running") + return nil +} + +func evalNative(op string, inputs []nativeValue) (nativeValue, error) { + switch op { + case "+": + return evalBinaryArith(inputs, func(a, b float64) float64 { return a + b }) + case "-": + if len(inputs) == 1 { + return evalNeg(inputs[0]) + } + return evalBinaryArith(inputs, func(a, b float64) float64 { return a - b }) + case "*": + return evalBinaryArith(inputs, func(a, b float64) float64 { return a * b }) + case "/": + return evalDiv(inputs) + case "%": + return evalMod(inputs) + case "==": + return evalCmp(inputs, func(a, b float64) bool { return a == b }, + func(a, b string) bool { return a == b }) + case "!=": + return evalCmp(inputs, func(a, b float64) bool { return a != b }, + func(a, b string) bool { return a != b }) + case "<": + return evalCmpNum(inputs, func(a, b float64) bool { return a < b }) + case ">": + return evalCmpNum(inputs, func(a, b float64) bool { return a > b }) + case "<=": + return evalCmpNum(inputs, func(a, b float64) bool { return a <= b }) + case ">=": + return evalCmpNum(inputs, func(a, b float64) bool { return a >= b }) + case "&&": + return evalLogic(inputs, func(a, b bool) bool { return a && b }) + case "||": + return evalLogic(inputs, func(a, b bool) bool { return a || b }) + case "!": + return evalNot(inputs) + case "&": + return evalBinaryInt(inputs, func(a, b int64) int64 { return a & b }) + case "|": + return evalBinaryInt(inputs, func(a, b int64) int64 { return a | b }) + case "^": + return evalBinaryInt(inputs, func(a, b int64) int64 { return a ^ b }) + case "<<": + return evalBinaryInt(inputs, func(a, b int64) int64 { return a << uint64(b) }) + case ">>": + return evalBinaryInt(inputs, func(a, b int64) int64 { return a >> uint64(b) }) + + // ── 数学 built-in ── + case "abs": + return evalAbs(inputs) + case "pow": + return evalPow(inputs) + case "min": + return evalMin(inputs) + case "max": + return evalMax(inputs) + case "sqrt": + return evalSqrt(inputs) + case "exp": + return evalExp(inputs) + case "log": + return evalLog(inputs) + case "neg": + return evalUnaryArith(inputs, func(a float64) float64 { return -a }) + case "sign": + return evalSign(inputs) + + // ── 类型转换 built-in ── + case "int": + return evalToInt(inputs) + case "float": + return evalToFloat(inputs) + case "bool": + return evalToBool(inputs) + + default: + return nativeValue{}, fmt.Errorf("unknown native op: %s", op) + } +} + +func requireBinary(inputs []nativeValue) error { + if len(inputs) != 2 { + return fmt.Errorf("binary op requires 2 inputs, got %d", len(inputs)) + } + return nil +} + +func requireUnary(inputs []nativeValue) error { + if len(inputs) != 1 { + return fmt.Errorf("unary op requires 1 input, got %d", len(inputs)) + } + return nil +} + +func evalBinaryArith(inputs []nativeValue, fn func(float64, float64) float64) (nativeValue, error) { + if err := requireBinary(inputs); err != nil { + return nativeValue{}, err + } + a, b := inputs[0], inputs[1] + result := fn(a.asFloat(), b.asFloat()) + if a.kind == "int" && b.kind == "int" { + return nativeValue{kind: "int", i: int64(result)}, nil + } + return nativeValue{kind: "float", f: result}, nil +} + +func evalNeg(v nativeValue) (nativeValue, error) { + switch v.kind { + case "int": + return nativeValue{kind: "int", i: -v.i}, nil + case "float": + return nativeValue{kind: "float", f: -v.f}, nil + default: + return nativeValue{}, fmt.Errorf("cannot negate %s", v.kind) + } +} + +func evalDiv(inputs []nativeValue) (nativeValue, error) { + if err := requireBinary(inputs); err != nil { + return nativeValue{}, err + } + a, b := inputs[0], inputs[1] + bf := b.asFloat() + if bf == 0 { + return nativeValue{}, fmt.Errorf("division by zero") + } + result := a.asFloat() / bf + return nativeValue{kind: "float", f: result}, nil +} + +func evalMod(inputs []nativeValue) (nativeValue, error) { + if err := requireBinary(inputs); err != nil { + return nativeValue{}, err + } + a, b := inputs[0], inputs[1] + if b.asInt() == 0 { + return nativeValue{}, fmt.Errorf("modulo by zero") + } + return nativeValue{kind: "int", i: a.asInt() % b.asInt()}, nil +} + +func evalCmp(inputs []nativeValue, numCmp func(float64, float64) bool, strCmp func(string, string) bool) (nativeValue, error) { + if err := requireBinary(inputs); err != nil { + return nativeValue{}, err + } + a, b := inputs[0], inputs[1] + if (a.kind == "int" || a.kind == "float") && (b.kind == "int" || b.kind == "float") { + return nativeValue{kind: "bool", b: numCmp(a.asFloat(), b.asFloat())}, nil + } + return nativeValue{kind: "bool", b: strCmp(a.raw, b.raw)}, nil +} + +func evalCmpNum(inputs []nativeValue, fn func(float64, float64) bool) (nativeValue, error) { + return evalCmp(inputs, fn, func(a, b string) bool { return a < b }) +} + +func evalLogic(inputs []nativeValue, fn func(bool, bool) bool) (nativeValue, error) { + if err := requireBinary(inputs); err != nil { + return nativeValue{}, err + } + a, b := inputs[0], inputs[1] + return nativeValue{kind: "bool", b: fn(a.asBool(), b.asBool())}, nil +} + +func evalNot(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + return nativeValue{kind: "bool", b: !inputs[0].asBool()}, nil +} + +func evalBinaryInt(inputs []nativeValue, fn func(int64, int64) int64) (nativeValue, error) { + if err := requireBinary(inputs); err != nil { + return nativeValue{}, err + } + return nativeValue{kind: "int", i: fn(inputs[0].asInt(), inputs[1].asInt())}, nil +} + +// ── 数学 built-in evaluators ── + +func evalAbs(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + v := inputs[0] + switch v.kind { + case "int": + if v.i < 0 { + return nativeValue{kind: "int", i: -v.i}, nil + } + return nativeValue{kind: "int", i: v.i}, nil + case "float": + return nativeValue{kind: "float", f: math.Abs(v.f)}, nil + default: + return nativeValue{}, fmt.Errorf("abs requires numeric, got %s", v.kind) + } +} + +func evalPow(inputs []nativeValue) (nativeValue, error) { + if err := requireBinary(inputs); err != nil { + return nativeValue{}, err + } + result := math.Pow(inputs[0].asFloat(), inputs[1].asFloat()) + return nativeValue{kind: "float", f: result}, nil +} + +func evalMin(inputs []nativeValue) (nativeValue, error) { + if err := requireBinary(inputs); err != nil { + return nativeValue{}, err + } + a, b := inputs[0], inputs[1] + if (a.kind == "int" || a.kind == "float") && (b.kind == "int" || b.kind == "float") { + af, bf := a.asFloat(), b.asFloat() + result := math.Min(af, bf) + if a.kind == "int" && b.kind == "int" { + return nativeValue{kind: "int", i: int64(result)}, nil + } + return nativeValue{kind: "float", f: result}, nil + } + // 字符串回退 + if a.raw < b.raw { + return a, nil + } + return b, nil +} + +func evalMax(inputs []nativeValue) (nativeValue, error) { + if err := requireBinary(inputs); err != nil { + return nativeValue{}, err + } + a, b := inputs[0], inputs[1] + if (a.kind == "int" || a.kind == "float") && (b.kind == "int" || b.kind == "float") { + af, bf := a.asFloat(), b.asFloat() + result := math.Max(af, bf) + if a.kind == "int" && b.kind == "int" { + return nativeValue{kind: "int", i: int64(result)}, nil + } + return nativeValue{kind: "float", f: result}, nil + } + if a.raw > b.raw { + return a, nil + } + return b, nil +} + +func evalSqrt(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + x := inputs[0].asFloat() + if x < 0 { + return nativeValue{}, fmt.Errorf("sqrt of negative number: %v", x) + } + return nativeValue{kind: "float", f: math.Sqrt(x)}, nil +} + +func evalExp(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + return nativeValue{kind: "float", f: math.Exp(inputs[0].asFloat())}, nil +} + +func evalLog(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + x := inputs[0].asFloat() + if x <= 0 { + return nativeValue{}, fmt.Errorf("log of non-positive number: %v", x) + } + return nativeValue{kind: "float", f: math.Log(x)}, nil +} + +func evalSign(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + v := inputs[0] + f := v.asFloat() + if f > 0 { + return nativeValue{kind: "int", i: 1}, nil + } else if f < 0 { + return nativeValue{kind: "int", i: -1}, nil + } + return nativeValue{kind: "int", i: 0}, nil +} + +func evalUnaryArith(inputs []nativeValue, fn func(float64) float64) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + v := inputs[0] + result := fn(v.asFloat()) + if v.kind == "int" { + return nativeValue{kind: "int", i: int64(result)}, nil + } + return nativeValue{kind: "float", f: result}, nil +} + +// ── 类型转换 built-in evaluators ── + +func evalToInt(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + v := inputs[0] + switch v.kind { + case "int": + return v, nil + case "float": + return nativeValue{kind: "int", i: int64(v.f)}, nil + case "bool": + if v.b { + return nativeValue{kind: "int", i: 1}, nil + } + return nativeValue{kind: "int", i: 0}, nil + default: + return nativeValue{kind: "int", i: v.asInt()}, nil + } +} + +func evalToFloat(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + v := inputs[0] + switch v.kind { + case "float": + return v, nil + case "int": + return nativeValue{kind: "float", f: float64(v.i)}, nil + case "bool": + if v.b { + return nativeValue{kind: "float", f: 1.0}, nil + } + return nativeValue{kind: "float", f: 0.0}, nil + default: + return nativeValue{kind: "float", f: v.asFloat()}, nil + } +} + +func evalToBool(inputs []nativeValue) (nativeValue, error) { + if err := requireUnary(inputs); err != nil { + return nativeValue{}, err + } + return nativeValue{kind: "bool", b: inputs[0].asBool()}, nil +} diff --git a/executor/vm/internal/engine/engine.go b/executor/vm/internal/engine/engine.go new file mode 100644 index 00000000..b2729e06 --- /dev/null +++ b/executor/vm/internal/engine/engine.go @@ -0,0 +1,141 @@ +// Package engine 提供 VM 核心执行循环与指令分发。 +// +// engine 是 VM 的编排者,协调 picker/dispatch/translate/state 等子包。 +package engine + +import ( + "context" + "fmt" + "log" + + "deepx/executor/vm/internal/dispatch" + "deepx/executor/vm/internal/ir" + "deepx/executor/vm/internal/picker" + "deepx/executor/vm/internal/state" + "deepx/executor/vm/internal/translate" + "github.com/redis/go-redis/v9" +) + +// RunWorker 单个 worker 的主循环。 +func RunWorker(ctx context.Context, rdb *redis.Client, id int) { + log.Printf("worker-%d started", id) + for { + select { + case <-ctx.Done(): + log.Printf("worker-%d stopped", id) + return + default: + } + + vtid := picker.PickVthread(ctx, rdb) + if vtid == "" { + picker.WaitForVthread(ctx, rdb) + continue + } + + log.Printf("worker-%d picked vthread %s", id, vtid) + Execute(ctx, rdb, vtid) + // 注意: VM 不负责清理 vthread key,由调用方在读取结果后自行清理 + } +} + +// Execute 执行一个 vthread 直到完成或出错。 +func Execute(ctx context.Context, rdb *redis.Client, vtid string) { + for { + s := state.Get(ctx, rdb, vtid) + if s.Status == "done" || s.Status == "error" { + return + } + + pc := s.PC + inst, err := ir.Decode(ctx, rdb, vtid, pc) + if err != nil { + log.Printf("[%s] decode error at %s: %v", vtid, pc, err) + state.SetError(ctx, rdb, vtid, pc, fmt.Sprintf("decode: %v", err)) + return + } + + if inst.Opcode == "" { + log.Printf("[%s] done (no more instructions at %s)", vtid, pc) + state.Set(ctx, rdb, vtid, pc, "done") + return + } + + log.Printf("[%s] PC=%s OP=%s READS=%v WRITES=%v", vtid, pc, inst.Opcode, inst.Reads, inst.Writes) + + var execErr error + + switch { + case ir.IsControlOp(inst.Opcode): + execErr = dispatchControl(ctx, rdb, vtid, pc, inst) + + case ir.IsNativeOp(inst.Opcode): + execErr = dispatch.Native(ctx, rdb, vtid, pc, inst) + + case ir.IsLifecycleOp(inst.Opcode): + execErr = dispatch.Lifecycle(ctx, rdb, vtid, pc, inst) + + case isFunctionCall(ctx, rdb, inst.Opcode): + // 非内置关键字的标识符 → 函数调用 + // 将 funcName(A, B) -> ./C 转换为内部 call(funcName, A, B) -> ./C + inst.Reads = append([]string{inst.Opcode}, inst.Reads...) + inst.Opcode = "call" + execErr = dispatchControl(ctx, rdb, vtid, pc, inst) + + case ir.IsComputeOp(inst.Opcode): + execErr = dispatch.Compute(ctx, rdb, vtid, pc, inst) + + default: + state.Set(ctx, rdb, vtid, ir.NextPC(pc), "running") + } + + if execErr != nil { + log.Printf("[%s] error: %v", vtid, execErr) + return + } + } +} + +// dispatchControl 处理控制流指令 (call / return / if)。 +func dispatchControl(ctx context.Context, rdb *redis.Client, vtid string, pc string, inst *ir.Instruction) error { + switch inst.Opcode { + case "call": + substackPC := translate.HandleCall(ctx, rdb, vtid, pc, inst) + state.Set(ctx, rdb, vtid, substackPC, "running") + log.Printf("[%s] CALL → substack %s", vtid, substackPC) + return nil + + case "return": + parentPC := translate.HandleReturn(ctx, rdb, vtid, pc) + log.Printf("[%s] RETURN → parent %s", vtid, parentPC) + + if parentPC == pc { + state.Set(ctx, rdb, vtid, pc, "done") + return nil + } + state.Set(ctx, rdb, vtid, parentPC, "running") + return nil + + case "if": + return dispatch.If(ctx, rdb, vtid, pc, inst) + + default: + return fmt.Errorf("unknown control opcode: %s", inst.Opcode) + } +} + +// isFunctionCall 判断 opcode 是否是一个已注册的函数名 (而非算子)。 +// 检查 /src/func/ 或 /op/*/func/ 是否存在。 +func isFunctionCall(ctx context.Context, rdb *redis.Client, opcode string) bool { + exists, err := rdb.Exists(ctx, "/src/func/"+opcode).Result() + if err == nil && exists > 0 { + return true + } + for _, backend := range []string{"op-metal", "op-cuda", "op-cpu"} { + exists, err := rdb.Exists(ctx, fmt.Sprintf("/op/%s/func/%s", backend, opcode)).Result() + if err == nil && exists > 0 { + return true + } + } + return false +} diff --git a/executor/vm/internal/ir/instruction.go b/executor/vm/internal/ir/instruction.go new file mode 100644 index 00000000..2911ba7a --- /dev/null +++ b/executor/vm/internal/ir/instruction.go @@ -0,0 +1,430 @@ +package ir + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/redis/go-redis/v9" +) + +// Instruction 表示执行层 [addr0, addr1] 解码后的一条指令 +type Instruction struct { + Opcode string // [addr0, 0] = "+" | "call" | "return" | ... + Reads []string // [addr0, -1], [addr0, -2], ... + Writes []string // [addr0, 1], [addr0, 2], ... + PC string // 当前指令坐标, e.g., "[3,0]" 或 "[2,0]/[1,0]" +} + +const maxParams = 10 + +// Decode 从 Redis 执行层 key 解码指令 +func Decode(ctx context.Context, rdb *redis.Client, vtid string, pc string) (*Instruction, error) { + prefix, addr0 := parsePC(pc) + keyBase := fmt.Sprintf("/vthread/%s/%s", vtid, prefix) + + keys := make([]string, 0, 1+maxParams*2) + keys = append(keys, fmt.Sprintf("%s[%d,0]", keyBase, addr0)) + for i := 1; i <= maxParams; i++ { + keys = append(keys, fmt.Sprintf("%s[%d,-%d]", keyBase, addr0, i)) + keys = append(keys, fmt.Sprintf("%s[%d,%d]", keyBase, addr0, i)) + } + + vals, err := rdb.MGet(ctx, keys...).Result() + if err != nil { + return nil, fmt.Errorf("decode MGET: %w", err) + } + + inst := &Instruction{PC: pc} + + if s, ok := vals[0].(string); ok { + inst.Opcode = s + } + + for i := 1; i <= maxParams; i++ { + readIdx := (i-1)*2 + 1 + writeIdx := readIdx + 1 + if readIdx < len(vals) { + if s, ok := vals[readIdx].(string); ok && s != "" { + inst.Reads = append(inst.Reads, s) + } + } + if writeIdx < len(vals) { + if s, ok := vals[writeIdx].(string); ok && s != "" { + inst.Writes = append(inst.Writes, s) + } + } + } + + return inst, nil +} + +// DecodeFromCache 从本地缓存 map 解码 (子栈场景, 零 Redis 访问) +func DecodeFromCache(cache map[string]string, pc string) *Instruction { + _, addr0 := parsePC(pc) + inst := &Instruction{PC: pc} + inst.Opcode = cache[fmt.Sprintf("[%d,0]", addr0)] + + for i := 1; i <= maxParams; i++ { + key := fmt.Sprintf("[%d,-%d]", addr0, i) + if v, ok := cache[key]; ok && v != "" { + inst.Reads = append(inst.Reads, v) + } + key = fmt.Sprintf("[%d,%d]", addr0, i) + if v, ok := cache[key]; ok && v != "" { + inst.Writes = append(inst.Writes, v) + } + } + return inst +} + +func parsePC(pc string) (prefix string, addr0 int) { + idx := strings.LastIndex(pc, "/") + if idx >= 0 { + prefix = pc[:idx+1] + addr0 = extractAddr0(pc[idx+1:]) + } else { + addr0 = extractAddr0(pc) + } + return +} + +func IsComputeOp(opcode string) bool { + return !isLifecycleOrControl(opcode) +} + +func IsLifecycleOp(opcode string) bool { + return opcode == "newtensor" || opcode == "deltensor" || opcode == "clonetensor" +} + +func IsControlOp(opcode string) bool { + switch opcode { + case "call", "return", "if", "for": + return true + } + return false +} + +func isLifecycleOrControl(opcode string) bool { + switch opcode { + case "call", "return", "if", "for", + "newtensor", "deltensor", "clonetensor": + return true + } + return false +} + +func NextPC(pc string) string { + parts := strings.Split(pc, "/") + last := parts[len(parts)-1] + num := extractAddr0(last) + parts[len(parts)-1] = fmt.Sprintf("[%d,0]", num+1) + return strings.Join(parts, "/") +} + +func ParentPC(pc string) string { + idx := strings.LastIndex(pc, "/") + if idx < 0 { + return pc + } + return NextPC(pc[:idx]) +} + +func extractAddr0(coord string) int { + s := strings.Trim(coord, "[]") + parts := strings.Split(s, ",") + if len(parts) > 0 { + n, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + return 0 + } + return n + } + return 0 +} + +// ParseDxlang 解析 dxlang 指令字符串为 Instruction。 +// +// 支持三种赋值风格: +// +// 前缀 (命名函数): add(A, B) -> ./C +// 中缀 (符号算子): A + B -> ./C +// !A -> ./C +// C风格 (左箭头): ./C <- A + B +// ./C <- add(A, B) +// +// 严格要求: 所有 key 引用 (以 / 或 ./ 开头的路径) 必须用双引号包裹。 +func ParseDxlang(line string) (*Instruction, error) { + line = strings.TrimSpace(line) + if line == "" { + return nil, fmt.Errorf("empty dxlang line") + } + + inst := &Instruction{} + + // 1. 分离输出 (支持 -> 和 <- 两种箭头) + var expr string + if larrow := findArrow(line, "<-"); larrow >= 0 { + // C风格: "./C" <- A + B → 输出在左, 表达式在右 + writesStr := strings.TrimSpace(line[:larrow]) + expr = strings.TrimSpace(line[larrow+2:]) + if strings.HasPrefix(writesStr, "(") && strings.HasSuffix(writesStr, ")") { + writesStr = writesStr[1 : len(writesStr)-1] + } + if err := validateKeys(parseParamListRaw(writesStr), writesStr, "write"); err != nil { + return nil, err + } + inst.Writes = parseParamList(writesStr) + } else if arrow := strings.Index(line, "->"); arrow >= 0 { + // 传统风格: add(A, B) -> "./C" → 表达式在左, 输出在右 + expr = strings.TrimSpace(line[:arrow]) + writesStr := strings.TrimSpace(line[arrow+2:]) + if strings.HasPrefix(writesStr, "(") && strings.HasSuffix(writesStr, ")") { + writesStr = writesStr[1 : len(writesStr)-1] + } + if err := validateKeys(parseParamListRaw(writesStr), writesStr, "write"); err != nil { + return nil, err + } + inst.Writes = parseParamList(writesStr) + } else { + expr = line + } + + // 2. 尝试中缀解析: "A + B", "!A", "A == B" + if op, left, right, ok := parseInfix(expr); ok { + inst.Opcode = op + if left != "" { + // 剥离引号后再验证 key 引用 + rawLeft := left + left = stripQuotes(left) + if isKeyRef(left) && !isQuoted(rawLeft) { + return nil, fmt.Errorf("read %q must be quoted (e.g. %q) in: %s", left, "\""+left+"\"", line) + } + inst.Reads = append(inst.Reads, left) + } + if right != "" { + rawRight := right + right = stripQuotes(right) + if isKeyRef(right) && !isQuoted(rawRight) { + return nil, fmt.Errorf("read %q must be quoted (e.g. %q) in: %s", right, "\""+right+"\"", line) + } + inst.Reads = append(inst.Reads, right) + } + return inst, nil + } + + // 3. 回退到前缀解析: "add(A, B)" + if idx := strings.Index(expr, "("); idx >= 0 { + inst.Opcode = strings.TrimSpace(expr[:idx]) + rest := expr[idx+1:] + + parenDepth := 1 + closeIdx := -1 + for i, c := range rest { + if c == '(' { + parenDepth++ + } else if c == ')' { + parenDepth-- + if parenDepth == 0 { + closeIdx = i + break + } + } + } + if closeIdx < 0 { + return nil, fmt.Errorf("unmatched paren in: %s", line) + } + + readsStr := rest[:closeIdx] + if err := validateKeys(parseParamListRaw(readsStr), readsStr, "read"); err != nil { + return nil, err + } + inst.Reads = parseParamList(readsStr) + } + + return inst, nil +} + +// findArrow 查找左箭头 <- (区别于 <= 和 <<)。 +// 返回 <- 中 < 的位置, 未找到返回 -1。 +func findArrow(s, arrow string) int { + for i := 0; i < len(s)-1; i++ { + if s[i] == arrow[0] { + // 排除 <=, <<, <>, < 后不是 - 的情况 + if s[i+1] == arrow[1] { + return i + } + } + } + return -1 +} + +// parseInfix 尝试中缀表达式解析。支持二元和单目符号算子。 +// 如果表达式含 '(' 则跳过 (回退到前缀解析)。 +func parseInfix(expr string) (op, left, right string, ok bool) { + expr = strings.TrimSpace(expr) + if expr == "" { + return + } + + // 含 '(' → 前缀格式 (如 add(A, B)),跳过中缀避免 / 在 ./a 中被误判 + if strings.IndexByte(expr, '(') >= 0 { + return + } + + // 多字符算子 (先匹配长的避免子串误匹配) + multiOps := []string{"==", "!=", "<=", ">=", "&&", "||", "<<", ">>"} + for _, op := range multiOps { + if idx := strings.Index(expr, op); idx > 0 { + return op, strings.TrimSpace(expr[:idx]), strings.TrimSpace(expr[idx+len(op):]), true + } + } + + // 单字符二元算子 + singleOps := []string{"+", "*", "/", "%", "<", ">", "&", "|", "^"} + for _, op := range singleOps { + if idx := strings.Index(expr, op); idx > 0 { + return op, strings.TrimSpace(expr[:idx]), strings.TrimSpace(expr[idx+1:]), true + } + } + + // '-' 二元 (idx>0, 不是 unary) + if idx := strings.Index(expr, "-"); idx > 0 { + return "-", strings.TrimSpace(expr[:idx]), strings.TrimSpace(expr[idx+1:]), true + } + + // 单目算子 (位置 0) + if len(expr) > 0 { + if expr[0] == '!' { + return "!", strings.TrimSpace(expr[1:]), "", true + } + if expr[0] == '-' { + return "-", strings.TrimSpace(expr[1:]), "", true + } + } + + return +} + +func parseParamList(s string) []string { + s = strings.TrimSpace(s) + if s == "" { + return nil + } + // Bracket-aware & quote-aware split: respect nested [], (), {} and "..." strings + var params []string + depth := 0 + inQuote := false + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == '"' { + inQuote = !inQuote + continue + } + if inQuote { + continue + } + switch s[i] { + case '[', '(', '{': + depth++ + case ']', ')', '}': + if depth > 0 { + depth-- + } + case ',': + if depth == 0 { + p := strings.TrimSpace(s[start:i]) + if p != "" { + params = append(params, stripQuotes(p)) + } + start = i + 1 + } + } + } + // Last param + if start < len(s) { + p := strings.TrimSpace(s[start:]) + if p != "" { + params = append(params, stripQuotes(p)) + } + } + return params +} + +// parseParamListRaw 与 parseParamList 相同但不剥离引号,用于 key 引用验证。 +func parseParamListRaw(s string) []string { + s = strings.TrimSpace(s) + if s == "" { + return nil + } + var params []string + depth := 0 + inQuote := false + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == '"' { + inQuote = !inQuote + continue + } + if inQuote { + continue + } + switch s[i] { + case '[', '(', '{': + depth++ + case ']', ')', '}': + if depth > 0 { + depth-- + } + case ',': + if depth == 0 { + p := strings.TrimSpace(s[start:i]) + if p != "" { + params = append(params, p) + } + start = i + 1 + } + } + } + if start < len(s) { + p := strings.TrimSpace(s[start:]) + if p != "" { + params = append(params, p) + } + } + return params +} + +// stripQuotes removes surrounding double quotes. +func stripQuotes(s string) string { + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s +} + +// validateKeys 验证所有 key (以 / 或 ./ 开头的参数) 必须用双引号包裹。 +// rawParams 是 parseParamListRaw 返回的未剥离引号的参数。 +func validateKeys(rawParams []string, rawExpr string, role string) error { + for _, raw := range rawParams { + if isQuoted(raw) { + continue // 已加引号, 合法 + } + if isKeyRef(raw) { + quoted := `"` + raw + `"` + return fmt.Errorf("%s %q must be quoted (e.g. %s) in: %s", role, raw, quoted, rawExpr) + } + } + return nil +} + +// isKeyRef 判断参数是否为 key 引用 (tensor 路径、文件路径等)。 +func isKeyRef(s string) bool { + return strings.HasPrefix(s, "/") || strings.HasPrefix(s, "./") +} + +// isQuoted 判断参数是否被双引号包裹。 +func isQuoted(s string) bool { + return len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' +} diff --git a/executor/vm/internal/ir/native.go b/executor/vm/internal/ir/native.go new file mode 100644 index 00000000..98329d7e --- /dev/null +++ b/executor/vm/internal/ir/native.go @@ -0,0 +1,46 @@ +package ir + +// nativeOps 定义 VM 原生求值的算子集合。 +// 包含符号算子 (中缀) 和 built-in 风格函数 (前缀)。 +// 这些算子直接在 VM 内求值,不需要分发到 op-plat。 +var nativeOps = map[string]bool{ + // 算术 (符号) + "+": true, "-": true, "*": true, "/": true, "%": true, + // 比较 (符号) + "==": true, "!=": true, "<": true, ">": true, "<=": true, ">=": true, + // 逻辑 (符号) + "&&": true, "||": true, "!": true, + // 位运算 (符号) + "&": true, "|": true, "^": true, "<<": true, ">>": true, + + // 数学 (built-in 命名) + "abs": true, // abs(x) 绝对值 + "pow": true, // pow(x, y) 幂运算 + "min": true, // min(x, y) 取最小值 + "max": true, // max(x, y) 取最大值 + "sqrt": true, // sqrt(x) 平方根 + "exp": true, // exp(x) e^x + "log": true, // log(x) 自然对数 + "neg": true, // neg(x) 取反 + "sign": true, // sign(x) 符号函数 (-1/0/1) + + // 类型转换 (built-in 命名) + "int": true, // int(x) 转整数 (截断) + "float": true, // float(x) 转浮点 + "bool": true, // bool(x) 转布尔 +} + +// IsNativeOp 判断是否为 VM 原生求值的符号算子。 +func IsNativeOp(opcode string) bool { + return nativeOps[opcode] +} + +// IsUnaryNativeOp 判断是否为单目原生算子。 +func IsUnaryNativeOp(opcode string) bool { + switch opcode { + case "!", "-", "abs", "sqrt", "exp", "log", "neg", "sign", + "int", "float", "bool": + return true + } + return false +} diff --git a/executor/vm/internal/picker/picker.go b/executor/vm/internal/picker/picker.go new file mode 100644 index 00000000..0007dc0e --- /dev/null +++ b/executor/vm/internal/picker/picker.go @@ -0,0 +1,104 @@ +// Package picker 负责原子拾取 status=init 的 vthread。 +package picker + +import ( + "context" + "encoding/json" + "fmt" + "log" + "time" + + "deepx/executor/vm/internal/state" + "github.com/redis/go-redis/v9" +) + +var errSkip = fmt.Errorf("skip") // 内部哨兵: 非 init 状态,跳过 + +// PickVthread 扫描 /vthread/*, 原子抢占 status=init 的 vthread。 +// 返回 vtid,无可用 vthread 时返回空字符串。 +func PickVthread(ctx context.Context, rdb *redis.Client) string { + keys, err := rdb.Keys(ctx, "/vthread/*").Result() + if err != nil || len(keys) == 0 { + return "" + } + + for _, key := range keys { + vtid := extractVtid(key) // "/vthread/42" → "42" + if vtid == "" { + continue + } + // 跳过子 key: /vthread/1/a, /vthread/1/[0,0], ... + if containsAny(vtid, "/") { + continue + } + + if tryPick(ctx, rdb, vtid) { + return vtid + } + } + return "" +} + +func containsAny(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func tryPick(ctx context.Context, rdb *redis.Client, vtid string) bool { + key := "/vthread/" + vtid + + err := rdb.Watch(ctx, func(tx *redis.Tx) error { + val, err := tx.Get(ctx, key).Result() + if err == redis.Nil { + return errSkip + } + if err != nil { + return err + } + + var s state.VThreadState + if err := json.Unmarshal([]byte(val), &s); err != nil { + return err + } + if s.Status != "init" { + return errSkip + } + + s.Status = "running" + data, err := json.Marshal(s) + if err != nil { + return err + } + + _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Set(ctx, key, data, 0) + return nil + }) + return err + }, key) + + if err != nil && err != errSkip { + log.Printf("tryPick %s error (vthread will be marked error): %v", vtid, err) + // 标记 vthread 为 error 状态,避免被反复尝试 + errData, _ := json.Marshal(state.VThreadState{PC: "[0,0]", Status: "error"}) + rdb.Set(ctx, key, errData, 0) + } + return err == nil +} + +func extractVtid(key string) string { + const prefix = "/vthread/" + if len(key) > len(prefix) { + return key[len(prefix):] + } + return "" +} + +// WaitForVthread 阻塞等待新 vthread 创建通知。 +func WaitForVthread(ctx context.Context, rdb *redis.Client) { + rdb.BLPop(ctx, 5*time.Second, "notify:vm") +} diff --git a/executor/vm/internal/route/router.go b/executor/vm/internal/route/router.go new file mode 100644 index 00000000..6b21fc09 --- /dev/null +++ b/executor/vm/internal/route/router.go @@ -0,0 +1,110 @@ +package route + +import ( + "context" + "encoding/json" + "fmt" + "log" + "math" + "strings" + + "github.com/redis/go-redis/v9" +) + +// Select 根据 opcode 选择负载最低的 op-plat 实例 +// 返回实例标识符, e.g., "metal:0", "cuda:1" +func Select(ctx context.Context, rdb *redis.Client, opcode string) (string, error) { + // 1. 找到支持该算子的所有程序 + programs, err := rdb.Keys(ctx, "/op/*/list").Result() + if err != nil { + return "", fmt.Errorf("list op programs: %w", err) + } + + var chosenProgram string + for _, progKey := range programs { + list, err := rdb.LRange(ctx, progKey, 0, -1).Result() + if err != nil { + continue + } + for _, op := range list { + if op == opcode { + parts := strings.Split(progKey, "/") + // "/op/op-cuda/list" → "op-cuda" + if len(parts) >= 3 { + chosenProgram = parts[2] + } + break + } + } + if chosenProgram != "" { + break + } + } + + if chosenProgram == "" { + return "", fmt.Errorf("no op-plat supports opcode: %s", opcode) + } + + // 2. 选择该程序下负载最低的进程实例 + instances, err := rdb.Keys(ctx, "/sys/op-plat/*").Result() + if err != nil { + return "", fmt.Errorf("list op-plat instances: %w", err) + } + + type instInfo struct { + Program string `json:"program"` + Status string `json:"status"` + Load float64 `json:"load"` + } + + bestLoad := math.MaxFloat64 + bestInstance := "" + + for _, instKey := range instances { + if !strings.Contains(instKey, chosenProgram) { + continue + } + + val, err := rdb.Get(ctx, instKey).Result() + if err != nil { + continue + } + var info instInfo + if err := json.Unmarshal([]byte(val), &info); err != nil { + log.Printf("route.Select: unmarshal instance info %s: %v", instKey, err) + continue + } + + if info.Status != "running" { + continue + } + if info.Load < bestLoad { + bestLoad = info.Load + // "/sys/op-plat/op-metal:0" → "metal:0" + parts := strings.Split(instKey, "/") + lastPart := parts[len(parts)-1] + bestInstance = strings.TrimPrefix(lastPart, "op-") + } + } + + if bestInstance == "" { + return "", fmt.Errorf("no running op-plat instance for %s (program %s)", opcode, chosenProgram) + } + + return bestInstance, nil +} + +// DetermineBackend 判断 func 的编译后端 (按优先级) +func DetermineBackend(ctx context.Context, rdb *redis.Client, funcName string) string { + for _, b := range []string{"op-metal", "op-cuda", "op-cpu"} { + exists, err := rdb.Exists(ctx, fmt.Sprintf("/op/%s/func/%s", b, funcName)).Result() + if err != nil { + log.Printf("route.DetermineBackend: EXISTS error for %s: %v", b, err) + continue + } + if exists > 0 { + return b + } + } + return "op-metal" +} diff --git a/executor/vm/internal/state/state.go b/executor/vm/internal/state/state.go new file mode 100644 index 00000000..72e752b4 --- /dev/null +++ b/executor/vm/internal/state/state.go @@ -0,0 +1,77 @@ +// Package state 提供 vthread 状态管理与 Redis 持久化。 +package state + +import ( + "context" + "encoding/json" + "fmt" + "log" + "time" + + "github.com/redis/go-redis/v9" +) + +// VThreadState 存储在 /vthread/ 中,表示运行时状态。 +type VThreadState struct { + PC string `json:"pc"` + Status string `json:"status"` + Mode string `json:"mode,omitempty"` // "single" | "batch", 默认 "single" + Error map[string]string `json:"error,omitempty"` +} + +// Get 读取 vthread 当前状态。 +func Get(ctx context.Context, rdb *redis.Client, vtid string) VThreadState { + val, err := rdb.Get(ctx, "/vthread/"+vtid).Result() + if err != nil { + return VThreadState{Status: "error"} + } + var s VThreadState + if err := json.Unmarshal([]byte(val), &s); err != nil { + log.Printf("state.Get: unmarshal vthread %s: %v", vtid, err) + return VThreadState{Status: "error"} + } + return s +} + +// Set 更新 vthread 的 PC 和 status。 +func Set(ctx context.Context, rdb *redis.Client, vtid string, pc, status string) { + s := VThreadState{PC: pc, Status: status} + data, err := json.Marshal(s) + if err != nil { + log.Printf("state.Set: marshal vthread %s: %v", vtid, err) + return + } + rdb.Set(ctx, "/vthread/"+vtid, data, 0) +} + +// SetError 标记 vthread 为 error 状态。 +func SetError(ctx context.Context, rdb *redis.Client, vtid string, pc string, errMsg string) { + s := map[string]interface{}{ + "pc": pc, + "status": "error", + "error": map[string]string{"code": "VM_ERROR", "message": errMsg}, + } + data, err := json.Marshal(s) + if err != nil { + log.Printf("state.SetError: marshal vthread %s: %v", vtid, err) + return + } + rdb.Set(ctx, "/vthread/"+vtid, data, 0) +} + +// WaitDone 阻塞等待 op-plat / heap-plat 完成通知。 +func WaitDone(ctx context.Context, rdb *redis.Client, vtid string, timeout time.Duration) (map[string]interface{}, error) { + doneKey := "done:" + vtid + result, err := rdb.BLPop(ctx, timeout, doneKey).Result() + if err != nil { + return nil, fmt.Errorf("waitDone timeout for %s: %w", doneKey, err) + } + var done map[string]interface{} + if len(result) > 1 { + if err := json.Unmarshal([]byte(result[1]), &done); err != nil { + log.Printf("state.WaitDone: unmarshal done result for %s: %v", vtid, err) + return nil, fmt.Errorf("unmarshal done result: %w", err) + } + } + return done, nil +} diff --git a/executor/vm/internal/translate/translate.go b/executor/vm/internal/translate/translate.go new file mode 100644 index 00000000..f3a54680 --- /dev/null +++ b/executor/vm/internal/translate/translate.go @@ -0,0 +1,292 @@ +// Package translate 负责 CALL eager 翻译与 RETURN 处理。 +// +// CALL 时一次性将编译层 dxlang 指令翻译为执行层 [i,j] 坐标, +// 后续逐条执行时零解析开销。 +package translate + +import ( + "context" + "fmt" + "log" + "sort" + "strconv" + "strings" + + "deepx/executor/vm/internal/ir" + "deepx/executor/vm/internal/route" + "deepx/executor/vm/internal/state" + "github.com/redis/go-redis/v9" +) + +// handleCall 执行 CALL 指令的 eager 翻译。 +// 返回子栈第一条指令的 PC;致命错误时设置 error 状态并返回当前 pc。 +func HandleCall(ctx context.Context, rdb *redis.Client, vtid string, pc string, inst *ir.Instruction) string { + funcName := inst.Reads[0] + + // 1. 确定 backend + backend := route.DetermineBackend(ctx, rdb, funcName) + + // 2. 读取编译层函数签名 + sig, err := rdb.Get(ctx, fmt.Sprintf("/op/%s/func/%s", backend, funcName)).Result() + if err != nil { + sig, err = rdb.Get(ctx, "/src/func/"+funcName).Result() + if err != nil { + msg := fmt.Sprintf("func %s not found in /op/%s/func/ or /src/func/", funcName, backend) + log.Printf("[%s] CALL error: %s", vtid, msg) + state.SetError(ctx, rdb, vtid, pc, msg) + return pc + } + } + + // 3. 解析签名 → 形参列表 + formalParams := parseSignature(sig) + + // 4. 建立形参→实参映射 + bindings := make(map[string]string) + for i, param := range formalParams.Reads { + if i+1 < len(inst.Reads) { + bindings[param] = inst.Reads[i+1] + } + } + for i, param := range formalParams.Writes { + if i < len(inst.Writes) { + bindings[param] = inst.Writes[i] + } + } + + // 5. 批量 MGET 编译层所有指令 + compiled := mgetAll(ctx, rdb, fmt.Sprintf("/op/%s/func/%s", backend, funcName)) + if len(compiled) == 0 { + compiled = mgetAll(ctx, rdb, "/src/func/"+funcName) + } + + // 6. 逐条翻译 → Pipeline 批量写入子栈 + substackRoot := fmt.Sprintf("/vthread/%s/%s/", vtid, pc) + pipe := rdb.Pipeline() + + bodyCount := len(compiled) + for i, dxlangLine := range compiled { + parsed, err := ir.ParseDxlang(dxlangLine) + if err != nil { + msg := fmt.Sprintf("parse error at body[%d]: %v", i, err) + log.Printf("[%s] CALL translate error: %s", vtid, msg) + state.SetError(ctx, rdb, vtid, pc, msg) + return pc + } + + replaceParams(parsed.Reads, bindings) + replaceParams(parsed.Writes, bindings) + + pipe.Set(ctx, fmt.Sprintf("%s[%d,0]", substackRoot, i), parsed.Opcode, 0) + for j, r := range parsed.Reads { + pipe.Set(ctx, fmt.Sprintf("%s[%d,-%d]", substackRoot, i, j+1), r, 0) + } + for j, w := range parsed.Writes { + pipe.Set(ctx, fmt.Sprintf("%s[%d,%d]", substackRoot, i, j+1), w, 0) + } + } + + // 7. 追加隐式 return 指令 (将最后一个输出形参的值回传父栈) + if len(formalParams.Writes) > 0 { + retIdx := bodyCount + retSlot := formalParams.Writes[0] // e.g., "C" + retRef := retSlot + // 绝对路径直接使用, 形参名称加 ./前缀 (相对 vthread 空间) + if !strings.HasPrefix(retSlot, "/") { + retRef = "./" + retSlot + } + pipe.Set(ctx, fmt.Sprintf("%s[%d,0]", substackRoot, retIdx), "return", 0) + pipe.Set(ctx, fmt.Sprintf("%s[%d,-1]", substackRoot, retIdx), retRef, 0) + } + + _, err = pipe.Exec(ctx) + if err != nil { + msg := fmt.Sprintf("CALL translate pipeline failed: %v", err) + log.Printf("[%s] CALL error: %s", vtid, msg) + state.SetError(ctx, rdb, vtid, pc, msg) + return pc + } + + // 8. 返回子栈第一条指令的 PC + return pc + "/[0,0]" +} + +// HandleReturn 处理 RETURN 指令。 +// 返回父栈 CALL 指令的下一条 PC。 +func HandleReturn(ctx context.Context, rdb *redis.Client, vtid string, pc string) string { + lastSlash := strings.LastIndex(pc, "/") + if lastSlash < 0 { + return pc // 根栈 return → vthread 即将 done + } + + parentPC := pc[:lastSlash] + + // 1. 读取返回值写入父 CALL 的返回值槽位 + inst, err := ir.Decode(ctx, rdb, vtid, pc) + if err == nil { + parentInst, err := ir.Decode(ctx, rdb, vtid, parentPC) + if err == nil && len(parentInst.Writes) > 0 && len(inst.Reads) > 0 { + retSlot := parentInst.Writes[0] // e.g., "./c" + retRef := inst.Reads[0] // e.g., "./C" (slot reference) + // 解析 retRef: 如果是相对路径则读取其实际值 + retVal := retRef + if strings.HasPrefix(retRef, "./") { + srcKey := "/vthread/" + vtid + "/" + retRef[2:] + if v, e := rdb.Get(ctx, srcKey).Result(); e == nil { + retVal = v + } + } + if strings.HasPrefix(retSlot, "./") { + slotKey := "/vthread/" + vtid + "/" + retSlot[2:] + rdb.Set(ctx, slotKey, retVal, 0) + } + } + } + + // 2. 删除当前子栈 + keys, err := rdb.Keys(ctx, "/vthread/"+vtid+"/"+parentPC+"/*").Result() + if err != nil { + log.Printf("[%s] RETURN KEYS error: %v", vtid, err) + } else if len(keys) > 0 { + if err := rdb.Del(ctx, keys...).Err(); err != nil { + log.Printf("[%s] RETURN DEL error: %v", vtid, err) + } + } + + // 3. PC 恢复到父栈 CALL 的下一条 + return ir.NextPC(parentPC) +} + +// FormalParams 函数形参列表 +type FormalParams struct { + Reads []string + Writes []string +} + +// parseSignature 解析函数签名 +// +// "def add_test(A:int, B:int) -> (C:int)" → Reads:["A","B"], Writes:["C"] +// "(add_test(A, B) -> (C))" → Reads:["A","B"], Writes:["C"] (legacy) +func parseSignature(sig string) FormalParams { + var fp FormalParams + sig = strings.TrimSpace(sig) + + // strip "def " prefix (new format) + if strings.HasPrefix(sig, "def ") { + sig = strings.TrimSpace(sig[4:]) + } + + // strip outer parens (legacy format) + if len(sig) >= 2 && sig[0] == '(' && sig[len(sig)-1] == ')' { + sig = sig[1 : len(sig)-1] + } + + arrow := strings.Index(sig, "->") + if arrow < 0 { + return fp + } + + left := strings.TrimSpace(sig[:arrow]) + right := strings.TrimSpace(sig[arrow+2:]) + + if lp := strings.Index(left, "("); lp >= 0 { + rp := strings.LastIndex(left, ")") + if rp > lp { + fp.Reads = extractParamNames(left[lp+1 : rp]) + } + } + + right = strings.TrimSpace(right) + if len(right) >= 2 && right[0] == '(' && right[len(right)-1] == ')' { + fp.Writes = extractParamNames(right[1 : len(right)-1]) + } else { + fp.Writes = extractParamNames(right) + } + + return fp +} + +// extractParamNames 从 "A:tensor, B:tensor, alpha:f32" 提取 ["A", "B", "alpha"] +func extractParamNames(s string) []string { + var names []string + for _, p := range strings.Split(s, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if colon := strings.Index(p, ":"); colon >= 0 { + p = p[:colon] + } + p = strings.TrimSpace(p) + if p != "" { + names = append(names, p) + } + } + return names +} + +// mgetAll 批量读取指定 base 路径下的所有编译层/源码层指令。 +// base 格式: "/op/op-metal/func/gemm" 或 "/src/func/gemm" +// KEYS 返回顺序不确定,按数字后缀排序以保证指令顺序。 +func mgetAll(ctx context.Context, rdb *redis.Client, base string) []string { + keys, err := rdb.Keys(ctx, base+"/*").Result() + if err != nil { + log.Printf("mgetAll KEYS error for %s: %v", base, err) + return nil + } + if len(keys) == 0 { + return nil + } + + type indexedKey struct { + key string + index int + } + var sorted []indexedKey + basePrefix := base + "/" + for _, k := range keys { + if !strings.HasPrefix(k, basePrefix) { + continue + } + suffix := k[len(basePrefix):] + n, err := strconv.Atoi(suffix) + if err != nil { + log.Printf("mgetAll skip non-numeric key: %s", k) + continue + } + sorted = append(sorted, indexedKey{key: k, index: n}) + } + sort.Slice(sorted, func(i, j int) bool { return sorted[i].index < sorted[j].index }) + + orderedKeys := make([]string, len(sorted)) + for i, sk := range sorted { + orderedKeys[i] = sk.key + } + + if len(orderedKeys) == 0 { + return nil + } + + vals, err := rdb.MGet(ctx, orderedKeys...).Result() + if err != nil { + log.Printf("mgetAll MGET error for %s: %v", base, err) + return nil + } + + result := make([]string, 0, len(vals)) + for _, v := range vals { + if s, ok := v.(string); ok { + result = append(result, s) + } + } + return result +} + +// replaceParams 将形参替换为实参 (原地修改)。 +func replaceParams(params []string, bindings map[string]string) { + for i, p := range params { + if v, ok := bindings[p]; ok { + params[i] = v + } + } +} diff --git a/executor/vm/ir_test.go b/executor/vm/ir_test.go new file mode 100644 index 00000000..7de83455 --- /dev/null +++ b/executor/vm/ir_test.go @@ -0,0 +1,101 @@ +package vm_test + +import ( + "context" + "testing" + + "deepx/executor/vm/internal/ir" + "deepx/executor/vm/internal/route" + "github.com/redis/go-redis/v9" +) + +// ── PC navigation (salvaged from deleted engine_test.go) ── + +func TestNextPC(t *testing.T) { + tests := []struct { + pc string + want string + }{ + {"[0,0]", "[1,0]"}, + {"[3,0]", "[4,0]"}, + {"[0,0]/[0,0]", "[0,0]/[1,0]"}, + {"[2,0]/[3,0]", "[2,0]/[4,0]"}, + } + + for _, tc := range tests { + if got := ir.NextPC(tc.pc); got != tc.want { + t.Errorf("NextPC(%q) = %q, want %q", tc.pc, got, tc.want) + } + } +} + +func TestParentPC(t *testing.T) { + tests := []struct { + pc string + want string + }{ + {"[2,0]/[1,0]", "[3,0]"}, + {"[0,0]/[5,0]", "[1,0]"}, + {"[0,0]/[3,0]/[2,0]", "[0,0]/[4,0]"}, + } + + for _, tc := range tests { + if got := ir.ParentPC(tc.pc); got != tc.want { + t.Errorf("ParentPC(%q) = %q, want %q", tc.pc, got, tc.want) + } + } +} + +func TestIsComputeOp(t *testing.T) { + compute := []string{"add", "sub", "mul", "div", "matmul", "relu", "sigmoid", "tanh"} + control := []string{"call", "return", "if", "for"} + lifecycle := []string{"newtensor", "deltensor", "clonetensor"} + + for _, op := range compute { + if !ir.IsComputeOp(op) { + t.Errorf("IsComputeOp(%q) = false, want true", op) + } + } + for _, op := range control { + if ir.IsComputeOp(op) { + t.Errorf("IsComputeOp(%q) = true, want false", op) + } + } + for _, op := range lifecycle { + if ir.IsLifecycleOp(op) && ir.IsComputeOp(op) { + t.Errorf("IsLifecycleOp(%q) should not also be IsComputeOp", op) + } + } +} + +func TestDecodeFromCache(t *testing.T) { + cache := map[string]string{ + "[3,0]": "add", + "[3,-1]": "./a", + "[3,-2]": "./b", + "[3,1]": "./c", + } + + inst := ir.DecodeFromCache(cache, "[3,0]") + if inst.Opcode != "add" { + t.Errorf("opcode = %q, want 'add'", inst.Opcode) + } + if len(inst.Reads) != 2 || inst.Reads[0] != "./a" || inst.Reads[1] != "./b" { + t.Errorf("reads = %v, want [./a ./b]", inst.Reads) + } + if len(inst.Writes) != 1 || inst.Writes[0] != "./c" { + t.Errorf("writes = %v, want [./c]", inst.Writes) + } +} + +// ── Route: error handling (no live Redis needed) ── + +func TestRouteSelect_NoRedis(t *testing.T) { + rdb := redis.NewClient(&redis.Options{Addr: "127.0.0.1:9999"}) + ctx := context.Background() + _, err := route.Select(ctx, rdb, "add") + if err == nil { + t.Error("expected error when Redis is not available") + } + t.Logf("expected error: %v", err) +} diff --git a/executor/vm/parse_dx_test.go b/executor/vm/parse_dx_test.go new file mode 100644 index 00000000..a6d07837 --- /dev/null +++ b/executor/vm/parse_dx_test.go @@ -0,0 +1,232 @@ +package vm_test + +import ( + "testing" + + "deepx/executor/vm/internal/ir" + "deepx/executor/vm/testutil" +) + +// wantInst 定义期望的指令结构 +type wantInst struct { + op string + reads []string + writes []string +} + +// verifyInst 验证一条指令的解析结果 +func verifyInst(t *testing.T, dxFile string, lineIdx int, inst *ir.Instruction, want wantInst) { + t.Helper() + if inst.Opcode != want.op { + t.Errorf("[%s] line[%d] opcode=%s, want %s", dxFile, lineIdx, inst.Opcode, want.op) + } + if len(inst.Reads) != len(want.reads) { + t.Errorf("[%s] line[%d] reads len=%d, want %d (%v vs %v)", dxFile, lineIdx, len(inst.Reads), len(want.reads), inst.Reads, want.reads) + return + } + for i := range inst.Reads { + if inst.Reads[i] != want.reads[i] { + t.Errorf("[%s] line[%d] reads[%d]=%s, want %s", dxFile, lineIdx, i, inst.Reads[i], want.reads[i]) + } + } + if len(inst.Writes) != len(want.writes) { + t.Errorf("[%s] line[%d] writes len=%d, want %d (%v vs %v)", dxFile, lineIdx, len(inst.Writes), len(want.writes), inst.Writes, want.writes) + return + } + for i := range inst.Writes { + if inst.Writes[i] != want.writes[i] { + t.Errorf("[%s] line[%d] writes[%d]=%s, want %s", dxFile, lineIdx, i, inst.Writes[i], want.writes[i]) + } + } +} + +// checkDx 加载 .dx 文件并逐行验证解析结果 +func checkDx(t *testing.T, dxFile string, wants []wantInst) { + t.Helper() + fn, err := testutil.LoadDxFile(dxFile) + if err != nil { + t.Fatalf("LoadDxFile(%s): %v", dxFile, err) + } + if len(fn.Body) != len(wants) { + t.Fatalf("[%s] body has %d lines, want %d:\n got: %v\n want: %v", dxFile, len(fn.Body), len(wants), fn.Body, wants) + } + for i, w := range wants { + inst, err := ir.ParseDxlang(fn.Body[i]) + if err != nil { + t.Errorf("[%s] line[%d] parse error: %v", dxFile, i, err) + continue + } + verifyInst(t, dxFile, i, inst, w) + } +} + +// ── Lifecycle ─────────────────────────────────────────────── + +func TestParse_Lifecycle(t *testing.T) { + t.Run("newtensor", func(t *testing.T) { + checkDx(t, "../../example/dxlang/tensor/lifecycle/newtensor.dx", []wantInst{ + {op: "newtensor", reads: []string{"f32", "[16]"}, writes: []string{"/data/x"}}, + }) + }) + t.Run("del", func(t *testing.T) { + checkDx(t, "../../example/dxlang/tensor/lifecycle/del.dx", []wantInst{ + {op: "newtensor", reads: []string{"f32", "[8]"}, writes: []string{"/data/tmp"}}, + {op: "deltensor", reads: []string{"/data/tmp"}, writes: nil}, + }) + }) + t.Run("compute_small", func(t *testing.T) { + checkDx(t, "../../example/dxlang/tensor/lifecycle/compute.dx", []wantInst{ + {op: "newtensor", reads: []string{"f32", "[8]"}, writes: []string{"/data/a"}}, + {op: "newtensor", reads: []string{"f32", "[8]"}, writes: []string{"/data/b"}}, + {op: "newtensor", reads: []string{"f32", "[8]"}, writes: []string{"/data/c"}}, + {op: "zeros", reads: nil, writes: []string{"/data/a"}}, + {op: "zeros", reads: nil, writes: []string{"/data/b"}}, + {op: "add", reads: []string{"/data/a", "/data/b"}, writes: []string{"/data/c"}}, + {op: "deltensor", reads: []string{"/data/a"}, writes: nil}, + {op: "deltensor", reads: []string{"/data/b"}, writes: nil}, + }) + }) + +} + +// ── Call / Function Nesting ───────────────────────────────── + +func TestParse_Call(t *testing.T) { + t.Run("add_test", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/call/add_test.dx", []wantInst{ + {op: "add", reads: []string{"A", "B"}, writes: []string{"./C"}}, + }) + }) + t.Run("callee", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/call/callee.dx", []wantInst{ + {op: "+", reads: []string{"X", "Y"}, writes: []string{"./Z"}}, + }) + }) + t.Run("caller", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/call/caller.dx", []wantInst{ + {op: "callee", reads: []string{"A", "B"}, writes: []string{"./C"}}, + }) + }) + t.Run("middle", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/call/middle.dx", []wantInst{ + {op: "leaf", reads: []string{"X"}, writes: []string{"./tmp"}}, + {op: "+", reads: []string{"./tmp", "1"}, writes: []string{"./Y"}}, + }) + }) + t.Run("deep3", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/call/deep3.dx", []wantInst{ + {op: "middle", reads: []string{"X"}, writes: []string{"./Y"}}, + }) + }) + t.Run("diamond", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/call/diamond.dx", []wantInst{ + {op: "double", reads: []string{"A"}, writes: []string{"./d"}}, + {op: "triple", reads: []string{"A"}, writes: []string{"./t"}}, + {op: "+", reads: []string{"./d", "./t"}, writes: []string{"./R"}}, + }) + }) + t.Run("double", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/call/double.dx", []wantInst{ + {op: "*", reads: []string{"X", "2"}, writes: []string{"./Y"}}, + }) + }) + t.Run("triple", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/call/triple.dx", []wantInst{ + {op: "*", reads: []string{"X", "3"}, writes: []string{"./Y"}}, + }) + }) +} + +// ── Native: Arithmetic ────────────────────────────────────── + +func TestParse_NativeArith(t *testing.T) { + tests := []struct { + name string + file string + op string + }{ + {"add", "../../example/dxlang/builtin/native/arith/add.dx", "+"}, + {"sub", "../../example/dxlang/builtin/native/arith/sub.dx", "-"}, + {"mul", "../../example/dxlang/builtin/native/arith/mul.dx", "*"}, + {"div", "../../example/dxlang/builtin/native/arith/div.dx", "/"}, + {"neg", "../../example/dxlang/builtin/native/arith/neg.dx", "neg"}, + {"abs", "../../example/dxlang/builtin/native/arith/abs.dx", "abs"}, + {"sign", "../../example/dxlang/builtin/native/arith/sign.dx", "sign"}, + {"pow", "../../example/dxlang/builtin/native/arith/pow.dx", "pow"}, + {"exp", "../../example/dxlang/builtin/native/arith/exp.dx", "exp"}, + {"log", "../../example/dxlang/builtin/native/arith/log.dx", "log"}, + {"sqrt", "../../example/dxlang/builtin/native/arith/sqrt.dx", "sqrt"}, + {"max", "../../example/dxlang/builtin/native/arith/max.dx", "max"}, + {"min", "../../example/dxlang/builtin/native/arith/min.dx", "min"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, err := testutil.LoadDxFile(tt.file) + if err != nil { + t.Fatal(err) + } + if len(fn.Body) == 0 { + t.Fatal("empty body") + } + inst, err := ir.ParseDxlang(fn.Body[0]) + if err != nil { + t.Fatal(err) + } + if inst.Opcode != tt.op { + t.Errorf("opcode=%s, want %s", inst.Opcode, tt.op) + } + }) + } +} + +// ── Native: Compare / Logic / Cast / Chain ────────────────── + +func TestParse_NativeOther(t *testing.T) { + t.Run("compare", func(t *testing.T) { + t.Run("eq", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/native/compare/eq.dx", []wantInst{ + {op: "==", reads: []string{"A", "B"}, writes: []string{"./C"}}, + }) + }) + t.Run("lt", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/native/compare/lt.dx", []wantInst{ + {op: "<", reads: []string{"A", "B"}, writes: []string{"./C"}}, + }) + }) + }) + t.Run("logic", func(t *testing.T) { + t.Run("and", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/native/logic/and.dx", []wantInst{ + {op: "&&", reads: []string{"A", "B"}, writes: []string{"./C"}}, + }) + }) + t.Run("not", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/native/logic/not.dx", []wantInst{ + {op: "!", reads: []string{"A"}, writes: []string{"./C"}}, + }) + }) + t.Run("bool", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/native/logic/bool.dx", []wantInst{ + {op: "bool", reads: []string{"A"}, writes: []string{"./C"}}, + }) + }) + }) + t.Run("cast", func(t *testing.T) { + t.Run("int", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/native/cast/int.dx", []wantInst{ + {op: "int", reads: []string{"A"}, writes: []string{"./C"}}, + }) + }) + t.Run("float", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/native/cast/float.dx", []wantInst{ + {op: "float", reads: []string{"A"}, writes: []string{"./C"}}, + }) + }) + }) + t.Run("chain", func(t *testing.T) { + checkDx(t, "../../example/dxlang/builtin/native/chain/chain.dx", []wantInst{ + {op: "+", reads: []string{"A", "B"}, writes: []string{"./tmp"}}, + {op: "*", reads: []string{"./tmp", "C"}, writes: []string{"./D"}}, + }) + }) +} diff --git a/executor/vm/testdata/call/add_test.dx b/executor/vm/testdata/call/add_test.dx new file mode 100644 index 00000000..6386675e --- /dev/null +++ b/executor/vm/testdata/call/add_test.dx @@ -0,0 +1,4 @@ +# add_test: basic element-wise addition +def add_test(A:tensor, B:tensor) -> (C:tensor) { + add(A, B) -> ./C +} diff --git a/executor/vm/testdata/call/callee.dx b/executor/vm/testdata/call/callee.dx new file mode 100644 index 00000000..59b98e95 --- /dev/null +++ b/executor/vm/testdata/call/callee.dx @@ -0,0 +1,4 @@ +# callee: a simple function called by caller +def callee(X:int, Y:int) -> (Z:int) { + X + Y -> ./Z +} diff --git a/executor/vm/testdata/call/caller.dx b/executor/vm/testdata/call/caller.dx new file mode 100644 index 00000000..da161395 --- /dev/null +++ b/executor/vm/testdata/call/caller.dx @@ -0,0 +1,4 @@ +# caller: calls callee to add two numbers +def caller(A:int, B:int) -> (C:int) { + callee(A, B) -> ./C +} diff --git a/executor/vm/testdata/call/cstyle_call.dx b/executor/vm/testdata/call/cstyle_call.dx new file mode 100644 index 00000000..db1a9af0 --- /dev/null +++ b/executor/vm/testdata/call/cstyle_call.dx @@ -0,0 +1,4 @@ +# cstyle_call: C-style assignment with function call +def cstyle_call(A:int, B:int) -> (C:int) { + ./C <- add(A, B) +} diff --git a/executor/vm/testdata/call/deep3.dx b/executor/vm/testdata/call/deep3.dx new file mode 100644 index 00000000..c5bcc7c0 --- /dev/null +++ b/executor/vm/testdata/call/deep3.dx @@ -0,0 +1,4 @@ +# deep3 calls middle, middle calls leaf — 3 levels deep +def deep3(X:int) -> (Y:int) { + middle(X) -> ./Y +} diff --git a/executor/vm/testdata/call/diamond.dx b/executor/vm/testdata/call/diamond.dx new file mode 100644 index 00000000..2573f868 --- /dev/null +++ b/executor/vm/testdata/call/diamond.dx @@ -0,0 +1,6 @@ +# diamond: splits into double+triple, then sums results +def diamond(A:int) -> (R:int) { + double(A) -> ./d + triple(A) -> ./t + ./d + ./t -> ./R +} diff --git a/executor/vm/testdata/call/double.dx b/executor/vm/testdata/call/double.dx new file mode 100644 index 00000000..f525d9e7 --- /dev/null +++ b/executor/vm/testdata/call/double.dx @@ -0,0 +1,4 @@ +# double: multiply by 2 +def double(X:int) -> (Y:int) { + X * 2 -> ./Y +} diff --git a/executor/vm/testdata/call/leaf.dx b/executor/vm/testdata/call/leaf.dx new file mode 100644 index 00000000..7cb954ad --- /dev/null +++ b/executor/vm/testdata/call/leaf.dx @@ -0,0 +1,4 @@ +# leaf: multiplies input by 2 +def leaf(X:int) -> (Y:int) { + X * 2 -> ./Y +} diff --git a/executor/vm/testdata/call/middle.dx b/executor/vm/testdata/call/middle.dx new file mode 100644 index 00000000..e9d74fe2 --- /dev/null +++ b/executor/vm/testdata/call/middle.dx @@ -0,0 +1,5 @@ +# middle: calls leaf then adds 1 +def middle(X:int) -> (Y:int) { + leaf(X) -> ./tmp + ./tmp + 1 -> ./Y +} diff --git a/executor/vm/testdata/call/triple.dx b/executor/vm/testdata/call/triple.dx new file mode 100644 index 00000000..09d2d947 --- /dev/null +++ b/executor/vm/testdata/call/triple.dx @@ -0,0 +1,4 @@ +# triple: multiply by 3 +def triple(X:int) -> (Y:int) { + X * 3 -> ./Y +} diff --git a/executor/vm/testdata/lifecycle/compute.dx b/executor/vm/testdata/lifecycle/compute.dx new file mode 100644 index 00000000..39232691 --- /dev/null +++ b/executor/vm/testdata/lifecycle/compute.dx @@ -0,0 +1,9 @@ +# compute: create 2 tensors, compute their sum into a 3rd, cleanup inputs +def compute() -> (/data/c) { + newtensor("f32", "[8]") -> /data/a + newtensor("f32", "[8]") -> /data/b + newtensor("f32", "[8]") -> /data/c + add(/data/a, /data/b) -> /data/c + deltensor(/data/a) + deltensor(/data/b) +} diff --git a/executor/vm/testdata/lifecycle/del.dx b/executor/vm/testdata/lifecycle/del.dx new file mode 100644 index 00000000..1684e1de --- /dev/null +++ b/executor/vm/testdata/lifecycle/del.dx @@ -0,0 +1,5 @@ +# lifecycle_del: create then delete a heap tensor +def lifecycle_del() -> () { + newtensor("f32", "[8]") -> /data/tmp + deltensor(/data/tmp) +} diff --git a/executor/vm/testdata/lifecycle/full.dx b/executor/vm/testdata/lifecycle/full.dx new file mode 100644 index 00000000..f2f6b856 --- /dev/null +++ b/executor/vm/testdata/lifecycle/full.dx @@ -0,0 +1,9 @@ +# lifecycle_full: create tensors, compute, then cleanup +def lifecycle_full() -> (/data/c) { + newtensor("f32", "[4]") -> /data/a + newtensor("f32", "[4]") -> /data/b + newtensor("f32", "[4]") -> /data/c + add(/data/a, /data/b) -> /data/c + deltensor(/data/a) + deltensor(/data/b) +} diff --git a/executor/vm/testdata/lifecycle/newtensor.dx b/executor/vm/testdata/lifecycle/newtensor.dx new file mode 100644 index 00000000..c2edde7b --- /dev/null +++ b/executor/vm/testdata/lifecycle/newtensor.dx @@ -0,0 +1,4 @@ +# lifecycle_newtensor: create a heap tensor and store its reference +def lifecycle_newtensor() -> (/data/x) { + newtensor("f32", "[16]") -> /data/x +} diff --git a/executor/vm/testdata/native/arith/abs.dx b/executor/vm/testdata/native/arith/abs.dx new file mode 100644 index 00000000..c23e0c9f --- /dev/null +++ b/executor/vm/testdata/native/arith/abs.dx @@ -0,0 +1,3 @@ +def native_abs(A:int) -> (C:int) { + abs(A) -> ./C +} diff --git a/executor/vm/testdata/native/arith/add.dx b/executor/vm/testdata/native/arith/add.dx new file mode 100644 index 00000000..3f29f9de --- /dev/null +++ b/executor/vm/testdata/native/arith/add.dx @@ -0,0 +1,3 @@ +def native_arith(A:int, B:int) -> (C:int) { + A + B -> ./C +} diff --git a/executor/vm/testdata/native/arith/cstyle_add.dx b/executor/vm/testdata/native/arith/cstyle_add.dx new file mode 100644 index 00000000..d05a2c88 --- /dev/null +++ b/executor/vm/testdata/native/arith/cstyle_add.dx @@ -0,0 +1,4 @@ +# cstyle_add: C-style assignment with infix operator +def cstyle_add(A:int, B:int) -> (C:int) { + ./C <- A + B +} diff --git a/executor/vm/testdata/native/arith/div.dx b/executor/vm/testdata/native/arith/div.dx new file mode 100644 index 00000000..6b5b6b9e --- /dev/null +++ b/executor/vm/testdata/native/arith/div.dx @@ -0,0 +1,3 @@ +def native_div(A:int, B:int) -> (C:float) { + A / B -> ./C +} diff --git a/executor/vm/testdata/native/arith/exp.dx b/executor/vm/testdata/native/arith/exp.dx new file mode 100644 index 00000000..fa1ef8ab --- /dev/null +++ b/executor/vm/testdata/native/arith/exp.dx @@ -0,0 +1,3 @@ +def native_exp(A:int) -> (C:float) { + exp(A) -> ./C +} diff --git a/executor/vm/testdata/native/arith/log.dx b/executor/vm/testdata/native/arith/log.dx new file mode 100644 index 00000000..dc949799 --- /dev/null +++ b/executor/vm/testdata/native/arith/log.dx @@ -0,0 +1,3 @@ +def native_log(A:int) -> (C:float) { + log(A) -> ./C +} diff --git a/executor/vm/testdata/native/arith/max.dx b/executor/vm/testdata/native/arith/max.dx new file mode 100644 index 00000000..22d0429e --- /dev/null +++ b/executor/vm/testdata/native/arith/max.dx @@ -0,0 +1,3 @@ +def native_max(A:int, B:int) -> (C:int) { + max(A, B) -> ./C +} diff --git a/executor/vm/testdata/native/arith/min.dx b/executor/vm/testdata/native/arith/min.dx new file mode 100644 index 00000000..f5b4d5a2 --- /dev/null +++ b/executor/vm/testdata/native/arith/min.dx @@ -0,0 +1,3 @@ +def native_min(A:int, B:int) -> (C:int) { + min(A, B) -> ./C +} diff --git a/executor/vm/testdata/native/arith/mul.dx b/executor/vm/testdata/native/arith/mul.dx new file mode 100644 index 00000000..c9e69e72 --- /dev/null +++ b/executor/vm/testdata/native/arith/mul.dx @@ -0,0 +1,3 @@ +def native_mul(A:int, B:int) -> (C:int) { + A * B -> ./C +} diff --git a/executor/vm/testdata/native/arith/neg.dx b/executor/vm/testdata/native/arith/neg.dx new file mode 100644 index 00000000..0ce68dc7 --- /dev/null +++ b/executor/vm/testdata/native/arith/neg.dx @@ -0,0 +1,3 @@ +def native_neg(A:int) -> (C:int) { + neg(A) -> ./C +} diff --git a/executor/vm/testdata/native/arith/pow.dx b/executor/vm/testdata/native/arith/pow.dx new file mode 100644 index 00000000..74eb0f30 --- /dev/null +++ b/executor/vm/testdata/native/arith/pow.dx @@ -0,0 +1,3 @@ +def native_pow(A:int, B:int) -> (C:float) { + pow(A, B) -> ./C +} diff --git a/executor/vm/testdata/native/arith/sign.dx b/executor/vm/testdata/native/arith/sign.dx new file mode 100644 index 00000000..95120dfb --- /dev/null +++ b/executor/vm/testdata/native/arith/sign.dx @@ -0,0 +1,3 @@ +def native_sign(A:int) -> (C:int) { + sign(A) -> ./C +} diff --git a/executor/vm/testdata/native/arith/sqrt.dx b/executor/vm/testdata/native/arith/sqrt.dx new file mode 100644 index 00000000..2a488966 --- /dev/null +++ b/executor/vm/testdata/native/arith/sqrt.dx @@ -0,0 +1,3 @@ +def native_sqrt(A:int) -> (C:float) { + sqrt(A) -> ./C +} diff --git a/executor/vm/testdata/native/arith/sub.dx b/executor/vm/testdata/native/arith/sub.dx new file mode 100644 index 00000000..a24e1634 --- /dev/null +++ b/executor/vm/testdata/native/arith/sub.dx @@ -0,0 +1,3 @@ +def native_sub(A:int, B:int) -> (C:int) { + A - B -> ./C +} diff --git a/executor/vm/testdata/native/cast/float.dx b/executor/vm/testdata/native/cast/float.dx new file mode 100644 index 00000000..0834c630 --- /dev/null +++ b/executor/vm/testdata/native/cast/float.dx @@ -0,0 +1,3 @@ +def native_float(A:int) -> (C:float) { + float(A) -> ./C +} diff --git a/executor/vm/testdata/native/cast/int.dx b/executor/vm/testdata/native/cast/int.dx new file mode 100644 index 00000000..6fb7a496 --- /dev/null +++ b/executor/vm/testdata/native/cast/int.dx @@ -0,0 +1,3 @@ +def native_int(A:float) -> (C:int) { + int(A) -> ./C +} diff --git a/executor/vm/testdata/native/chain/chain.dx b/executor/vm/testdata/native/chain/chain.dx new file mode 100644 index 00000000..d30c1cfa --- /dev/null +++ b/executor/vm/testdata/native/chain/chain.dx @@ -0,0 +1,4 @@ +def native_chain(A:int, B:int, C:int) -> (D:int) { + A + B -> ./tmp + ./tmp * C -> ./D +} diff --git a/executor/vm/testdata/native/compare/eq.dx b/executor/vm/testdata/native/compare/eq.dx new file mode 100644 index 00000000..5001e26c --- /dev/null +++ b/executor/vm/testdata/native/compare/eq.dx @@ -0,0 +1,3 @@ +def native_eq(A:int, B:int) -> (C:bool) { + A == B -> ./C +} diff --git a/executor/vm/testdata/native/compare/lt.dx b/executor/vm/testdata/native/compare/lt.dx new file mode 100644 index 00000000..d8dda666 --- /dev/null +++ b/executor/vm/testdata/native/compare/lt.dx @@ -0,0 +1,3 @@ +def native_lt(A:int, B:int) -> (C:bool) { + A < B -> ./C +} diff --git a/executor/vm/testdata/native/logic/and.dx b/executor/vm/testdata/native/logic/and.dx new file mode 100644 index 00000000..6fce519d --- /dev/null +++ b/executor/vm/testdata/native/logic/and.dx @@ -0,0 +1,3 @@ +def native_and(A:bool, B:bool) -> (C:bool) { + A && B -> ./C +} diff --git a/executor/vm/testdata/native/logic/bool.dx b/executor/vm/testdata/native/logic/bool.dx new file mode 100644 index 00000000..303d1e49 --- /dev/null +++ b/executor/vm/testdata/native/logic/bool.dx @@ -0,0 +1,3 @@ +def native_bool(A:int) -> (C:bool) { + bool(A) -> ./C +} diff --git a/executor/vm/testdata/native/logic/not.dx b/executor/vm/testdata/native/logic/not.dx new file mode 100644 index 00000000..f25c2976 --- /dev/null +++ b/executor/vm/testdata/native/logic/not.dx @@ -0,0 +1,3 @@ +def native_not(A:bool) -> (C:bool) { + !A -> ./C +} diff --git a/executor/vm/testhelpers_test.go b/executor/vm/testhelpers_test.go new file mode 100644 index 00000000..617a18ab --- /dev/null +++ b/executor/vm/testhelpers_test.go @@ -0,0 +1,77 @@ +package vm_test + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/redis/go-redis/v9" +) + +// connectRedisIntegration connects to Redis for integration tests. +// Uses REDIS_ADDR env or defaults to 127.0.0.1:6379. +func connectRedisIntegration(t *testing.T) (*redis.Client, context.Context) { + t.Helper() + addr := os.Getenv("REDIS_ADDR") + if addr == "" { + addr = "127.0.0.1:16379" + } + ctx := context.Background() + rdb := redis.NewClient(&redis.Options{Addr: addr, PoolSize: 10, MinIdleConns: 2}) + if err := rdb.Ping(ctx).Err(); err != nil { + t.Fatalf("Redis not available at %s: %v (set REDIS_ADDR or start Redis)", addr, err) + } + rdb.FlushDB(ctx) + return rdb, ctx +} + +// waitVthreadDone polls the vthread state until it reaches "done" or "error". +// Returns named slot values on success. +func waitVthreadDone(t *testing.T, rdb *redis.Client, vtid string, timeout time.Duration) (map[string]string, bool) { + t.Helper() + ctx := context.Background() + ticker := time.NewTicker(30 * time.Millisecond) + defer ticker.Stop() + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + <-ticker.C + val, err := rdb.Get(ctx, "/vthread/"+vtid).Result() + if err == redis.Nil { + continue + } + if err != nil { + continue + } + var s struct { + Status string `json:"status"` + PC string `json:"pc"` + Error map[string]string `json:"error,omitempty"` + } + json.Unmarshal([]byte(val), &s) + + switch s.Status { + case "done": + // read named slots + keys, _ := rdb.Keys(ctx, "/vthread/"+vtid+"/*").Result() + outputs := make(map[string]string) + prefix := "/vthread/" + vtid + "/" + for _, k := range keys { + if v, err := rdb.Get(ctx, k).Result(); err == nil { + slot := k[len(prefix):] + if len(slot) > 0 && slot[0] != '[' { + outputs[slot] = v + } + } + } + return outputs, true + case "error": + t.Logf("vtid=%s error: %v", vtid, s.Error) + return nil, false + } + } + t.Logf("vtid=%s timeout after %v", vtid, timeout) + return nil, false +} \ No newline at end of file diff --git a/executor/vm/testutil/dxloader.go b/executor/vm/testutil/dxloader.go new file mode 100644 index 00000000..b320ffdf --- /dev/null +++ b/executor/vm/testutil/dxloader.go @@ -0,0 +1,348 @@ +// Package testutil provides helpers for VM integration testing, +// including loading .dx function source files and registering them in Redis. +package testutil + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + "github.com/redis/go-redis/v9" +) + +// DxFunc represents a parsed dxlang function from a .dx file. +type DxFunc struct { + Name string // e.g., "add_test" + Signature string // e.g., "def add_test(A:int, B:int) -> (C:int)" + Body []string // dxlang instruction lines +} + +// TopLevelCall represents a function call at the file's outermost scope +// (outside any def { } block). When present, the loader writes /func/main +// to trigger VM execution. +type TopLevelCall struct { + FuncName string // e.g., "add_test" + Args []string // e.g., ["./a", "./b"] + Outputs []string // e.g., ["./c"] +} + +// DxFile represents a fully parsed .dx file with function definitions +// and optional top-level call expressions. +type DxFile struct { + Funcs []DxFunc + TopLevelCalls []TopLevelCall +} + +// LoadDxFile reads a .dx file and returns the first parsed function. +// Deprecated: Use ParseDxFile for multi-function files with top-level call support. +func LoadDxFile(path string) (*DxFunc, error) { + df, err := ParseDxFile(path) + if err != nil { + return nil, err + } + if len(df.Funcs) == 0 { + return nil, fmt.Errorf("%s: no function definitions found", path) + } + return &df.Funcs[0], nil +} + +// ParseDxFile reads a .dx file and returns all function definitions and +// any top-level call expressions. +// +// File format: +// +// # comment lines (ignored) +// def funcName(param1:type, ...) -> (out1:type, ...) { +// instruction1 +// instruction2 +// } +// # more def blocks ... +// +// # optional top-level calls (outside any def { } block): +// funcName(arg1, arg2) -> "./output" +// +// Top-level calls trigger automatic vthread creation via /func/main. +func ParseDxFile(path string) (*DxFile, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open %s: %w", path, err) + } + defer f.Close() + + var lines []string + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + lines = append(lines, line) + } + if err := scanner.Err(); err != nil { + return nil, err + } + + if len(lines) == 0 { + return nil, fmt.Errorf("%s: file is empty (no content lines)", path) + } + + df := &DxFile{} + + // ── Pass 1: Parse all def blocks ── + i := 0 + for i < len(lines) { + if !strings.HasPrefix(lines[i], "def ") { + i++ + continue + } + + defLine := lines[i] + name := extractFuncName(defLine) + if name == "" { + return nil, fmt.Errorf("%s: cannot extract function name from: %s", path, defLine) + } + + // Extract body + var body []string + bodyEnd := len(lines) // default: rest of file + + if strings.HasSuffix(defLine, "{") { + // Braced format: find matching '}' + for j := i + 1; j < len(lines); j++ { + if lines[j] == "}" { + bodyEnd = j + break + } + body = append(body, lines[j]) + } + if bodyEnd == len(lines) { + return nil, fmt.Errorf("%s: unclosed brace in function %s", path, name) + } + i = bodyEnd + 1 // skip past '}' + } else { + // No-brace format: scan until next 'def' or top-level call + for j := i + 1; j < len(lines); j++ { + if strings.HasPrefix(lines[j], "def ") { + bodyEnd = j + break + } + // Stop at top-level call expressions (detected by `->` outside braces) + if isTopLevelCall(lines[j]) { + bodyEnd = j + break + } + body = append(body, lines[j]) + } + if bodyEnd == len(lines) { + // consumed rest of file + i = len(lines) + } else { + i = bodyEnd + } + } + + if len(body) == 0 { + return nil, fmt.Errorf("%s: function %s body is empty", path, name) + } + + df.Funcs = append(df.Funcs, DxFunc{ + Name: name, + Signature: strings.TrimSuffix(defLine, " {"), + Body: body, + }) + } + + if len(df.Funcs) == 0 { + return nil, fmt.Errorf("%s: no 'def' function definition found", path) + } + + // ── Pass 2: Parse top-level calls (lines not consumed by def blocks) ── + for _, line := range lines { + // Skip lines that are part of def blocks (they start with def or are inside braces) + if strings.HasPrefix(line, "def ") || line == "}" { + continue + } + // Check if this line is inside a known function body + inBody := false + for _, fn := range df.Funcs { + for _, bodyLine := range fn.Body { + if bodyLine == line { + inBody = true + break + } + } + if inBody { + break + } + } + if inBody { + continue + } + + // Try to parse as top-level call + if tc, ok := parseTopLevelCall(line); ok { + df.TopLevelCalls = append(df.TopLevelCalls, tc) + } + } + + return df, nil +} + +// parseTopLevelCall attempts to parse a top-level call expression like: +// +// funcName(arg1, arg2) -> "./output" +// funcName() -> ("./a", "./b") +// +// Returns the parsed call and true if successful. +func parseTopLevelCall(line string) (TopLevelCall, bool) { + // Must contain '->' (call operator) + arrowIdx := strings.Index(line, "->") + if arrowIdx < 0 { + return TopLevelCall{}, false + } + + left := strings.TrimSpace(line[:arrowIdx]) + right := strings.TrimSpace(line[arrowIdx+2:]) + + // Left side: funcName(args) + parenOpen := strings.Index(left, "(") + if parenOpen < 0 { + return TopLevelCall{}, false + } + parenClose := strings.LastIndex(left, ")") + if parenClose < 0 || parenClose <= parenOpen { + return TopLevelCall{}, false + } + + funcName := strings.TrimSpace(left[:parenOpen]) + if funcName == "" { + return TopLevelCall{}, false + } + + // Parse args + argsStr := strings.TrimSpace(left[parenOpen+1 : parenClose]) + var args []string + if argsStr != "" { + for _, a := range strings.Split(argsStr, ",") { + a = strings.TrimSpace(a) + if a != "" { + args = append(args, a) + } + } + } + + // Parse outputs (right side) + var outputs []string + right = strings.Trim(right, "()") + if right != "" { + for _, o := range strings.Split(right, ",") { + o = strings.TrimSpace(o) + // Strip quotes if present + o = strings.Trim(o, `"`) + if o != "" { + outputs = append(outputs, o) + } + } + } + + return TopLevelCall{ + FuncName: funcName, + Args: args, + Outputs: outputs, + }, true +} + +// isTopLevelCall checks if a line looks like a function call expression +// (contains '->' and is not a def line or brace). +func isTopLevelCall(line string) bool { + if strings.HasPrefix(line, "def ") || line == "}" || line == "{" { + return false + } + return strings.Contains(line, "->") +} + +// extractFuncName extracts the function name from a def line like +// "def add_test(A:int, B:int) -> (C:int) {" or legacy "(add_test(A, B) -> (C))". +func extractFuncName(sig string) string { + sig = strings.TrimSpace(sig) + // Strip "def " prefix + if strings.HasPrefix(sig, "def ") { + sig = strings.TrimSpace(sig[4:]) + } + // Strip outer parens (legacy format) + if len(sig) >= 2 && sig[0] == '(' && sig[len(sig)-1] == ')' { + sig = sig[1 : len(sig)-1] + } + // Strip trailing " {" (braced def) + sig = strings.TrimSuffix(sig, " {") + sig = strings.TrimSpace(sig) + + // Isolate left side of "->" + left := sig + if idx := strings.Index(sig, "->"); idx >= 0 { + left = strings.TrimSpace(sig[:idx]) + } + + // "add_test(A, B)" → "add_test" + if idx := strings.Index(left, "("); idx >= 0 { + return strings.TrimSpace(left[:idx]) + } + return left +} + +// RegisterFunc registers a DxFunc in Redis at /src/func/. +func (f *DxFunc) RegisterFunc(ctx context.Context, rdb *redis.Client) error { + if err := rdb.Set(ctx, "/src/func/"+f.Name, f.Signature, 0).Err(); err != nil { + return fmt.Errorf("register sig: %w", err) + } + for i, line := range f.Body { + key := fmt.Sprintf("/src/func/%s/%d", f.Name, i) + if err := rdb.Set(ctx, key, line, 0).Err(); err != nil { + return fmt.Errorf("register body[%d]: %w", i, err) + } + } + return nil +} + +// VThreadState mirrors state.VThreadState for test usage. +type VThreadState struct { + PC string `json:"pc"` + Status string `json:"status"` + Mode string `json:"mode,omitempty"` +} + +// CreateVThread creates a new vthread with initial state and entry instruction. +// The entry instruction uses the function name as the opcode directly: +// +// opcode = funcName (e.g., "add_test"), reads = args, writes = outputs +// +// The engine detects function names at runtime via isFunctionCall() and +// converts them to internal CALL instructions automatically. +func CreateVThread(ctx context.Context, rdb *redis.Client, funcName string, reads, writes []string) (string, error) { + vtid := fmt.Sprintf("test-%d", time.Now().UnixNano()) + + st := VThreadState{PC: "[0,0]", Status: "init", Mode: "single"} + data, _ := json.Marshal(st) + if err := rdb.Set(ctx, "/vthread/"+vtid, data, 0).Err(); err != nil { + return "", fmt.Errorf("set state: %w", err) + } + + pipe := rdb.Pipeline() + // 直接使用函数名作为 opcode (不再使用 "call" 关键字) + pipe.Set(ctx, "/vthread/"+vtid+"/[0,0]", funcName, 0) + for i, r := range reads { + pipe.Set(ctx, fmt.Sprintf("/vthread/%s/[0,-%d]", vtid, i+1), r, 0) + } + for i, w := range writes { + pipe.Set(ctx, fmt.Sprintf("/vthread/%s/[0,%d]", vtid, i+1), w, 0) + } + if _, err := pipe.Exec(ctx); err != nil { + return "", fmt.Errorf("pipeline: %w", err) + } + + return vtid, nil +} diff --git a/executor/vm/testutil/dxloader_test.go b/executor/vm/testutil/dxloader_test.go new file mode 100644 index 00000000..a5c9ef14 --- /dev/null +++ b/executor/vm/testutil/dxloader_test.go @@ -0,0 +1,87 @@ +package testutil + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadDxFile(t *testing.T) { + dxPath := filepath.Join("..", "..", "..", "example", "dxlang", "call", "add_test.dx") + + fn, err := LoadDxFile(dxPath) + if err != nil { + t.Fatalf("LoadDxFile: %v", err) + } + + if fn.Name != "add_test" { + t.Errorf("Name = %q, want 'add_test'", fn.Name) + } + if fn.Signature != "def add_test(A:tensor, B:tensor) -> (C:tensor)" { + t.Errorf("Signature = %q", fn.Signature) + } + if len(fn.Body) != 1 { + t.Fatalf("Body len = %d, want 1", len(fn.Body)) + } + if fn.Body[0] != `add(A, B) -> "./C"` { + t.Errorf("Body[0] = %q, want `add(A, B) -> \"./C\"`", fn.Body[0]) + } +} + +func TestExtractFuncName(t *testing.T) { + tests := []struct { + sig string + want string + }{ + // def format + {"def add_test(A:int, B:int) -> (C:int) {", "add_test"}, + {"def add_test(A:int, B:int) -> (C:int)", "add_test"}, + {"def gemm(A, B, alpha, beta, C) -> (Y)", "gemm"}, + {"def relu(X:tensor) -> (Y:tensor)", "relu"}, + // legacy format + {"(add_test(A:tensor, B:tensor) -> (C:tensor))", "add_test"}, + {"(gemm(A, B, alpha, beta, C) -> (Y))", "gemm"}, + {"add_test -> (C)", "add_test"}, + {"simple_func", "simple_func"}, + } + + for _, tc := range tests { + got := extractFuncName(tc.sig) + if got != tc.want { + t.Errorf("extractFuncName(%q) = %q, want %q", tc.sig, got, tc.want) + } + } +} + +func TestLoadDxFile_Errors(t *testing.T) { + _, err := LoadDxFile("/nonexistent/path.dx") + if err == nil { + t.Error("expected error for nonexistent file") + } + + tmpDir := t.TempDir() + + // Only comments + emptyPath := filepath.Join(tmpDir, "empty.dx") + os.WriteFile(emptyPath, []byte("# only comments\n"), 0644) + _, err = LoadDxFile(emptyPath) + if err == nil { + t.Error("expected error for file with only comments") + } + + // No def prefix + noDef := filepath.Join(tmpDir, "nodef.dx") + os.WriteFile(noDef, []byte("(foo(A) -> (B))\nadd(A) -> \"./B\"\n"), 0644) + _, err = LoadDxFile(noDef) + if err == nil { + t.Error("expected error for file without 'def' prefix") + } + + // Def with no body + noBody := filepath.Join(tmpDir, "nobody.dx") + os.WriteFile(noBody, []byte("def foo(A) -> (B) {\n}\n"), 0644) + _, err = LoadDxFile(noBody) + if err == nil { + t.Error("expected error for function with empty body") + } +} diff --git a/front/go/example/1/1_app.go b/front/go/example/1/1_app.go index 23826fd8..debc55de 100644 --- a/front/go/example/1/1_app.go +++ b/front/go/example/1/1_app.go @@ -3,7 +3,7 @@ package main import ( "os" - "github.com/array2d/deepx/front/go/deepx" + "deepx/front/go/deepx" ) type Module1 struct { diff --git a/front/go/example/3/3_transformer_app.go b/front/go/example/3/3_transformer_app.go index 63ee9f26..fd8339bd 100644 --- a/front/go/example/3/3_transformer_app.go +++ b/front/go/example/3/3_transformer_app.go @@ -3,7 +3,7 @@ package main import ( "os" - "github.com/array2d/deepx/front/go/deepx" + "deepx/front/go/deepx" ) func main() { diff --git a/front/go/go.mod b/front/go/go.mod index 19edd067..fe21c916 100644 --- a/front/go/go.mod +++ b/front/go/go.mod @@ -1,3 +1,3 @@ -module github.com/array2d/deepx/front/go +module deepx/front/go go 1.23.2 diff --git a/front/py/deepx/nn/functional/leaffunc_changeshape.py b/front/py/deepx/nn/functional/leaffunc_changeshape.py index 4be47c30..e2410243 100644 --- a/front/py/deepx/nn/functional/leaffunc_changeshape.py +++ b/front/py/deepx/nn/functional/leaffunc_changeshape.py @@ -109,59 +109,3 @@ def repeat(input:Tensor,repeats:tuple[int,...],out:Union[Tensor,str]=''): from .rtf_changeshape import rtf_repeat rtf_repeat(input,repeats,outtensor,defaultauthor['repeat']) return outtensor - -# def unsqueeze(t:Tensor,dim:int)->Tensor: -# # 确保dim是有效的 -# if dim < -t.ndim-1 or dim > t.ndim: -# raise ValueError(f"维度超出范围,当前张量维度为{t.ndim},dim={dim}") - -# # 处理负数索引 -# if dim < 0: -# dim = t.ndim + dim + 1 - -# new_shape = list(t.shape) -# new_shape.insert(dim, 1) - -# return reshape(t, new_shape) - -# OpNode.register("expand") -# def expand(t:Tensor,shape:tuple[int,...],out:Union[Tensor,str]='')->Tensor: -# outtensor=None -# if isinstance(out,str) or out is None: -# outtensor=Tensor(shape=shape, dtype=t.dtype, device=t.device) -# outtensor.addtograph(out) -# else: -# outtensor=out - -# opnode=t.graph.add_op("expand") -# opnode.add_input(t.node) -# opnode.add_input(t.graph.add_vector("",shape)) -# outtensor.node.add_input(opnode) -# if t.graph.eager: -# ir=DeepxIR("expand",'',[t.node.name,*map(str, shape)], [outtensor.node.name]) -# send(ir) -# return outtensor - -# def broadcast_to(a: Tensor, shape: tuple,out:Union[Tensor,str]='') -> Tensor: -# # 计算广播后的形状 -# try: -# target_shape = broadcast_shape(a.shape, shape) -# if target_shape!=shape: -# raise ValueError(f"广播失败:{a.shape} 无法广播为 {shape} ") -# except ValueError as e: -# raise ValueError(f"广播失败:{e}") from e - -# # 为每个张量添加前导维度 -# if a.shape != target_shape: -# a_reshape = [1] * (len(target_shape) - a.ndimension) + list(a.shape) -# a_reshaped = reshape(a,a_reshape) -# else: -# a_reshaped=a - -# # 执行实际广播 -# if a_reshaped.shape != target_shape: -# a_broadcasted = expand(a_reshaped,target_shape,out) -# else: -# a_broadcasted=a_reshaped - -# return a_broadcasted \ No newline at end of file diff --git a/tool/deepxctl/cmd/boot.go b/tool/deepxctl/cmd/boot.go new file mode 100644 index 00000000..4a38442b --- /dev/null +++ b/tool/deepxctl/cmd/boot.go @@ -0,0 +1,229 @@ +// Package cmd implements the "boot" subcommand for deepxctl. +// +// deepxctl boot [flags] +// +// Boots the full deepx runtime: Redis reset → build → launch op-metal + heap-metal + VM. +// Writes PID state to /tmp/deepx-boot.json for later shutdown. +package cmd + +import ( + "encoding/json" + "flag" + "fmt" + "log" + "os" + "syscall" + "time" + + "deepx/tool/deepxctl/internal/builder" + "deepx/tool/deepxctl/internal/process" + "deepx/tool/deepxctl/internal/redis" +) + +// BootPIDFile is the path where boot writes process PIDs. +const BootPIDFile = "/tmp/deepx-boot.json" + +// BootState holds the PIDs of booted services. +type BootState struct { + OpMetal int `json:"op-metal"` + HeapMetal int `json:"heap-metal"` + VM int `json:"vm"` + RedisAddr string `json:"redis_addr"` +} + +// BootFlags holds the parsed flags for the boot command. +type BootFlags struct { + RedisAddr string + ForceBuild bool + NoReset bool + Verbose bool +} + +// Boot is the entry point for the "boot" subcommand. +func Boot(args []string) { + flags := parseBootFlags(args) + + if err := boot(flags); err != nil { + fmt.Fprintf(os.Stderr, "\nERROR: %v\n", err) + os.Exit(1) + } + + fmt.Println() + printSeparator() + fmt.Println("Boot complete. Services are running.") + fmt.Printf("PID file: %s\n", BootPIDFile) + fmt.Println("Run 'deepxctl run ' to execute, 'deepxctl shutdown' to stop.") + printSeparator() +} + +func parseBootFlags(args []string) BootFlags { + fs := flag.NewFlagSet("boot", flag.ExitOnError) + + var flags BootFlags + fs.StringVar(&flags.RedisAddr, "r", redis.DefaultAddr, "Redis address") + fs.StringVar(&flags.RedisAddr, "redis", redis.DefaultAddr, "Redis address") + fs.BoolVar(&flags.ForceBuild, "b", false, "Force rebuild all binaries") + fs.BoolVar(&flags.ForceBuild, "build", false, "Force rebuild all binaries") + fs.BoolVar(&flags.NoReset, "no-reset", false, "Skip Redis FLUSHDB") + fs.BoolVar(&flags.Verbose, "v", false, "Verbose output") + fs.BoolVar(&flags.Verbose, "verbose", false, "Verbose output") + + fs.Parse(args) + return flags +} + +func boot(flags BootFlags) error { + printHeader(flags.RedisAddr) + + repoRoot, err := builder.RepoRoot() + if err != nil { + return fmt.Errorf("find repo root: %w", err) + } + + // ── [1/3] Redis ── + step(1, 3, "Redis") + rdb, err := redis.Connect(flags.RedisAddr) + if err != nil { + errorX("Redis connection failed: %v", err) + return err + } + defer rdb.Close() + + if !flags.NoReset { + if err := redis.FlushDB(rdb); err != nil { + errorX("FLUSHDB: %v", err) + return err + } + } + ok() + + // ── [2/3] Build ── + step(2, 3, "Build") + if err := builder.All(repoRoot, flags.ForceBuild); err != nil { + errorX("Build failed: %v", err) + return err + } + ok() + + // ── [3/3] Start services ── + step(3, 3, "Start services") + fmt.Println() + + mgr := process.NewManager(flags.Verbose) + mgr.SetWorkDir(repoRoot) + mgr.SetLogDir("/tmp/deepx-logs") + + redisHost, redisPort := splitRedisAddr(flags.RedisAddr) + + // ① op-plat + if _, err := mgr.Start("op-metal", builder.OpMetal, redisHost, redisPort); err != nil { + errorX("op-plat: %v", err) + mgr.StopAll(5 * time.Second) + return err + } + fmt.Print(" op-plat .....................") + if err := redis.WaitForInstance(rdb, "/sys/op-plat/op-metal:0", 30*time.Second); err != nil { + errorX("op-plat not ready: %v", err) + mgr.StopAll(5 * time.Second) + return err + } + okInline() + + // ② heap-plat + if _, err := mgr.Start("heap-metal", builder.HeapMetal, redisHost, redisPort); err != nil { + errorX("heap-plat: %v", err) + mgr.StopAll(5 * time.Second) + return err + } + fmt.Print(" heap-plat ...................") + if err := redis.WaitForInstance(rdb, "/sys/heap-plat/heap-metal:0", 30*time.Second); err != nil { + errorX("heap-plat not ready: %v", err) + mgr.StopAll(5 * time.Second) + return err + } + okInline() + + // ③ VM + if _, err := mgr.Start("vm", builder.VM, flags.RedisAddr); err != nil { + errorX("VM: %v", err) + mgr.StopAll(5 * time.Second) + return err + } + fmt.Print(" VM ..........................") + if err := redis.WaitForInstance(rdb, "/sys/vm/0", 30*time.Second); err != nil { + errorX("VM not ready: %v", err) + mgr.StopAll(5 * time.Second) + return err + } + okInline() + + // ── Write PID file ── + state := BootState{ + OpMetal: mgr.PID("op-metal"), + HeapMetal: mgr.PID("heap-metal"), + VM: mgr.PID("vm"), + RedisAddr: flags.RedisAddr, + } + if err := writeBootState(state); err != nil { + errorX("write PID file: %v", err) + mgr.StopAll(5 * time.Second) + return err + } + + fmt.Printf("\n PID file written: %s\n", BootPIDFile) + ok() + + // Detach manager — processes stay running after boot exits. + // The PID file is the authoritative record for shutdown. + mgr.Detach() + log.Printf("[boot] services running, deepxctl boot exiting") + return nil +} + +// writeBootState writes the boot state to BootPIDFile. +func writeBootState(state BootState) error { + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return fmt.Errorf("marshal boot state: %w", err) + } + if err := os.WriteFile(BootPIDFile, data, 0644); err != nil { + return fmt.Errorf("write %s: %w", BootPIDFile, err) + } + return nil +} + +// ReadBootState reads the boot state from BootPIDFile. +// Returns nil if the file does not exist. +func ReadBootState() (*BootState, error) { + data, err := os.ReadFile(BootPIDFile) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("read %s: %w", BootPIDFile, err) + } + var state BootState + if err := json.Unmarshal(data, &state); err != nil { + return nil, fmt.Errorf("parse %s: %w", BootPIDFile, err) + } + return &state, nil +} + +// IsBooted checks whether the booted services are still running. +// Returns true if BootPIDFile exists and all PIDs are alive. +func IsBooted() bool { + state, err := ReadBootState() + if err != nil || state == nil { + return false + } + return pidAlive(state.OpMetal) && pidAlive(state.HeapMetal) && pidAlive(state.VM) +} + +// pidAlive checks if a process with the given PID is running (Unix). +func pidAlive(pid int) bool { + if pid <= 0 { + return false + } + // Signal 0 is the null signal — used to check process existence. + return syscall.Kill(pid, 0) == nil +} diff --git a/tool/deepxctl/cmd/common.go b/tool/deepxctl/cmd/common.go new file mode 100644 index 00000000..3a5fef91 --- /dev/null +++ b/tool/deepxctl/cmd/common.go @@ -0,0 +1,55 @@ +// Package cmd provides shared utilities used by boot, run, and shutdown subcommands. +package cmd + +import ( + "fmt" + "os" + "strings" +) + +// ── Output helpers ── + +func printHeader(redisAddr string) { + fmt.Println() + fmt.Printf(" deepxctl | redis: %s\n", redisAddr) + printSeparator() + fmt.Println() +} + +func printSeparator() { + fmt.Println("─────────────────────────────────────────") +} + +func step(n, total int, label string) { + fmt.Printf("[%d/%d] %-28s", n, total, label) +} + +func ok() { + fmt.Println("✓") +} + +func okInline() { + fmt.Println("✓") +} + +func greenCheck() { + fmt.Print(" ✓ ") +} + +func errorX(format string, args ...interface{}) { + fmt.Println("✗") + fmt.Fprintf(os.Stderr, "\n─────────────────────────────────────────\n") + fmt.Fprintf(os.Stderr, "ERROR "+format+"\n", args...) + fmt.Fprintf(os.Stderr, "─────────────────────────────────────────\n") +} + +// splitRedisAddr splits "host:port" into host, port. +func splitRedisAddr(addr string) (host, port string) { + host = "127.0.0.1" + port = "16379" + if idx := strings.LastIndex(addr, ":"); idx > 0 { + host = addr[:idx] + port = addr[idx+1:] + } + return +} diff --git a/tool/deepxctl/cmd/run.go b/tool/deepxctl/cmd/run.go new file mode 100644 index 00000000..bb81f18f --- /dev/null +++ b/tool/deepxctl/cmd/run.go @@ -0,0 +1,357 @@ +// Package cmd implements the "run" subcommand for deepxctl. +// +// deepxctl run [flags] +// +// Requires a prior "deepxctl boot". Run loads .dx source via the loader binary. +// +// Execution semantics: +// - If the .dx file has top-level call expressions (outside any def block), +// the loader writes /func/main and the VM auto-executes. deepxctl polls for the result. +// - If the .dx file only has function definitions (no top-level call), +// the loader only registers them. deepxctl reports the loaded functions and exits. +// - Use --entry to manually specify an entry function (writes /func/main even +// when the file has no top-level call). +// +// Services are left running after completion (unless --rm). +package cmd + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "time" + + goredis "github.com/redis/go-redis/v9" + + "deepx/tool/deepxctl/internal/builder" + "deepx/tool/deepxctl/internal/redis" +) + +// RunFlags holds the parsed flags for the run command. +type RunFlags struct { + RedisAddr string + Entry string + Timeout int + FilePath string + Rm bool +} + +// Run is the entry point for the "run" subcommand. +func Run(args []string) { + flags := parseRunFlags(args) + + if flags.FilePath == "" { + fmt.Fprintln(os.Stderr, "Usage: deepxctl run [flags]") + flag.PrintDefaults() + os.Exit(1) + } + + if err := run(flags); err != nil { + fmt.Fprintf(os.Stderr, "\nERROR: %v\n", err) + os.Exit(1) + } +} + +func parseRunFlags(args []string) RunFlags { + fs := flag.NewFlagSet("run", flag.ExitOnError) + + var flags RunFlags + fs.StringVar(&flags.RedisAddr, "r", redis.DefaultAddr, "Redis address") + fs.StringVar(&flags.RedisAddr, "redis", redis.DefaultAddr, "Redis address") + fs.StringVar(&flags.Entry, "entry", "", "Manual entry function (overrides top-level call detection)") + fs.IntVar(&flags.Timeout, "timeout", 60, "Execution timeout in seconds (0=no limit)") + fs.BoolVar(&flags.Rm, "rm", false, "After execution, flush Redis and shutdown all services") + + fs.Parse(args) + + if fs.NArg() > 0 { + flags.FilePath = fs.Arg(0) + } + + return flags +} + +func run(flags RunFlags) error { + printHeader(flags.RedisAddr) + + // ── [1/3] Verify boot ── + step(1, 3, "Check services") + if !IsBooted() { + errorX("Services not booted. Run 'deepxctl boot' first.") + fmt.Fprintf(os.Stderr, "\n Expected boot state at: %s\n", BootPIDFile) + fmt.Fprintf(os.Stderr, " If you believe services are running, check with 'make status'.\n") + return fmt.Errorf("services not booted") + } + ok() + + // Verify each service is registered in Redis + rdb, err := redis.Connect(flags.RedisAddr) + if err != nil { + errorX("Redis connection failed: %v", err) + return err + } + defer rdb.Close() + + services := map[string]string{ + "op-plat": "/sys/op-plat/op-metal:0", + "heap-plat": "/sys/heap-plat/heap-metal:0", + "vm": "/sys/vm/0", + } + for name, key := range services { + if err := redis.WaitForInstance(rdb, key, 5*time.Second); err != nil { + errorX("%s not ready (%s): %v", name, key, err) + return fmt.Errorf("service %s not ready — re-run 'deepxctl boot'", name) + } + } + log.Printf("[run] all services verified") + + // ── [2/3] Load dx ── + step(2, 3, "Load dx") + dxPath, _ := normalizePath(flags.FilePath) + funcs, entryCreated, err := loadDx(builder.Loader, dxPath, flags.RedisAddr) + if err != nil { + errorX("Load: %v", err) + return err + } + if len(funcs) == 0 { + errorX("No functions loaded from %s", flags.FilePath) + return fmt.Errorf("no functions loaded from %s", flags.FilePath) + } + ok() + + // ── Manual entry override ── + // If --entry is specified, write /func/main directly + if flags.Entry != "" { + log.Printf("[run] --entry=%s → writing /func/main", flags.Entry) + entryData, _ := json.Marshal(map[string]interface{}{ + "entry": flags.Entry, + "reads": []string{}, + "writes": []string{}, + }) + if err := rdb.Set(context.Background(), "/func/main", entryData, 0).Err(); err != nil { + errorX("write /func/main: %v", err) + return err + } + entryCreated = true + fmt.Printf(" entry: %s (manual override)\n", flags.Entry) + } + + // ── [3/3] Execute (only if /func/main was created) ── + if !entryCreated { + // No entry point — just loaded definitions + fmt.Println() + printSeparator() + fmt.Printf("Loaded %d function(s) into KV Space.\n", len(funcs)) + fmt.Println("(no top-level call found — VM is waiting for /func/main)") + fmt.Println("Use --entry to execute a loaded function.") + printSeparator() + return nil + } + + step(3, 3, "Execute") + timeout := time.Duration(flags.Timeout) * time.Second + if flags.Timeout == 0 { + timeout = 5 * time.Minute + } + + result, err := pollFuncMain(rdb, timeout) + if err != nil { + errorX("Execute: %v", err) + return err + } + + if result.Success { + greenCheck() + fmt.Printf(" vtid=%s status=%s %v\n", result.Vtid, result.Status, result.Duration) + } else { + errorX("vtid=%s status=%s", result.Vtid, result.Status) + if result.ErrCode != "" { + fmt.Fprintf(os.Stderr, " code: %s\n", result.ErrCode) + fmt.Fprintf(os.Stderr, " message: %s\n", result.ErrMsg) + } + return fmt.Errorf("execution failed") + } + + // ── Final summary ── + fmt.Println() + printSeparator() + fmt.Printf("SUCCESS vtid=%s status=%s %v\n", result.Vtid, result.Status, result.Duration) + if !flags.Rm { + fmt.Println("(services left running — use 'deepxctl shutdown' to stop)") + } + printSeparator() + + // ── [--rm] Cleanup ── + if flags.Rm { + fmt.Println() + fmt.Println("── Cleanup (--rm): flushing Redis, shutting down services ──") + if err := redis.FlushDB(rdb); err != nil { + errorX("FLUSHDB: %v", err) + return err + } + fmt.Println(" Redis FLUSHDB ✓") + if err := ExecShutdown(); err != nil { + errorX("Shutdown: %v", err) + return err + } + } + + return nil +} + +// ── Loader helpers ── + +// normalizePath resolves the .dx file path. Relative paths are resolved against CWD. +func normalizePath(path string) (string, error) { + if filepath.IsAbs(path) { + return path, nil + } + return filepath.Abs(path) +} + +// loadDx exec's the loader binary to load .dx files into /src/func/. +// Returns the set of function names loaded, and whether an entry point (/func/main) was created. +func loadDx(loaderBin, path, redisAddr string) (funcs []string, entryCreated bool, err error) { + log.Printf("[loader] loading %s ...", path) + + cmd := exec.Command(loaderBin, path, redisAddr) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return nil, false, fmt.Errorf("loader failed: %w\noutput: %s", err, stderr.String()) + } + + // Parse function names and entry info from loader output + output := stdout.String() + stderr.String() + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + + // Parse: "OK /path/file.dx → /src/func/compute (N body lines)" + if idx := strings.Index(line, "/src/func/"); idx >= 0 && strings.Contains(line, "OK") { + namePart := line[idx+len("/src/func/"):] + name := strings.SplitN(namePart, " ", 2)[0] + funcs = append(funcs, name) + } + + // Parse: "ENTRY /func/main → funcName" + if strings.Contains(line, "ENTRY /func/main →") { + entryCreated = true + } + } + + log.Printf("[loader] loaded %d functions: %v, entryCreated=%v", len(funcs), funcs, entryCreated) + return funcs, entryCreated, nil +} + +// ── /func/main execution polling ── + +// funcMainResult holds the execution result from the /func/main protocol. +type funcMainResult struct { + Success bool + Vtid string + Status string + ErrCode string + ErrMsg string + Duration time.Duration +} + +// pollFuncMain waits for the VM to pick up /func/main, execute, and report completion. +func pollFuncMain(rdb *goredis.Client, timeout time.Duration) (*funcMainResult, error) { + startTime := time.Now() + deadline := time.Now().Add(timeout) + ctx := context.Background() + const key = "/func/main" + + // Phase 1: Wait for VM to claim /func/main and write vtid + var vtid string + for time.Now().Before(deadline) { + val, err := rdb.Get(ctx, key).Result() + if err != nil { + // Key may not exist yet, or VM already processed it + time.Sleep(200 * time.Millisecond) + continue + } + + var entry struct { + Vtid string `json:"vtid"` + Status string `json:"status"` + } + if err := json.Unmarshal([]byte(val), &entry); err != nil { + time.Sleep(200 * time.Millisecond) + continue + } + + if entry.Vtid != "" { + vtid = entry.Vtid + log.Printf("[run] VM picked up /func/main, vtid=%s", vtid) + break + } + + // Still waiting for VM to pick up (value has "entry" but not yet "vtid") + time.Sleep(200 * time.Millisecond) + } + + if vtid == "" { + return nil, fmt.Errorf("timeout waiting for VM to pick up /func/main") + } + + // Phase 2: Poll vthread status + pollInterval := 100 * time.Millisecond + for time.Now().Before(deadline) { + status, err := redis.GetVThreadStatus(rdb, parseVtid(vtid)) + if err != nil { + time.Sleep(pollInterval) + continue + } + + switch status.Status { + case "done": + // Clean up /func/main + rdb.Del(ctx, key) + return &funcMainResult{ + Success: true, + Vtid: vtid, + Status: status.Status, + Duration: time.Since(startTime), + }, nil + + case "error": + r := &funcMainResult{ + Success: false, + Vtid: vtid, + Status: status.Status, + Duration: time.Since(startTime), + } + if status.Error != nil { + r.ErrCode = status.Error.Code + r.ErrMsg = status.Error.Message + } + rdb.Del(ctx, key) + return r, nil + + case "init", "running", "wait": + time.Sleep(pollInterval) + + default: + time.Sleep(pollInterval) + } + } + + return nil, fmt.Errorf("vthread %s execution timeout after %v", vtid, timeout) +} + +// parseVtid converts a string vtid to int64 for compatibility with redis helpers. +func parseVtid(s string) int64 { + n, _ := strconv.ParseInt(s, 10, 64) + return n +} diff --git a/tool/deepxctl/cmd/shutdown.go b/tool/deepxctl/cmd/shutdown.go new file mode 100644 index 00000000..b0758dba --- /dev/null +++ b/tool/deepxctl/cmd/shutdown.go @@ -0,0 +1,329 @@ +// Package cmd implements the "shutdown" subcommand for deepxctl. +// +// deepxctl shutdown +// +// Ordered shutdown via Redis system commands: +// 1. plats (op-metal, heap-metal) — send sys:shutdown, wait for stopped heartbeat +// 2. VM — send sys:shutdown, wait for stopped heartbeat +// 3. Verify all heartbeats, log final values +// 4. Clean up PID file +// +// OS SIGKILL is only used as last-resort fallback if Redis is unreachable. +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "syscall" + "time" + + goredis "github.com/redis/go-redis/v9" + + "deepx/tool/deepxctl/internal/redis" +) + +// Shutdown is the entry point for the "shutdown" subcommand. +func Shutdown() { + if err := ExecShutdown(); err != nil { + fmt.Fprintf(os.Stderr, "\nERROR: %v\n", err) + os.Exit(1) + } + + fmt.Println() + printSeparator() + fmt.Println("Shutdown complete. All services stopped.") + printSeparator() +} + +// ExecShutdown performs the ordered shutdown of all booted services. +// It sends sys:shutdown commands via Redis (with OS signal fallback) and +// removes the PID file. Exported so the run command can reuse it with --rm. +func ExecShutdown() error { + return shutdown() +} + +// heartbeatVal represents a heartbeat entry from Redis. +type heartbeatVal struct { + Ts int64 `json:"ts"` + Status string `json:"status"` + Pid int `json:"pid"` +} + +type platInfo struct { + name string + sysQueue string + hbKey string + pid int +} + +func shutdown() error { + state, err := ReadBootState() + if err != nil { + return fmt.Errorf("read boot state: %w", err) + } + if state == nil { + fmt.Println("No boot state found. Nothing to shut down.") + fmt.Printf("(expected %s — was 'deepxctl boot' run?)\n", BootPIDFile) + return nil + } + + fmt.Printf("Ordered shutdown via Redis sys commands (redis: %s)\n", state.RedisAddr) + + // Connect to Redis + rdb, err := redis.Connect(state.RedisAddr) + if err != nil { + fmt.Printf(" Redis not reachable (%v), falling back to OS signals...\n", err) + return forceKill(state) + } + defer rdb.Close() + + ctx := context.Background() + shutdownCmd, _ := json.Marshal(map[string]string{"cmd": "shutdown"}) + + // ═══════════════════════════════════════════════════════════════ + // Phase 1: Shutdown plats (op-metal → heap-metal) + // ═══════════════════════════════════════════════════════════════ + fmt.Println("\n── Phase 1: Stopping plats (op-metal, heap-metal) ──") + + plats := []platInfo{ + {"op-metal", "sys:cmd:op-metal:0", "/sys/heartbeat/op-metal:0", state.OpMetal}, + {"heap-metal", "sys:cmd:heap-metal:0", "/sys/heartbeat/heap-metal:0", state.HeapMetal}, + } + + for _, p := range plats { + if !pidAlive(p.pid) { + fmt.Printf(" %-15s pid=%-6d already stopped\n", p.name, p.pid) + continue + } + fmt.Printf(" %-15s pid=%-6d sending sys:shutdown → %s...", p.name, p.pid, p.sysQueue) + if err := rdb.LPush(ctx, p.sysQueue, shutdownCmd).Err(); err != nil { + fmt.Printf(" LPUSH failed: %v\n", err) + } else { + fmt.Println(" sent") + } + } + + // Wait for plats heartbeats to show "stopped" + fmt.Print(" waiting for plats to stop...") + if !waitHeartbeats(rdb, plats, 10*time.Second) { + fmt.Println(" timeout") + } else { + fmt.Println(" done") + } + + // ═══════════════════════════════════════════════════════════════ + // Phase 2: Shutdown VM + // ═══════════════════════════════════════════════════════════════ + fmt.Println("\n── Phase 2: Stopping VM ──") + + vmPlats := []platInfo{ + {"vm", "sys:cmd:vm:0", "/sys/heartbeat/vm:0", state.VM}, + } + + if !pidAlive(state.VM) { + fmt.Printf(" VM pid=%-6d already stopped\n", state.VM) + } else { + fmt.Printf(" VM pid=%-6d sending sys:shutdown → sys:cmd:vm:0...", state.VM) + if err := rdb.LPush(ctx, "sys:cmd:vm:0", shutdownCmd).Err(); err != nil { + fmt.Printf(" LPUSH failed: %v\n", err) + } else { + fmt.Println(" sent") + } + + fmt.Print(" waiting for VM to stop...") + if !waitHeartbeats(rdb, vmPlats, 10*time.Second) { + fmt.Println(" timeout") + } else { + fmt.Println(" done") + } + } + + // ═══════════════════════════════════════════════════════════════ + // Phase 3: Verify all final heartbeats + // ═══════════════════════════════════════════════════════════════ + fmt.Println("\n── Phase 3: Final heartbeat verification ──") + + allHbKeys := []string{ + "/sys/heartbeat/op-metal:0", + "/sys/heartbeat/heap-metal:0", + "/sys/heartbeat/vm:0", + } + + for _, key := range allHbKeys { + val, err := rdb.Get(ctx, key).Result() + if err != nil { + fmt.Printf(" %-40s ── (cleaned)\n", key) + continue + } + var hb heartbeatVal + if err := json.Unmarshal([]byte(val), &hb); err != nil { + fmt.Printf(" %-40s parse error: %v\n", key, err) + continue + } + ts := time.Unix(hb.Ts, 0).Format("15:04:05") + icon := "✓" + if hb.Status != "stopped" { + icon = "✗" + } + fmt.Printf(" %s %-40s status=%-8s pid=%-6d ts=%s\n", icon, key, hb.Status, hb.Pid, ts) + } + + // ═══════════════════════════════════════════════════════════════ + // Grace period — lets processes finish exiting after heartbeat stop + // ═══════════════════════════════════════════════════════════════ + time.Sleep(500 * time.Millisecond) + + // ═══════════════════════════════════════════════════════════════ + // Phase 4: Force kill any remaining processes (fallback) + // ═══════════════════════════════════════════════════════════════ + needForce := false + for _, r := range []struct { + name string + pid int + }{ + {"op-metal", state.OpMetal}, + {"heap-metal", state.HeapMetal}, + {"vm", state.VM}, + } { + if pidAlive(r.pid) { + needForce = true + break + } + } + + if needForce { + fmt.Println("\n── Force killing remaining processes (fallback) ──") + for _, r := range []struct { + name string + pid int + }{ + {"op-metal", state.OpMetal}, + {"heap-metal", state.HeapMetal}, + {"vm", state.VM}, + } { + if pidAlive(r.pid) { + fmt.Printf(" %-15s pid=%-6d SIGKILL...", r.name, r.pid) + syscall.Kill(r.pid, syscall.SIGKILL) + time.Sleep(100 * time.Millisecond) + if pidAlive(r.pid) { + fmt.Println(" still alive!") + } else { + fmt.Println(" killed") + } + } + } + } + + // Remove PID file + if err := os.Remove(BootPIDFile); err != nil && !os.IsNotExist(err) { + log.Printf("[shutdown] could not remove %s: %v", BootPIDFile, err) + } else { + log.Printf("[shutdown] removed %s", BootPIDFile) + } + + return nil +} + +// waitHeartbeats polls heartbeat keys until all show "stopped" or PID dies, or timeout. +func waitHeartbeats(rdb *goredis.Client, plats []platInfo, timeout time.Duration) bool { + ctx := context.Background() + deadline := time.Now().Add(timeout) + + remaining := make(map[string]bool) + for _, p := range plats { + if pidAlive(p.pid) { + remaining[p.name] = true + } + } + if len(remaining) == 0 { + return true + } + + for len(remaining) > 0 && time.Now().Before(deadline) { + for _, p := range plats { + if !remaining[p.name] { + continue + } + // Check 1: PID dead = component exited + if !pidAlive(p.pid) { + delete(remaining, p.name) + continue + } + // Check 2: Heartbeat shows "stopped" + val, err := rdb.Get(ctx, p.hbKey).Result() + if err != nil { + continue + } + var hb heartbeatVal + if json.Unmarshal([]byte(val), &hb) == nil && hb.Status == "stopped" { + delete(remaining, p.name) + } + } + if len(remaining) > 0 { + time.Sleep(300 * time.Millisecond) + } + } + return len(remaining) == 0 +} + +// forceKill sends SIGTERM → wait → SIGKILL to all booted processes. +// Used as fallback when Redis is unreachable. +func forceKill(state *BootState) error { + pids := map[string]int{ + "op-metal": state.OpMetal, + "heap-metal": state.HeapMetal, + "vm": state.VM, + } + + for name, pid := range pids { + if !pidAlive(pid) { + fmt.Printf(" %-15s pid=%-6d already stopped\n", name, pid) + continue + } + fmt.Printf(" %-15s pid=%-6d SIGTERM...", name, pid) + syscall.Kill(pid, syscall.SIGTERM) + if waitPID(pid, 5*time.Second) { + fmt.Println(" stopped") + continue + } + fmt.Print(" SIGKILL...") + syscall.Kill(pid, syscall.SIGKILL) + time.Sleep(200 * time.Millisecond) + if pidAlive(pid) { + fmt.Println(" still alive!") + } else { + fmt.Println(" killed") + } + } + + os.Remove(BootPIDFile) + return nil +} + +// pidFromState returns the PID for a named component from boot state. +func pidFromState(state *BootState, name string) int { + switch name { + case "op-metal": + return state.OpMetal + case "heap-metal": + return state.HeapMetal + case "vm": + return state.VM + } + return -1 +} + +// waitPID polls until the process exits or timeout elapses. +func waitPID(pid int, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if !pidAlive(pid) { + return true + } + time.Sleep(200 * time.Millisecond) + } + return false +} diff --git a/tool/deepxctl/cmd/tensor/print.go b/tool/deepxctl/cmd/tensor/print.go index f755bfcd..9f796f9a 100644 --- a/tool/deepxctl/cmd/tensor/print.go +++ b/tool/deepxctl/cmd/tensor/print.go @@ -5,7 +5,7 @@ import ( "fmt" "os" - coretensor "github.com/array2d/deepx/tool/deepxctl/tensor" + coretensor "deepx/tool/deepxctl/tensor" ) func PrintCmd() { diff --git a/tool/deepxctl/go.mod b/tool/deepxctl/go.mod index 42c0efe6..8ea063ea 100644 --- a/tool/deepxctl/go.mod +++ b/tool/deepxctl/go.mod @@ -1,5 +1,15 @@ -module github.com/array2d/deepx/tool/deepxctl +module deepx/tool/deepxctl -go 1.23.2 +go 1.24 -require gopkg.in/yaml.v2 v2.4.0 // indirect +toolchain go1.24.4 + +require ( + github.com/redis/go-redis/v9 v9.19.0 + gopkg.in/yaml.v2 v2.4.0 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + go.uber.org/atomic v1.11.0 // indirect +) diff --git a/tool/deepxctl/go.sum b/tool/deepxctl/go.sum index 75346616..c56e52a1 100644 --- a/tool/deepxctl/go.sum +++ b/tool/deepxctl/go.sum @@ -1,3 +1,26 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k= +github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= +github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/tool/deepxctl/internal/builder/builder.go b/tool/deepxctl/internal/builder/builder.go new file mode 100644 index 00000000..beb992f4 --- /dev/null +++ b/tool/deepxctl/internal/builder/builder.go @@ -0,0 +1,154 @@ +// Package builder handles building deepx components by exec'ing existing build.sh scripts. +// +// Allowed operations (per doc/deepxctl/CLAUDE.md): +// +// exec executor/*/build.sh +// detect existing binary +// +// Prohibited: +// +// modifying build scripts +// modifying CMakeLists.txt +package builder + +import ( + "fmt" + "log" + "os" + "os/exec" + "path/filepath" +) + +// Binary paths for metal platform binaries. +var ( + OpMetal = "/tmp/deepx/op-metal/build/deepx-op-metal" + HeapMetal = "/tmp/deepx/heap-metal/build/deepx-heap-metal" + VM = "/tmp/deepx-vm/vm" + Loader = "/tmp/deepx-vm/loader" +) + +// Script paths relative to repo root. +type Scripts struct { + OpMetal string + HeapMetal string + VM string +} + +// DefaultScripts returns the standard build script locations. +func DefaultScripts(repoRoot string) Scripts { + return Scripts{ + OpMetal: filepath.Join(repoRoot, "executor/op-metal/build.sh"), + HeapMetal: filepath.Join(repoRoot, "executor/heap-metal/build.sh"), + VM: filepath.Join(repoRoot, "executor/vm/build.sh"), + } +} + +// binaryExists checks if a binary exists on disk. +func binaryExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +// Missing returns the list of missing binaries. +func Missing() []string { + var missing []string + if !binaryExists(OpMetal) { + missing = append(missing, "op-metal") + } + if !binaryExists(HeapMetal) { + missing = append(missing, "heap-metal") + } + if !binaryExists(VM) { + missing = append(missing, "vm") + } + if !binaryExists(Loader) { + missing = append(missing, "loader") + } + return missing +} + +// All builds all components by exec'ing their build.sh scripts. +// repoRoot is the path to the deepx repository root. +func All(repoRoot string, force bool) error { + scripts := DefaultScripts(repoRoot) + + components := []struct { + name string + script string + bin string + }{ + {"op-metal", scripts.OpMetal, OpMetal}, + {"heap-metal", scripts.HeapMetal, HeapMetal}, + {"vm (+loader)", scripts.VM, VM}, + } + + for _, c := range components { + if !force && binaryExists(c.bin) { + log.Printf("[build] %s binary exists, skipping", c.name) + continue + } + log.Printf("[build] building %s ...", c.name) + cmd := exec.Command("bash", c.script) + cmd.Dir = repoRoot + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("build %s failed: %w", c.name, err) + } + if !binaryExists(c.bin) { + return fmt.Errorf("build %s succeeded but binary not found at %s", c.name, c.bin) + } + log.Printf("[build] %s → %s", c.name, c.bin) + } + + // loader is built as part of vm build.sh + if !binaryExists(Loader) { + return fmt.Errorf("loader binary not found at %s (should be built by vm/build.sh)", Loader) + } + + return nil +} + +// RepoRoot attempts to find the repository root by walking up from the +// executable's directory, looking for go.mod or executor/ directory. +func RepoRoot() (string, error) { + // Start from executable path, or current working directory. + exe, err := os.Executable() + if err != nil { + exe, _ = os.Getwd() + } + dir := filepath.Dir(exe) + + // Walk up to find repo root (look for executor/ or go.mod at top level) + for { + if _, err := os.Stat(filepath.Join(dir, "executor")); err == nil { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir, nil + } + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + + // Fallback: use cwd + cwd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("cannot determine repo root") + } + // Walk up from cwd too + dir = cwd + for { + if _, err := os.Stat(filepath.Join(dir, "executor")); err == nil { + return dir, nil + } + parent := filepath.Dir(dir) + if parent == dir { + break + } + dir = parent + } + return "", fmt.Errorf("cannot find repo root (no executor/ directory found)") +} diff --git a/tool/deepxctl/internal/executor/executor.go b/tool/deepxctl/internal/executor/executor.go new file mode 100644 index 00000000..5fb61dd4 --- /dev/null +++ b/tool/deepxctl/internal/executor/executor.go @@ -0,0 +1,109 @@ +// Package executor handles vthread creation, VM wake-up, and status polling. +// +// Allowed operations (per doc/deepxctl/CLAUDE.md): +// +// INCR /sys/vtid_counter +// SET /vthread/ (init state) +// SET /vthread//[0,0] (entry CALL) +// SET /vthread//[0,1] (return slot) +// LPUSH notify:vm +// GET /vthread/ (status polling) +// +// Prohibited: +// +// writing to /vthread//[*,*] beyond the initial entry CALL +// consuming done: queue +// writing to cmd:* queues +package executor + +import ( + "fmt" + "log" + "time" + + goredis "github.com/redis/go-redis/v9" + + "deepx/tool/deepxctl/internal/redis" +) + +// Result holds the execution result. +type Result struct { + Success bool + Vtid int64 + Status string + PC string + ErrCode string + ErrMsg string + Duration time.Duration +} + +// Run creates a vthread for entryFunc, wakes the VM, and polls until done/error/timeout. +func Run(rdb *goredis.Client, entryFunc string, timeout time.Duration) (*Result, error) { + startTime := time.Now() + + // 1. Allocate vtid + vtid, err := redis.AllocVtid(rdb) + if err != nil { + return nil, fmt.Errorf("alloc vtid: %w", err) + } + + // 2. Create vthread with single CALL instruction + if err := redis.CreateVThread(rdb, vtid, entryFunc); err != nil { + return nil, fmt.Errorf("create vthread: %w", err) + } + + // 3. Wake VM + if err := redis.WakeVM(rdb, vtid); err != nil { + return nil, fmt.Errorf("wake vm: %w", err) + } + + log.Printf("[executor] vthread %d started, entry=%s, waiting...", vtid, entryFunc) + + // 4. Poll until done/error/timeout + deadline := time.Now().Add(timeout) + pollInterval := 100 * time.Millisecond + + for time.Now().Before(deadline) { + status, err := redis.GetVThreadStatus(rdb, vtid) + if err != nil { + // VThread key might not exist yet (VM hasn't picked it up) + time.Sleep(pollInterval) + continue + } + + switch status.Status { + case "done": + return &Result{ + Success: true, + Vtid: vtid, + Status: status.Status, + PC: status.PC, + Duration: time.Since(startTime), + }, nil + + case "error": + r := &Result{ + Success: false, + Vtid: vtid, + Status: status.Status, + PC: status.PC, + Duration: time.Since(startTime), + } + if status.Error != nil { + r.ErrCode = status.Error.Code + r.ErrMsg = status.Error.Message + } + return r, nil + + case "init", "running", "wait": + // Still executing, continue polling + time.Sleep(pollInterval) + + default: + log.Printf("[executor] unexpected vthread status: %s", status.Status) + time.Sleep(pollInterval) + } + } + + return nil, fmt.Errorf("vthread %d execution timeout after %v", vtid, timeout) +} diff --git a/tool/deepxctl/internal/process/manager.go b/tool/deepxctl/internal/process/manager.go new file mode 100644 index 00000000..798772f2 --- /dev/null +++ b/tool/deepxctl/internal/process/manager.go @@ -0,0 +1,238 @@ +// Package process manages the lifecycle of deepx subprocesses: +// op-plat, heap-plat, and VM. +// +// Allowed operations (per doc/deepxctl/CLAUDE.md): +// +// exec.Command start +// pass args (redis addr) +// capture stdout/stderr +// SIGTERM / SIGKILL +// detect exit status +package process + +import ( + "bytes" + "fmt" + "io" + "log" + "os" + "os/exec" + "path/filepath" + "sync" + "syscall" + "time" +) + +// Proc represents a managed subprocess. +type Proc struct { + Name string + cmd *exec.Cmd + stdout bytes.Buffer + stderr bytes.Buffer + logFile *os.File // if set, stdout+stderr are also written to this file + done chan error +} + +// Manager tracks all subprocesses started by deepxctl. +type Manager struct { + mu sync.Mutex + procs []*Proc + verbose bool + workDir string // if set, all subprocesses run with this CWD + logDir string // if set, each subprocess logs to /.log +} + +// NewManager creates a process manager. +// If verbose is true, subprocess stdout/stderr are also streamed to os.Stdout/os.Stderr. +func NewManager(verbose bool) *Manager { + return &Manager{verbose: verbose} +} + +// SetWorkDir sets the working directory for all subprocesses started by this manager. +func (m *Manager) SetWorkDir(dir string) { + m.mu.Lock() + defer m.mu.Unlock() + m.workDir = dir +} + +// SetLogDir sets a directory for per-process log files. +// Each subprocess gets /.log with combined stdout+stderr. +// Files persist after the manager exits (safe for boot → detach → run workflows). +func (m *Manager) SetLogDir(dir string) { + m.mu.Lock() + defer m.mu.Unlock() + m.logDir = dir +} + +// Start launches a subprocess. +// +// binPath: path to the compiled binary +// args: arguments passed to the binary +// +// If logDir is set, stdout+stderr are written directly to /.log. +// This ensures logs survive after the manager exits (safe for boot → detach workflow). +// If logDir is not set, stdout+stderr are captured in memory via pipes. +func (m *Manager) Start(name, binPath string, args ...string) (*Proc, error) { + m.mu.Lock() + defer m.mu.Unlock() + + cmd := exec.Command(binPath, args...) + if m.workDir != "" { + cmd.Dir = m.workDir + } + + p := &Proc{ + Name: name, + cmd: cmd, + done: make(chan error, 1), + } + + // Case 1: log file redirection (survives parent exit). + if m.logDir != "" { + if err := os.MkdirAll(m.logDir, 0755); err != nil { + return nil, fmt.Errorf("create log dir %s: %w", m.logDir, err) + } + logPath := filepath.Join(m.logDir, name+".log") + f, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return nil, fmt.Errorf("open log file %s: %w", logPath, err) + } + p.logFile = f + // Write directly to file — no pipes, so survives parent exit. + cmd.Stdout = f + cmd.Stderr = f + log.Printf("[process] %s logging to %s", name, logPath) + } else { + // Case 2: in-memory capture via pipes (for run command where parent stays alive). + if m.verbose { + cmd.Stdout = io.MultiWriter(&p.stdout, os.Stdout) + cmd.Stderr = io.MultiWriter(&p.stderr, os.Stderr) + } else { + cmd.Stdout = &p.stdout + cmd.Stderr = &p.stderr + } + } + + if err := cmd.Start(); err != nil { + if p.logFile != nil { + p.logFile.Close() + } + return nil, fmt.Errorf("start %s: %w", name, err) + } + + log.Printf("[process] %s started pid=%d", name, cmd.Process.Pid) + + // Monitor exit in background + go func() { + p.done <- cmd.Wait() + if p.logFile != nil { + p.logFile.Close() + } + }() + + m.procs = append(m.procs, p) + return p, nil +} + +// PID returns the PID of a named process, or -1 if not found. +func (m *Manager) PID(name string) int { + m.mu.Lock() + defer m.mu.Unlock() + for _, p := range m.procs { + if p.Name == name && p.cmd.Process != nil { + return p.cmd.Process.Pid + } + } + return -1 +} + +// StopAll sends SIGTERM to all processes, waits up to shutdownTimeout, then SIGKILL. +func (m *Manager) StopAll(shutdownTimeout time.Duration) { + m.mu.Lock() + procs := make([]*Proc, len(m.procs)) + copy(procs, m.procs) + m.mu.Unlock() + + if len(procs) == 0 { + return + } + + log.Printf("[process] stopping %d subprocesses...", len(procs)) + + // Phase 1: SIGTERM + for _, p := range procs { + if p.cmd.Process != nil { + p.cmd.Process.Signal(syscall.SIGTERM) + } + } + + // Phase 2: Wait for graceful shutdown, then collect exit statuses. + shutdownDeadline := time.After(shutdownTimeout) + allExited := true + for _, p := range procs { + select { + case err := <-p.done: + if err != nil { + log.Printf("[process] %s exited: %v", p.Name, err) + } else { + log.Printf("[process] %s exited ok", p.Name) + } + case <-shutdownDeadline: + allExited = false + } + } + + if !allExited { + log.Printf("[process] timeout, sending SIGKILL...") + for _, p := range procs { + if p.cmd.Process != nil { + p.cmd.Process.Signal(syscall.SIGKILL) + } + } + time.Sleep(500 * time.Millisecond) + // Drain remaining done channels after kill + for _, p := range procs { + select { + case err := <-p.done: + if err != nil { + log.Printf("[process] %s killed: %v", p.Name, err) + } + default: + } + } + } +} + +// Stdout returns captured stdout for a named process. +func (m *Manager) Stdout(name string) string { + m.mu.Lock() + defer m.mu.Unlock() + for _, p := range m.procs { + if p.Name == name { + return p.stdout.String() + } + } + return "" +} + +// Detach clears the internal process list without stopping any processes. +// Use this when the manager should exit but processes must keep running +// (e.g., after boot, managed by PID file instead). +func (m *Manager) Detach() { + m.mu.Lock() + defer m.mu.Unlock() + log.Printf("[process] detaching %d subprocesses", len(m.procs)) + m.procs = nil +} + +// Stderr returns captured stderr for a named process. +func (m *Manager) Stderr(name string) string { + m.mu.Lock() + defer m.mu.Unlock() + for _, p := range m.procs { + if p.Name == name { + return p.stderr.String() + } + } + return "" +} diff --git a/tool/deepxctl/internal/redis/redis.go b/tool/deepxctl/internal/redis/redis.go new file mode 100644 index 00000000..b72dad45 --- /dev/null +++ b/tool/deepxctl/internal/redis/redis.go @@ -0,0 +1,205 @@ +// Package redis provides Redis connection, FLUSHDB, and system key status checks +// for deepxctl process orchestration. +// +// Allowed operations (per doc/deepxctl/CLAUDE.md): +// +// PING, FLUSHDB, DBSIZE +// GET /sys/op-plat/*, /sys/heap-plat/*, /sys/vm/* +// GET /vthread/ (status polling) +// SET /vthread/ (vthread creation) +// SET /vthread//[0,0], /vthread//[0,1] (entry CALL) +// INCR /sys/vtid_counter +// LPUSH notify:vm +// GET /src/func/ (verification) +// KEYS /src/func/* (function listing) +package redis + +import ( + "context" + "encoding/json" + "fmt" + "log" + "time" + + goredis "github.com/redis/go-redis/v9" +) + +// DefaultAddr is the default Redis address for development. +const DefaultAddr = "127.0.0.1:16379" + +// Connect dials Redis with a short timeout and verifies with PING. +func Connect(addr string) (*goredis.Client, error) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + rdb := goredis.NewClient(&goredis.Options{ + Addr: addr, + PoolSize: 4, + MinIdleConns: 1, + }) + if err := rdb.Ping(ctx).Err(); err != nil { + rdb.Close() + return nil, fmt.Errorf("redis PING failed [%s]: %w", addr, err) + } + log.Printf("[redis] connected to %s", addr) + return rdb, nil +} + +// FlushDB resets the current Redis database. +// Only call this in development (port 16379). +func FlushDB(rdb *goredis.Client) error { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + if err := rdb.FlushDB(ctx).Err(); err != nil { + return fmt.Errorf("FLUSHDB failed: %w", err) + } + // Verify + size, err := rdb.DBSize(ctx).Result() + if err != nil { + return fmt.Errorf("DBSIZE after FLUSHDB failed: %w", err) + } + log.Printf("[redis] FLUSHDB done, dbsize=%d", size) + return nil +} + +// WaitForInstance polls a /sys/ key until it contains status="running" or timeout. +func WaitForInstance(rdb *goredis.Client, key string, timeout time.Duration) error { + ctx := context.Background() + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + val, err := rdb.Get(ctx, key).Result() + if err != nil { + time.Sleep(200 * time.Millisecond) + continue + } + var m map[string]interface{} + if err := json.Unmarshal([]byte(val), &m); err != nil { + time.Sleep(200 * time.Millisecond) + continue + } + if s, ok := m["status"].(string); ok && s == "running" { + log.Printf("[redis] %s is running", key) + return nil + } + time.Sleep(200 * time.Millisecond) + } + return fmt.Errorf("timeout waiting for %s to be running (%.0fs)", key, timeout.Seconds()) +} + +// AllocVtid atomically increments the vthread counter and returns the new ID. +func AllocVtid(rdb *goredis.Client) (int64, error) { + ctx := context.Background() + id, err := rdb.Incr(ctx, "/sys/vtid_counter").Result() + if err != nil { + return 0, fmt.Errorf("INCR /sys/vtid_counter: %w", err) + } + return id, nil +} + +// VThreadStatus represents the status of a vthread from Redis. +type VThreadStatus struct { + PC string `json:"pc"` + Status string `json:"status"` + Error *struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` +} + +// GetVThreadStatus reads the vthread JSON from Redis. +func GetVThreadStatus(rdb *goredis.Client, vtid int64) (*VThreadStatus, error) { + ctx := context.Background() + key := fmt.Sprintf("/vthread/%d", vtid) + val, err := rdb.Get(ctx, key).Result() + if err != nil { + return nil, fmt.Errorf("GET %s: %w", key, err) + } + var s VThreadStatus + if err := json.Unmarshal([]byte(val), &s); err != nil { + return nil, fmt.Errorf("parse %s: %w", key, err) + } + return &s, nil +} + +// CreateVThread writes a vthread with a single top-level CALL instruction. +// +// Writes: +// +// /vthread/ = {"pc":"[0,0]","status":"init"} +// /vthread//[0,0] = "" +// /vthread//[0,1] = "./ret" +func CreateVThread(rdb *goredis.Client, vtid int64, entryFunc string) error { + ctx := context.Background() + base := fmt.Sprintf("/vthread/%d", vtid) + + status := fmt.Sprintf(`{"pc":"[0,0]","status":"init"}`) + + pipe := rdb.Pipeline() + pipe.Set(ctx, base, status, 0) + pipe.Set(ctx, base+"/[0,0]", entryFunc, 0) + pipe.Set(ctx, base+"/[0,1]", "./ret", 0) + if _, err := pipe.Exec(ctx); err != nil { + return fmt.Errorf("create vthread %d: %w", vtid, err) + } + log.Printf("[redis] created vthread %d entry=%s", vtid, entryFunc) + return nil +} + +// WakeVM pushes a new_vthread notification to the VM wake queue. +func WakeVM(rdb *goredis.Client, vtid int64) error { + ctx := context.Background() + notify := map[string]interface{}{ + "event": "new_vthread", + "vtid": fmt.Sprintf("%d", vtid), + } + data, _ := json.Marshal(notify) + if err := rdb.LPush(ctx, "notify:vm", data).Err(); err != nil { + return fmt.Errorf("LPUSH notify:vm: %w", err) + } + log.Printf("[redis] notified VM: vtid=%d", vtid) + return nil +} + +// SrcFuncKeys returns all registered function names under /src/func/. +func SrcFuncKeys(rdb *goredis.Client) ([]string, error) { + ctx := context.Background() + keys, err := rdb.Keys(ctx, "/src/func/*").Result() + if err != nil { + return nil, err + } + // Filter out sub-keys like /src/func/name/0, return unique names + seen := make(map[string]bool) + var names []string + for _, k := range keys { + // /src/func/name → name + // /src/func/name/0 → name + name := k + if len(k) > 11 { // len("/src/func/") + rest := k[10:] // after "/src/func/" + // find first / + for i, c := range rest { + if c == '/' { + name = "/src/func/" + rest[:i] + break + } + } + } + if !seen[name] { + seen[name] = true + names = append(names, name[len("/src/func/"):]) + } + } + return names, nil +} + +// SrcFuncExists returns true if /src/func/ exists (non-empty). +func SrcFuncExists(rdb *goredis.Client, name string) bool { + ctx := context.Background() + val, err := rdb.Get(ctx, "/src/func/"+name).Result() + if err != nil { + return false + } + return val != "" +} diff --git a/tool/deepxctl/main.go b/tool/deepxctl/main.go index 1b5ffc3a..c6fa05a0 100644 --- a/tool/deepxctl/main.go +++ b/tool/deepxctl/main.go @@ -1,63 +1,65 @@ package main import ( - "flag" "fmt" "os" "path/filepath" - "github.com/array2d/deepx/tool/deepxctl/cmd/tensor" + "deepx/tool/deepxctl/cmd" + "deepx/tool/deepxctl/cmd/tensor" ) -var version = "0.1.0" +var version = "0.2.0" func printUsage() { execName := filepath.Base(os.Args[0]) - fmt.Printf("用法: %s [命令] [参数]\n\n", execName) - fmt.Println("可用命令:") - fmt.Println(" tensor 张量操作相关命令") - fmt.Println(" version 显示版本信息") - fmt.Println(" help 显示帮助信息") - fmt.Println("\n使用 '%s help [命令]' 获取命令的详细信息", execName) + fmt.Printf("Usage: %s [command] [arguments]\n\n", execName) + fmt.Println("Commands:") + fmt.Println(" boot Start services (Redis → build → launch op-metal + heap-metal + VM)") + fmt.Println(" run Run a .dx file (requires prior boot)") + fmt.Println(" shutdown Stop all booted services") + fmt.Println(" tensor Tensor file operations (print)") + fmt.Println(" version Show version") + fmt.Println(" help Show this help") + fmt.Println() + fmt.Println("Typical workflow:") + fmt.Printf(" %s boot # start services once\n", execName) + fmt.Printf(" %s run file.dx # execute .dx (repeatable)\n", execName) + fmt.Printf(" %s shutdown # stop services when done\n", execName) + fmt.Println() + fmt.Printf("Run '%s [command] --help' for per-command flags.\n", execName) } func main() { - flag.Usage = printUsage - if len(os.Args) < 2 { printUsage() os.Exit(0) } - // 获取子命令 - cmd := os.Args[1] + subcmd := os.Args[1] + + switch subcmd { + case "boot": + cmd.Boot(os.Args[2:]) + + case "run": + cmd.Run(os.Args[2:]) + + case "shutdown": + cmd.Shutdown() - // 根据子命令执行相应操作 - switch cmd { case "tensor": - // 移除子命令,让子命令处理剩余的参数 - os.Args = os.Args[2:] + os.Args = os.Args[1:] tensor.Execute() case "version": - fmt.Printf("deepxctl 版本 %s\n", version) + fmt.Printf("deepxctl version %s\n", version) - case "help": - if len(os.Args) > 2 { - helpCmd := os.Args[2] - switch helpCmd { - case "tensor": - tensor.PrintUsage() - default: - fmt.Printf("未知命令: %s\n", helpCmd) - printUsage() - } - } else { - printUsage() - } + case "help", "-h", "--help": + printUsage() default: - fmt.Printf("未知命令: %s\n", cmd) + fmt.Fprintf(os.Stderr, "Unknown command: %s\n\n", subcmd) printUsage() os.Exit(1) }