[Mlir-commits] [mlir] [MLIR][NVVM] Add movmatrix Op (PR #193995)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 24 08:45:24 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Varad Rahul Kamthe (varadk27)

<details>
<summary>Changes</summary>

[MLIR][NVVM] Add movmatrix Op

Add MLIR NVVM dialect Op for the movmatrix PTX instruction, which moves
a row-major matrix across all threads in a warp and writes the transposed
elements to the destination.

---
Full diff: https://github.com/llvm/llvm-project/pull/193995.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+31) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+18) 
- (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+8) 
- (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+28) 
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9e94477c3a60a..b2b24f49ea2cc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3297,6 +3297,37 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix", [InferTypeOpAdaptor]>,
   let hasVerifier = 1;
 }
 
+def NVVM_MovMatrixOp :
+  NVVM_SingleResultIntrinsicOp<"movmatrix", [NVVMRequiresSM<75>], "$dst"> {
+  let summary = "Warp-level matrix transpose";
+  let description = [{
+    Moves a row-major matrix across all threads in a warp, reading elements
+    from source `$src`, and writing the transposed elements to destination
+    `$dst`.
+
+    The `shape` attribute indicates the dimensions of the matrix being
+    transposed. Each matrix element holds 16-bit data as indicated by the
+    `eltType` attribute.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-movmatrix-instruction)
+
+    Example:
+    ```mlir
+    %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+                                eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+    ```
+  }];
+
+  let results = (outs I32:$dst);
+  let arguments = (ins I32:$src,
+                       LdStMatrixShapeAttr:$shape,
+                       DefaultValuedAttr<MMALayoutAttr, "MMALayout::col">:$layout,
+                       LdStMatrixEltTypeAttr:$eltType);
+
+  let assemblyFormat = "$src attr-dict `:` type($src)";
+  let hasVerifier = 1;
+}
+
 def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 
   let summary = "cooperative matrix-multiply and accumulate";
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index b0cebd45624a9..35cba2d7bc2fe 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2607,6 +2607,17 @@ LogicalResult NVVM::StMatrixOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::MovMatrixOp::verify() {
+  int m = getShape().getM(), n = getShape().getN();
+  if (m != 8 || n != 8)
+    return emitOpError("expected shape to be 8x8");
+  if (getLayout() != NVVM::MMALayout::col)
+    return emitOpError("expected layout to be col");
+  if (getEltType() != NVVM::LdStMatrixEltType::B16)
+    return emitOpError("expected element type to be b16");
+  return success();
+}
+
 static FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
   if (typeA == NVVM::WGMMATypes::tf32)
     return 8;
@@ -3864,6 +3875,13 @@ mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
   return {id, {mt.lookupValue(thisOp.getAddr())}};
 }
 
+mlir::NVVM::IDArgPair MovMatrixOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::MovMatrixOp>(op);
+  return {llvm::Intrinsic::nvvm_movmatrix_sync_aligned_m8n8_trans_b16,
+          {mt.lookupValue(thisOp.getSrc())}};
+}
+
 #define CP_ASYNC_ID_IMPL(mod, size, suffix)                                    \
   llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
 
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c039edc6b5de5..bbb3cd4b38a41 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -133,6 +133,14 @@ func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
   llvm.return %0 : i32
 }
 
+// CHECK-LABEL: @nvvm_movmatrix
+func.func @nvvm_movmatrix(%src : i32) -> i32 {
+  // CHECK: nvvm.movmatrix %{{.*}} {eltType = #nvvm.ld_st_matrix_elt_type<b16>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : i32
+  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+                              eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+  llvm.return %dst : i32
+}
+
 // CHECK-LABEL: @llvm_nvvm_bar_warp_sync
 func.func @llvm_nvvm_bar_warp_sync(%mask : i32) {
   // CHECK: nvvm.bar.warp.sync %{{.*}}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 2726fc7a40ef0..a36984590b89b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -541,6 +541,34 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
 
 // -----
 
+llvm.func @mov_matrix(%src : i32) -> i32 {
+  // expected-error at +1 {{'nvvm.movmatrix' op expected shape to be 8x8}}
+  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,
+                              eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+  llvm.return %dst : i32
+}
+
+// -----
+
+llvm.func @mov_matrix(%src : i32) -> i32 {
+  // expected-error at +1 {{'nvvm.movmatrix' op expected layout to be col}}
+  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+                              layout = #nvvm.mma_layout<row>,
+                              eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+  llvm.return %dst : i32
+}
+
+// -----
+
+llvm.func @mov_matrix(%src : i32) -> i32 {
+  // expected-error at +1 {{'nvvm.movmatrix' op expected element type to be b16}}
+  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+                              eltType = #nvvm.ld_st_matrix_elt_type<b8>} : i32
+  llvm.return %dst : i32
+}
+
+// -----
+
 llvm.func @clusterlaunchcontrol_query_cancel_is_canceled_invalid_return_type(%try_cancel_response: i128) {
   // expected-error at +1 {{'nvvm.clusterlaunchcontrol.query.cancel' op is_canceled query type returns an i1}}
   %res = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %try_cancel_response : i32
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 64c44be3e8182..f2888025d8a08 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -617,6 +617,14 @@ llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32
   llvm.return
 }
 
+// CHECK-LABEL: @nvvm_movmatrix
+llvm.func @nvvm_movmatrix(%src : i32) -> i32 {
+  // CHECK: call i32 @llvm.nvvm.movmatrix.sync.aligned.m8n8.trans.b16(i32 %{{.*}})
+  %dst = nvvm.movmatrix %src {shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+                              eltType = #nvvm.ld_st_matrix_elt_type<b16>} : i32
+  llvm.return %dst : i32
+}
+
 // This function has the "kernel" attribute attached and should appear in the
 // NVVM annotations after conversion.
 llvm.func @kernel_func() attributes {nvvm.kernel} {

``````````

</details>


https://github.com/llvm/llvm-project/pull/193995


More information about the Mlir-commits mailing list