[Mlir-commits] [mlir] [MLIR][NVVM] Update prefetch.tensormap Op (PR #153134)

Srinivasa Ravi llvmlistbot at llvm.org
Sun Aug 17 21:42:04 PDT 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/153134

>From 80b066599399eded3fb393f74fb9bec88b649702 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 11 Aug 2025 13:24:37 +0530
Subject: [PATCH] [MLIR][NVVM] Update prefetch.tensormap Op

This change updates the `prefetch.tensormap` NVVM Op to lower
through the `llvm.nvvm.prefetch.tensormap` intrinsics.
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 43 ++++++++++++++++---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 14 ++++++
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   |  2 +-
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 11 +++++
 mlir/test/Target/LLVMIR/nvvm/prefetch.mlir    | 14 ++++++
 5 files changed, 78 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8d507268a3a15..bcefc5b199b8b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -25,9 +25,12 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
 def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
 def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
 def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
+def LLVM_PointerConst : LLVM_PointerInAddressSpace<4>;
 def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
 def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
 def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
+def LLVM_PointerParam : LLVM_PointerInAddressSpace<101>; // Parameter State Space (PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#parameter-state-space)
+
 
 //===----------------------------------------------------------------------===//
 // NVVM dialect definitions
@@ -2464,15 +2467,45 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
   }];
 }
 
-def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
-  Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {
-  let assemblyFormat = "$tmaDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap", 
+    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, NVVMRequiresSM<90>]> {
+  let summary = "Brings the cache line containing an address from the constant (`.const`) or parameter (`.param`) state spaces for subsequent use by the `cp.async.bulk.tensor` instruction";
+  let description = [{
+    Operand `tmaDescriptor` can be a `const` or generic address 
+    pointer.
+    If it is a generic address pointer, it must map to a memory 
+    location in the [const](https://docs.nvidia.com/cuda/parallel-thread-execution/#constant-state-space) or [param](https://docs.nvidia.com/cuda/parallel-thread-execution/#parameter-state-space) state space.
+
+    If the `in_param_space` attribute is present, the `tmaDescriptor` must be a
+    generic address pointer and is treated as pointing to a memory location in 
+    the `param` state space. 
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
+  }];
+  let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric,
+                                  LLVM_PointerConst]>:$tmaDescriptor,
+                       PtxPredicate:$predicate,
+                       UnitAttr:$in_param_space);
+  let assemblyFormat = "(`param` $in_param_space^)? $tmaDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = "bool hasIntrinsic() { return !getPredicate(); }";
   let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { 
+    std::string $cppClass::getPtx() {
       return std::string("prefetch.tensormap [%0];");
     }
   }];
+  let llvmBuilder = [{
+    llvm::Value* tmaDesc;
+    if(op.getInParamSpace())
+      tmaDesc = builder.CreateAddrSpaceCast($tmaDescriptor, 
+                          llvm::PointerType::get(builder.getContext(), 101));
+    else
+      tmaDesc = $tmaDescriptor;
+    
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_prefetch_tensormap, 
+                        {tmaDesc}, {tmaDesc->getType()});
+  }];
 }
 
 def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ad429efc9fad..9bd6748a0d58e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1265,6 +1265,20 @@ LogicalResult NVVM::PrefetchOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::PrefetchTensorMapOp::verify() {
+  using MemSpace = NVVM::NVVMMemorySpace;
+  unsigned addressSpace =
+      llvm::cast<LLVM::LLVMPointerType>(getTmaDescriptor().getType())
+          .getAddressSpace();
+
+  if (getInParamSpace()) {
+    if (addressSpace != MemSpace::kGenericMemorySpace)
+      return emitOpError(
+          "in_param_space can only be specified for a generic pointer");
+  }
+  return success();
+}
+
 /// Packs the given `field` into the `result`.
 /// The `result` is 64-bits and each `field` can be 32-bits or narrower.
 static llvm::Value *
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index e50576722e38c..956ae113ba020 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -582,7 +582,7 @@ func.func @elect_one_leader_sync() {
 
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx
 llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
+  //CHECK: nvvm.prefetch.tensormap %{{.*}}
   nvvm.prefetch.tensormap %desc : !llvm.ptr
   //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
   nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c7fa41c98ac92..61d2bd29a64e3 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -619,6 +619,17 @@ func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr:
   return
 }
 
+// CHECK-LABEL: @prefetch_tensormap
+func.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
+  // CHECK:   nvvm.prefetch.tensormap %{{.*}}
+  nvvm.prefetch.tensormap %gen_ptr : !llvm.ptr
+  // CHECK:   nvvm.prefetch.tensormap %{{.*}}
+  nvvm.prefetch.tensormap %const_ptr : !llvm.ptr<4>
+  // CHECK:   nvvm.prefetch.tensormap param %{{.*}}
+  nvvm.prefetch.tensormap param %gen_ptr : !llvm.ptr
+  return
+}
+
 // -----
 
 // Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
index f38b7529a7233..bd6943c041e8c 100644
--- a/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
@@ -45,3 +45,17 @@ llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
   nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
   llvm.return
 }
+
+llvm.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
+  // CHECK-LABEL: define void @prefetch_tensormap(ptr %0, ptr addrspace(4) %1) {
+  // CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p0(ptr %0)
+  // CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p4(ptr addrspace(4) %1)
+  // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(101)
+  // CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p101(ptr addrspace(101) %3)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  nvvm.prefetch.tensormap %gen_ptr : !llvm.ptr
+  nvvm.prefetch.tensormap %const_ptr: !llvm.ptr<4>
+  nvvm.prefetch.tensormap param %gen_ptr : !llvm.ptr
+  llvm.return
+}



More information about the Mlir-commits mailing list