[Mlir-commits] [mlir] [MLIR][NVVM] Add movmatrix Op (PR #193995)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 24 08:45:24 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Varad Rahul Kamthe (varadk27)
<details>
<summary>Changes</summary>
[MLIR][NVVM] Add movmatrix Op
Add MLIR NVVM dialect Op for the movmatrix PTX instruction, which moves
a row-major matrix across all threads in a warp and writes the transposed
elements to the destination.
---
Full diff: https://github.com/llvm/llvm-project/pull/193995.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+31)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+18)
- (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+8)
- (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+28)
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9e94477c3a60a..b2b24f49ea2cc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3297,6 +3297,37 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix", [InferTypeOpAdaptor]>,
let hasVerifier = 1;
}
+def NVVM_MovMatrixOp :
+ NVVM_SingleResultIntrinsicOp<"movmatrix", [NVVMRequiresSM<75>], "$dst"> {
+ let summary = "Warp-level matrix transpose";
+ let description = [{
+ Moves a row-major matrix across all threads in a warp, reading elements
+ from source `$src`, and writing the transposed elements to destination
+ `$dst`.
+
+ The `shape` attribute indicates the dimensions of the matrix being
+ transposed. Each matrix element holds 16-bit data as indicated by the
+ `eltType` attribute.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-movmatrix-instruction)
+
+ Example:
+ ```mlir
+ %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+ ```
+ }];
+
+ let results = (outs I32:$dst);
+ let arguments = (ins I32:$src,
+ LdStMatrixShapeAttr:$shape,
+ DefaultValuedAttr<MMALayoutAttr, "MMALayout::col">:$layout,
+ LdStMatrixEltTypeAttr:$eltType);
+
+ let assemblyFormat = "$src attr-dict `:` type($src)";
+ let hasVerifier = 1;
+}
+
def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
let summary = "cooperative matrix-multiply and accumulate";
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index b0cebd45624a9..35cba2d7bc2fe 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2607,6 +2607,17 @@ LogicalResult NVVM::StMatrixOp::verify() {
return success();
}
+LogicalResult NVVM::MovMatrixOp::verify() {
+ int m = getShape().getM(), n = getShape().getN();
+ if (m != 8 || n != 8)
+ return emitOpError("expected shape to be 8x8");
+ if (getLayout() != NVVM::MMALayout::col)
+ return emitOpError("expected layout to be col");
+ if (getEltType() != NVVM::LdStMatrixEltType::B16)
+ return emitOpError("expected element type to be b16");
+ return success();
+}
+
static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
if (typeA == NVVM::WGMMATypes::tf32)
return 8;
@@ -3864,6 +3875,13 @@ mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
return {id, {mt.lookupValue(thisOp.getAddr())}};
}
+mlir::NVVM::IDArgPair MovMatrixOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MovMatrixOp>(op);
+ return {llvm::Intrinsic::nvvm_movmatrix_sync_aligned_m8n8_trans_b16,
+ {mt.lookupValue(thisOp.getSrc())}};
+}
+
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c039edc6b5de5..bbb3cd4b38a41 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -133,6 +133,14 @@ func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
llvm.return %0 : i32
}
+// CHECK-LABEL: @nvvm_movmatrix
+func.func @nvvm_movmatrix(%src : i32) -> i32 {
+ // CHECK: nvvm.movmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : i32
+ %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+ llvm.return %dst : i32
+}
+
// CHECK-LABEL: @llvm_nvvm_bar_warp_sync
func.func @llvm_nvvm_bar_warp_sync(%mask : i32) {
// CHECK: nvvm.bar.warp.sync %{{.*}}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 2726fc7a40ef0..a36984590b89b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -541,6 +541,34 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// -----
+llvm.func @mov_matrix(%src : i32) -> i32 {
+ // expected-error at +1 {{'nvvm.movmatrix' op expected shape to be 8x8}}
+ %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+ llvm.return %dst : i32
+}
+
+// -----
+
+llvm.func @mov_matrix(%src : i32) -> i32 {
+ // expected-error at +1 {{'nvvm.movmatrix' op expected layout to be col}}
+ %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+ layout = #nvvm.mma_layout<row>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+ llvm.return %dst : i32
+}
+
+// -----
+
+llvm.func @mov_matrix(%src : i32) -> i32 {
+ // expected-error at +1 {{'nvvm.movmatrix' op expected element type to be b16}}
+ %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b8>} : i32
+ llvm.return %dst : i32
+}
+
+// -----
+
llvm.func @clusterlaunchcontrol_query_cancel_is_canceled_invalid_return_type(%try_cancel_response: i128) {
// expected-error at +1 {{'nvvm.clusterlaunchcontrol.query.cancel' op is_canceled query type returns an i1}}
%res = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %try_cancel_response : i32
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 64c44be3e8182..f2888025d8a08 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -617,6 +617,14 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32
llvm.return
}
+// CHECK-LABEL: @nvvm_movmatrix
+llvm.func @nvvm_movmatrix(%src : i32) -> i32 {
+ // CHECK: call i32 @llvm.nvvm.movmatrix.sync.aligned.m8n8.trans.b16(i32 %{{.*}})
+ %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+ llvm.return %dst : i32
+}
+
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {
``````````
</details>
https://github.com/llvm/llvm-project/pull/193995
More information about the Mlir-commits
mailing list