[Mlir-commits] [mlir] [MLIR][NVVM] Add Op to create tcgen05-mma smem descriptor (PR #141651)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 27 11:24:03 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Durgadoss R (durga4github)
<details>
<summary>Changes</summary>
This patch adds an Op to create the shared-memory
descriptor for Tcgen05 MMA.
---
Full diff: https://github.com/llvm/llvm-project/pull/141651.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+64)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+44)
- (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir (+38)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 13f693872d890..408537be0a5e4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3373,6 +3373,70 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
}];
}
+def NVVM_Tcgen05MmaSmemDescOp : NVVM_Op<"tcgen05.mma_smem_desc", []> {
+ let summary = "Constructs a Shared Memory descriptor for MMA Operands A or B";
+ let description = [{
+ The `nvvm.tcgen05_mma_smem_desc` constructs a Shared Memory descriptor
+ for tcgen05.mma. This descriptor is a 64-bit value which describes the
+ properties of multiplicand matrix in shared memory including its location
+ in the shared memory of the current CTA.
+
+ +-----------+------+------------------------------------------------------+
+ | Bit-field | Size | Description |
+ +-----------+------+------------------------------------------------------+
+ | 0-13 | 14 | Matrix start address |
+ | 14-15 | 2 | Reserved |
+ | 16-29 | 14 | Leading dim relative-offset (or) absolute-address |
+ | 30-31 | 2 | Reserved |
+ | 32-45 | 14 | Stride dimension byte offset |
+ | 46-48 | 3 | Fixed constant value of 0b001 |
+ | 49-51 | 3 | Matrix base offset |
+ | 52 | 1 | Leading dimension stride mode: |
+ | | | 0: byte offset relative |
+ | | | 1: byte address absolute |
+ | 53-60 | 8 | Fixed constant value of 0xb00000000 |
+ | 61-63 | 3 | Swizzling mode: |
+ | | | 0: No swizzling |
+ | | | 1: 128-Byte with 32B atomic swizzling |
+ | | | 2: 128-Byte swizzling |
+ | | | 4: 64-Byte swizzling |
+ | | | 6: 32-Byte swizzling |
+ | | | (Values 3, 5 and 7 are invalid) |
+ +-----------+------+------------------------------------------------------+
+
+ Example:
+ ```mlir
+ %desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset,
+ %baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
+ ```
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor)
+ }];
+
+ let arguments = (ins
+ I32:$startAddr, // Matrix A or B start address (bits 13-0)
+ I32:$leadingDimOffset, // Matrix A or B leading dim byte offset (bits 29-16)
+ I32:$strideDimOffset, // Matrix A or B stride dim byte offset (bits 45-32)
+ I8:$baseOffset, // Matrix A or B base offset (bits 51-49)
+ I1:$leadingDimMode, // Matrix A or B leading dim mode (bit 52)
+ I8:$swizzleMode // Swizzle mode (bits 63-61)
+ );
+
+ let results = (outs I64:$res);
+
+ let assemblyFormat = [{
+ `(` operands `)` attr-dict `:` `(` type(operands) `)` `->` type($res)
+ }];
+
+ let extraClassDeclaration = [{
+ static void createSmemDescriptor(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }];
+
+ string llvmBuilder = [{
+ NVVM::Tcgen05MmaSmemDescOp::createSmemDescriptor(*op, moduleTranslation, builder);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM tcgen05 LdSt Shape Attr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 79d9d2f6255e7..8036ea27f524f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1212,6 +1212,50 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
llvm::Type::getInt32Ty(builder.getContext()));
}
+/// Packs the given `field` into the `result`.
+/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
+static llvm::Value *
+packValInto64Bits(llvm::IRBuilderBase &builder,
+ llvm::Value *result, // the `result` (unset bits are zero)
+ llvm::Value *field, // `field` to pack into `result`
+ unsigned sizeInBits, // Size of `field` in bits
+ unsigned start) { // Starting bit within `result`
+ field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
+
+ unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
+ if (mask != 0xffffffffu)
+ field = builder.CreateAnd(field, builder.getInt32(mask));
+
+ field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
+ field = builder.CreateShl(field, start);
+
+ return builder.CreateOr(result, field);
+}
+
+void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
+ llvm::Value *smemDesc = builder.getInt64(0);
+
+ smemDesc = packValInto64Bits(builder, smemDesc,
+ mt.lookupValue(thisOp.getStartAddr()), 14, 0);
+ smemDesc = packValInto64Bits(
+ builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
+ smemDesc = packValInto64Bits(
+ builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
+
+ smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
+ smemDesc = packValInto64Bits(builder, smemDesc,
+ mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
+ smemDesc = packValInto64Bits(
+ builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
+ smemDesc = packValInto64Bits(builder, smemDesc,
+ mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
+
+ mt.mapValue(thisOp.getRes()) = smemDesc;
+}
+
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir
new file mode 100644
index 0000000000000..5af79c6f1379b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define i64 @tcgen05_mma_smem_desc_test(i32 %0, i32 %1, i32 %2, i8 %3, i1 %4, i8 %5) {
+llvm.func @tcgen05_mma_smem_desc_test(%startAddr: i32, %leadingDimOffset: i32, %strideDimOffset: i32,
+ %baseOffset: i8, %leadingDimMode: i1, %swizzleMode: i8) -> i64 {
+ // CHECK-NEXT: %7 = and i32 %0, 16383
+ // CHECK-NEXT: %8 = zext i32 %7 to i64
+ // CHECK-NEXT: %9 = shl i64 %8, 0
+ // CHECK-NEXT: %10 = or i64 0, %9
+ // CHECK-NEXT: %11 = and i32 %1, 16383
+ // CHECK-NEXT: %12 = zext i32 %11 to i64
+ // CHECK-NEXT: %13 = shl i64 %12, 16
+ // CHECK-NEXT: %14 = or i64 %10, %13
+ // CHECK-NEXT: %15 = and i32 %2, 16383
+ // CHECK-NEXT: %16 = zext i32 %15 to i64
+ // CHECK-NEXT: %17 = shl i64 %16, 32
+ // CHECK-NEXT: %18 = or i64 %14, %17
+ // CHECK-NEXT: %19 = or i64 %18, 70368744177664
+ // CHECK-NEXT: %20 = zext i8 %3 to i32
+ // CHECK-NEXT: %21 = and i32 %20, 7
+ // CHECK-NEXT: %22 = zext i32 %21 to i64
+ // CHECK-NEXT: %23 = shl i64 %22, 49
+ // CHECK-NEXT: %24 = or i64 %19, %23
+ // CHECK-NEXT: %25 = zext i1 %4 to i32
+ // CHECK-NEXT: %26 = and i32 %25, 1
+ // CHECK-NEXT: %27 = zext i32 %26 to i64
+ // CHECK-NEXT: %28 = shl i64 %27, 52
+ // CHECK-NEXT: %29 = or i64 %24, %28
+ // CHECK-NEXT: %30 = zext i8 %5 to i32
+ // CHECK-NEXT: %31 = and i32 %30, 7
+ // CHECK-NEXT: %32 = zext i32 %31 to i64
+ // CHECK-NEXT: %33 = shl i64 %32, 61
+ // CHECK-NEXT: %34 = or i64 %29, %33
+ // CHECK-NEXT: ret i64 %34
+ // CHECK-NEXT: }
+ %desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset, %baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
+ llvm.return %desc : i64
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/141651
More information about the Mlir-commits
mailing list