[Mlir-commits] [mlir] [MLIR] Fix the PTX generation bug for StMatrixOp (PR #148250)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 11 08:01:33 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm
Author: Pecco (Pecco-314)
<details>
<summary>Changes</summary>
According to the [PTX documents](https://docs.nvidia.com/cuda/parallel-thread-execution/), the syntax of stmatrix should be:
```
stmatrix.sync.aligned.shape.num{.trans}{.ss}.type [p], r;
.shape = {.m8n8, .m16n8};
.num = {.x1, .x2, .x4};
.ss = {.shared{::cta}};
.type = {.b16, .b8};
```
However, the current code will generate the PTX like "stmatrix.sync.aligned.x4.m8n8.shared.b16". It seems like a bug.
---
Full diff: https://github.com/llvm/llvm-project/pull/148250.diff
1 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+4-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6895e946b8a45..b27c03ec2c63f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2000,13 +2000,13 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
int d = getSources().size();
- std::string ptx = "stmatrix.sync.aligned";
+ std::string ptx = "stmatrix.sync.aligned.m8n8";
ptx += ".x" + std::to_string(d);
if (getLayout() == NVVM::MMALayout::col)
ptx += ".trans";
- if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};";
- if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};";
- if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
+ if(d == 1) ptx += ".shared.b16 [%0], {%1};";
+ if(d == 2) ptx += ".shared.b16 [%0], {%1, %2};";
+ if(d == 4) ptx += ".shared.b16 [%0], {%1, %2, %3, %4};";
return ptx;
}
}];
``````````
</details>
https://github.com/llvm/llvm-project/pull/148250
More information about the Mlir-commits
mailing list