[Mlir-commits] [mlir] [MLIR][NVVM] Add Op to create tcgen05-mma smem descriptor (PR #141651)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 27 11:24:03 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

<details>
<summary>Changes</summary>

This patch adds an Op to create the shared-memory
descriptor for Tcgen05 MMA.

---
Full diff: https://github.com/llvm/llvm-project/pull/141651.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+64) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+44) 
- (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir (+38) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 13f693872d890..408537be0a5e4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3373,6 +3373,70 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
   }];
 }
 
+def NVVM_Tcgen05MmaSmemDescOp : NVVM_Op<"tcgen05.mma_smem_desc", []> {
+  let summary = "Constructs a Shared Memory descriptor for MMA Operands A or B";
+  let description = [{
+    The `nvvm.tcgen05_mma_smem_desc` constructs a Shared Memory descriptor
+    for tcgen05.mma. This descriptor is a 64-bit value which describes the
+    properties of multiplicand matrix in shared memory including its location
+    in the shared memory of the current CTA.
+
+    +-----------+------+------------------------------------------------------+
+    | Bit-field | Size | Description                                          |
+    +-----------+------+------------------------------------------------------+
+    | 0-13      | 14   | Matrix start address                                 |
+    | 14-15     | 2    | Reserved                                             |
+    | 16-29     | 14   | Leading dim relative-offset (or) absolute-address    |
+    | 30-31     | 2    | Reserved                                             |
+    | 32-45     | 14   | Stride dimension byte offset                         |
+    | 46-48     | 3    | Fixed constant value of 0b001                        |
+    | 49-51     | 3    | Matrix base offset                                   |
+    | 52        | 1    | Leading dimension stride mode:                       |
+    |           |      |   0: byte offset relative                            |
+    |           |      |   1: byte address absolute                           |
+    | 53-60     | 8    | Fixed constant value of 0xb00000000                  |
+    | 61-63     | 3    | Swizzling mode:                                      |
+    |           |      |   0: No swizzling                                    |
+    |           |      |   1: 128-Byte with 32B atomic swizzling              |
+    |           |      |   2: 128-Byte swizzling                              |
+    |           |      |   4: 64-Byte swizzling                               |
+    |           |      |   6: 32-Byte swizzling                               |
+    |           |      |   (Values 3, 5 and 7 are invalid)                    |
+    +-----------+------+------------------------------------------------------+    
+
+    Example:
+    ```mlir
+      %desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset,
+                                          %baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
+    ```
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor)
+  }];
+
+  let arguments = (ins
+      I32:$startAddr,        // Matrix A or B start address (bits 13-0)
+      I32:$leadingDimOffset, // Matrix A or B leading dim byte offset (bits 29-16)
+      I32:$strideDimOffset,  // Matrix A or B stride  dim byte offset (bits 45-32)
+      I8:$baseOffset,        // Matrix A or B base offset (bits 51-49)
+      I1:$leadingDimMode,    // Matrix A or B leading dim mode (bit 52)
+      I8:$swizzleMode        // Swizzle mode (bits 63-61)
+  );
+
+  let results = (outs I64:$res);
+
+  let assemblyFormat = [{
+    `(` operands `)` attr-dict `:` `(` type(operands) `)` `->` type($res)
+  }];
+
+  let extraClassDeclaration = [{
+    static void createSmemDescriptor(Operation &op, LLVM::ModuleTranslation &mt,
+                                     llvm::IRBuilderBase& builder);
+  }];
+
+  string llvmBuilder = [{
+    NVVM::Tcgen05MmaSmemDescOp::createSmemDescriptor(*op, moduleTranslation, builder);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM tcgen05 LdSt Shape Attr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 79d9d2f6255e7..8036ea27f524f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1212,6 +1212,50 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
                                llvm::Type::getInt32Ty(builder.getContext()));
 }
 
+/// Packs the given `field` into the `result`.
+/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
+static llvm::Value *
+packValInto64Bits(llvm::IRBuilderBase &builder,
+                  llvm::Value *result, // the `result` (unset bits are zero)
+                  llvm::Value *field,  // `field` to pack into `result`
+                  unsigned sizeInBits, // Size of `field` in bits
+                  unsigned start) {    // Starting bit within `result`
+  field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
+
+  unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
+  if (mask != 0xffffffffu)
+    field = builder.CreateAnd(field, builder.getInt32(mask));
+
+  field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
+  field = builder.CreateShl(field, start);
+
+  return builder.CreateOr(result, field);
+}
+
+void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
+                                                LLVM::ModuleTranslation &mt,
+                                                llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
+  llvm::Value *smemDesc = builder.getInt64(0);
+
+  smemDesc = packValInto64Bits(builder, smemDesc,
+                               mt.lookupValue(thisOp.getStartAddr()), 14, 0);
+  smemDesc = packValInto64Bits(
+      builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
+  smemDesc = packValInto64Bits(
+      builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
+
+  smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
+  smemDesc = packValInto64Bits(builder, smemDesc,
+                               mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
+  smemDesc = packValInto64Bits(
+      builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
+  smemDesc = packValInto64Bits(builder, smemDesc,
+                               mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
+
+  mt.mapValue(thisOp.getRes()) = smemDesc;
+}
+
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir
new file mode 100644
index 0000000000000..5af79c6f1379b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define i64 @tcgen05_mma_smem_desc_test(i32 %0, i32 %1, i32 %2, i8 %3, i1 %4, i8 %5) {
+llvm.func @tcgen05_mma_smem_desc_test(%startAddr: i32, %leadingDimOffset: i32, %strideDimOffset: i32,
+                                      %baseOffset: i8, %leadingDimMode: i1, %swizzleMode: i8) -> i64 {
+  // CHECK-NEXT: %7 = and i32 %0, 16383
+  // CHECK-NEXT: %8 = zext i32 %7 to i64
+  // CHECK-NEXT: %9 = shl i64 %8, 0
+  // CHECK-NEXT: %10 = or i64 0, %9
+  // CHECK-NEXT: %11 = and i32 %1, 16383
+  // CHECK-NEXT: %12 = zext i32 %11 to i64
+  // CHECK-NEXT: %13 = shl i64 %12, 16
+  // CHECK-NEXT: %14 = or i64 %10, %13
+  // CHECK-NEXT: %15 = and i32 %2, 16383
+  // CHECK-NEXT: %16 = zext i32 %15 to i64
+  // CHECK-NEXT: %17 = shl i64 %16, 32
+  // CHECK-NEXT: %18 = or i64 %14, %17
+  // CHECK-NEXT: %19 = or i64 %18, 70368744177664
+  // CHECK-NEXT: %20 = zext i8 %3 to i32
+  // CHECK-NEXT: %21 = and i32 %20, 7
+  // CHECK-NEXT: %22 = zext i32 %21 to i64
+  // CHECK-NEXT: %23 = shl i64 %22, 49
+  // CHECK-NEXT: %24 = or i64 %19, %23
+  // CHECK-NEXT: %25 = zext i1 %4 to i32
+  // CHECK-NEXT: %26 = and i32 %25, 1
+  // CHECK-NEXT: %27 = zext i32 %26 to i64
+  // CHECK-NEXT: %28 = shl i64 %27, 52
+  // CHECK-NEXT: %29 = or i64 %24, %28
+  // CHECK-NEXT: %30 = zext i8 %5 to i32
+  // CHECK-NEXT: %31 = and i32 %30, 7
+  // CHECK-NEXT: %32 = zext i32 %31 to i64
+  // CHECK-NEXT: %33 = shl i64 %32, 61
+  // CHECK-NEXT: %34 = or i64 %29, %33
+  // CHECK-NEXT: ret i64 %34
+  // CHECK-NEXT: }
+  %desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset, %baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
+  llvm.return %desc : i64
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/141651


More information about the Mlir-commits mailing list