[Mlir-commits] [mlir] f1032f0 - [MLIR][NVVM][NVGPU] Combine prefetch and prefetch.tensormap (#153134)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 1 03:26:35 PDT 2025


Author: Srinivasa Ravi
Date: 2025-09-01T15:56:31+05:30
New Revision: f1032f06e8bd6571a929e45ca8d0afc9f17957cc

URL: https://github.com/llvm/llvm-project/commit/f1032f06e8bd6571a929e45ca8d0afc9f17957cc
DIFF: https://github.com/llvm/llvm-project/commit/f1032f06e8bd6571a929e45ca8d0afc9f17957cc.diff

LOG: [MLIR][NVVM][NVGPU] Combine prefetch and prefetch.tensormap (#153134)

This PR combines the `prefetch` and `prefetch.tensormap` NVVM Ops
to one `prefetch` Op. The `tensormap` variant is lowered through the
newly added intrinsics.

The lowering of the NVGPU `tma.prefetch.descriptor` Op is changed
from lowering to the `prefetch.tensormap` Op to `prefetch`.

PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
    mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9528da05c9fd6..8537c7030aa8f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
 def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
 def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
 def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
+def LLVM_PointerConst : LLVM_PointerInAddressSpace<4>;
 def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
 def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
 def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
@@ -2570,15 +2571,25 @@ def PrefetchCacheLevelAttr : EnumAttr<NVVM_Dialect, PrefetchCacheLevel, "prefetc
   let assemblyFormat = "$value";
 }
 
-def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
+def NVVM_PrefetchOp : NVVM_Op<"prefetch",
+    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
   let summary = "Brings the cache line containing an address into the specified cache level";
   let description = [{
-    Operand `addr` can be a global, local or generic address pointer. No 
-    operation is performed if `addr` maps to a `shared` memory location.
+    Prefetches the cache line containing the address given by `addr`. The 
+    operand may be a global, local, or generic pointer. When `tensormap` is 
+    specified, the operand may instead be a constant or generic pointer. If the 
+    address maps to shared memory, the operation has no effect.
+
+    At most one of `cacheLevel` or `tensormap` may be present. The `cacheLevel` 
+    attribute selects the target cache level. When combined with `uniform`, the 
+    prefetch is performed to the uniform cache, in which case `addr` must be a 
+    generic pointer.
+
+    When `tensormap` is used, the line containing `addr` is brought from the 
+    constant or parameter state space for later use by `cp.async.bulk.tensor`. 
+    If `in_param_space` is specified, the generic pointer is interpreted as 
+    referring to the parameter state space.
 
-    The `cacheLevel` attribute specifies the cache level to which the cache line
-    containing the specified address is brought.
-    
     `uniform` can be specified after the `cacheLevel` to indicate that the 
     prefetch is performed to the specified uniform cache level. If `uniform` is 
     specified, `addr` must be a generic address pointer and no operation is 
@@ -2589,33 +2600,41 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
 
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
   }];
-  let arguments = (ins PrefetchCacheLevelAttr:$cacheLevel,
-                       UnitAttr:$uniform,
+  let arguments = (ins OptionalAttr<PrefetchCacheLevelAttr>:$cacheLevel,
+                       OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority,
                        AnyTypeOf<[LLVM_PointerGlobal,
                                   LLVM_PointerLocal,
-                                  LLVM_PointerGeneric]>:$addr,
-                       OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority);
-  let assemblyFormat = "`level` `=` $cacheLevel (`uniform` $uniform^)? `,` $addr (`,` `evict_priority` `=` $evictPriority^)? attr-dict `:` type($addr)";
+                                  LLVM_PointerGeneric,
+                                  LLVM_PointerConst]>:$addr,
+                       PtxPredicate:$predicate,
+                       UnitAttr:$tensormap,
+                       UnitAttr:$uniform,
+                       UnitAttr:$in_param_space);
+  let assemblyFormat = "(`level` `=` $cacheLevel^ (`uniform` $uniform^)? `,`)? (`tensormap` $tensormap^ (`in_param_space` $in_param_space^)? `,`)? (`evict_priority` `=` $evictPriority^ `,`)? $addr (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID getIntrinsicID(NVVM::PrefetchOp &op);
-  }];
-  let llvmBuilder = [{
-    auto intId = NVVM::PrefetchOp::getIntrinsicID(op);
-    createIntrinsicCall(builder, intId, $addr);
+    static NVVM::IDArgPair
+    getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase &builder);
+    bool hasIntrinsic() { return !getPredicate() || !getTensormap(); }
   }];
-}
-
-def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
-                    [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
-  Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {
-  let assemblyFormat = "$tmaDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { 
+    std::string $cppClass::getPtx() {
+      // Inline PTX is only supported for prefetch tensormap
       return std::string("prefetch.tensormap [%0];");
     }
   }];
+  let llvmBuilder = [{
+    auto [id, args] = NVVM::PrefetchOp::getIntrinsicIDAndArgs(op,
+                                          moduleTranslation, builder);
+
+    if(op.getTensormap())
+      // Overloaded intrinsic
+      createIntrinsicCall(builder, id, args, {args[0]->getType()});
+    else
+      createIntrinsicCall(builder, id, args);
+  }];
 }
 
 def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index ab1666a0e8e75..37d12bad298df 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1695,8 +1695,10 @@ struct NVGPUTmaPrefetchOpLowering
   LogicalResult
   matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
-        op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
+    rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
+        op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
+        adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
+        /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ff6ccbaac2b35..77ec1ebde3109 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -33,6 +33,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <optional>
@@ -1332,30 +1333,70 @@ LogicalResult NVVM::PrefetchOp::verify() {
   unsigned addressSpace =
       llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
   std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
+  std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
 
-  if (getUniform()) {
-    if (getCacheLevel() != CacheLevel::L1)
-      return emitOpError("unsupported cache level, the only supported uniform "
-                         "cache level is L1");
+  if (getTensormap() && cacheLevel)
+    return emitOpError("cannot specify both tensormap and cache level");
 
-    if (addressSpace != MemSpace::kGenericMemorySpace)
+  if (getTensormap()) {
+    if (addressSpace != MemSpace::kGenericMemorySpace &&
+        addressSpace != MemSpace::kConstantMemorySpace) {
       return emitOpError(
-          "prefetch to uniform cache requires a generic pointer");
-  }
+          "prefetch tensormap requires a generic or constant pointer");
+    }
 
-  if (evictPriority) {
-    if (getCacheLevel() != CacheLevel::L2)
+    if (evictPriority) {
       return emitOpError(
-          "cache eviction priority supported only for cache level L2");
-
-    if (addressSpace != MemSpace::kGlobalMemorySpace)
-      return emitOpError("cache eviction priority requires a global pointer");
+          "prefetch tensormap does not support eviction priority");
+    }
 
-    if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
-        *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
+    if (getInParamSpace() && addressSpace != MemSpace::kGenericMemorySpace) {
       return emitOpError(
-          "unsupported cache eviction priority, only evict_last and "
-          "evict_normal are supported");
+          "in_param_space can only be specified for a generic pointer");
+    }
+
+  } else if (cacheLevel) {
+    if (addressSpace != MemSpace::kGenericMemorySpace &&
+        addressSpace != MemSpace::kGlobalMemorySpace &&
+        addressSpace != MemSpace::kLocalMemorySpace) {
+      return emitOpError("prefetch to cache level requires a generic, global, "
+                         "or local pointer");
+    }
+
+    if (getUniform()) {
+      if (*cacheLevel != CacheLevel::L1) {
+        return emitOpError(
+            "unsupported cache level, the only supported uniform "
+            "cache level is L1");
+      }
+
+      if (addressSpace != MemSpace::kGenericMemorySpace) {
+        return emitOpError(
+            "prefetch to uniform cache requires a generic pointer");
+      }
+    }
+
+    if (evictPriority) {
+      if (*cacheLevel != CacheLevel::L2)
+        return emitOpError(
+            "cache eviction priority supported only for cache level L2");
+
+      if (addressSpace != MemSpace::kGlobalMemorySpace)
+        return emitOpError("cache eviction priority requires a global pointer");
+
+      if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
+          *evictPriority != NVVM::CacheEvictionPriority::EvictLast)
+        return emitOpError(
+            "unsupported cache eviction priority, only evict_last and "
+            "evict_normal are supported");
+    }
+
+    if (getPredicate())
+      return emitOpError("predicate supported only on prefetch tensormap");
+
+  } else {
+    return emitOpError(
+        "requires specification of either cache level or tensormap");
   }
 
   return success();
@@ -1964,26 +2005,47 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
   return {ids[type], args};
 }
 
-llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
+static llvm::Value *getParamCastedAddr(llvm::Value *addr,
+                                       llvm::IRBuilderBase &builder) {
+  return builder.CreateAddrSpaceCast(
+      addr,
+      llvm::PointerType::get(builder.getContext(),
+                             llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
+}
+
+NVVM::IDArgPair
+PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
+                                  LLVM::ModuleTranslation &mt,
+                                  llvm::IRBuilderBase &builder) {
   using MemSpace = NVVM::NVVMMemorySpace;
   using CacheLevel = NVVM::PrefetchCacheLevel;
 
-  NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
+  std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
   std::optional<NVVM::CacheEvictionPriority> evictPriority =
       op.getEvictPriority();
   unsigned addressSpace =
       llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
           .getAddressSpace();
 
-  if (op.getUniform() && cacheLevel == CacheLevel::L1)
-    return llvm::Intrinsic::nvvm_prefetchu_L1;
+  llvm::SmallVector<llvm::Value *> args;
+  llvm::Value *addr = mt.lookupValue(op.getAddr());
+  args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
+                                      : addr);
+
+  if (op.getTensormap())
+    return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
+
+  assert(cacheLevel && "expected cache level for non-tensormap prefetch");
+
+  if (op.getUniform() && *cacheLevel == CacheLevel::L1)
+    return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
 
-  if (evictPriority && cacheLevel == CacheLevel::L2) {
+  if (evictPriority && *cacheLevel == CacheLevel::L2) {
     switch (*evictPriority) {
     case NVVM::CacheEvictionPriority::EvictLast:
-      return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
+      return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
     case NVVM::CacheEvictionPriority::EvictNormal:
-      return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
+      return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
     default:
       llvm_unreachable("Invalid cache eviction priority");
     }
@@ -1991,16 +2053,21 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
 
   switch (addressSpace) {
   case MemSpace::kGenericMemorySpace:
-    return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
-                                        : llvm::Intrinsic::nvvm_prefetch_L2;
+    return *cacheLevel == CacheLevel::L1
+               ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
+               : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
   case MemSpace::kGlobalMemorySpace:
-    return cacheLevel == CacheLevel::L1
-               ? llvm::Intrinsic::nvvm_prefetch_global_L1
-               : llvm::Intrinsic::nvvm_prefetch_global_L2;
+    return *cacheLevel == CacheLevel::L1
+               ? NVVM::IDArgPair(
+                     {llvm::Intrinsic::nvvm_prefetch_global_L1, args})
+               : NVVM::IDArgPair(
+                     {llvm::Intrinsic::nvvm_prefetch_global_L2, args});
   case MemSpace::kLocalMemorySpace:
-    return cacheLevel == CacheLevel::L1
-               ? llvm::Intrinsic::nvvm_prefetch_local_L1
-               : llvm::Intrinsic::nvvm_prefetch_local_L2;
+    return *cacheLevel == CacheLevel::L1
+               ? NVVM::IDArgPair(
+                     {llvm::Intrinsic::nvvm_prefetch_local_L1, args})
+               : NVVM::IDArgPair(
+                     {llvm::Intrinsic::nvvm_prefetch_local_L2, args});
   default:
     llvm_unreachable("Invalid pointer address space");
   }

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index c4cf4f7337d81..0c500e10bc810 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -817,9 +817,9 @@ func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : m
 // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none>, %[[arg1:[a-zA-Z0-9_]+]]: i1
 func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
   // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none> to !llvm.ptr
-  // CHECK: nvvm.prefetch.tensormap %[[S0]] : !llvm.ptr
+  // CHECK: nvvm.prefetch tensormap, %[[S0]] : !llvm.ptr
   nvgpu.tma.prefetch.descriptor %tensorMap1d: !tensorMap1d
-  // CHECK: nvvm.prefetch.tensormap %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
+  // CHECK: nvvm.prefetch tensormap, %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
   nvgpu.tma.prefetch.descriptor %tensorMap1d, predicate = %p: !tensorMap1d
   func.return
 }

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 89075120d16ea..92930f9cbaa49 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -572,10 +572,10 @@ func.func @elect_one_leader_sync() {
 
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx
 llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
-  nvvm.prefetch.tensormap %desc : !llvm.ptr
+  //CHECK: nvvm.prefetch tensormap, %{{.*}}
+  nvvm.prefetch tensormap, %desc : !llvm.ptr
   //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
-  nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
+  nvvm.prefetch tensormap, %desc, predicate = %pred : !llvm.ptr, i1
   llvm.return
 }
 

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 5821c2eac99dd..5209b3c1d7906 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -586,7 +586,7 @@ func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c:
 }
 
 // CHECK-LABEL: @prefetch
-func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
+func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>, %const_ptr: !llvm.ptr<4>) {
   // CHECK:   nvvm.prefetch level = L1, %{{.*}}
   nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0>
   // CHECK:   nvvm.prefetch level = L1, %{{.*}}
@@ -599,12 +599,24 @@ func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr:
   nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5>
   // CHECK:   nvvm.prefetch level = L2, %{{.*}}
   nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1>
-  // CHECK:   nvvm.prefetch level = L2, %{{.*}}
-  nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
-  // CHECK:   nvvm.prefetch level = L2, %{{.*}}
-  nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
+  // CHECK:   nvvm.prefetch level = L2, evict_priority = evict_last, %{{.*}}
+  nvvm.prefetch level = L2, evict_priority = evict_last, %global_ptr :
+  !llvm.ptr<1>
+  // CHECK:   nvvm.prefetch level = L2, evict_priority = evict_normal, %{{.*}}
+  nvvm.prefetch level = L2, evict_priority = evict_normal, %global_ptr : !llvm.ptr<1>
   // CHECK:   nvvm.prefetch level = L1 uniform, %{{.*}}
   nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
+  // CHECK:   nvvm.prefetch tensormap, %{{.*}}
+  nvvm.prefetch tensormap, %gen_ptr : !llvm.ptr
+  // CHECK:   nvvm.prefetch tensormap, %{{.*}}
+  nvvm.prefetch tensormap, %const_ptr : !llvm.ptr<4>
+  // CHECK:   nvvm.prefetch tensormap in_param_space, %{{.*}}
+  nvvm.prefetch tensormap in_param_space, %gen_ptr : !llvm.ptr
+  return
+}
+
+// CHECK-LABEL: @prefetch_tensormap
+func.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
   return
 }
 

diff  --git a/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
index f38b7529a7233..5f8e8d06e1c2d 100644
--- a/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/prefetch.mlir
@@ -32,8 +32,8 @@ llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) {
   // CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0)
   // CHECK-NEXT: ret void
   // CHECK-NEXT: }
-  nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
-  nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
+  nvvm.prefetch level = L2, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
+  nvvm.prefetch level = L2, evict_priority = evict_normal, %global_ptr : !llvm.ptr<1>
   llvm.return
 }
 
@@ -45,3 +45,17 @@ llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
   nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
   llvm.return
 }
+
+llvm.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
+  // CHECK-LABEL: define void @prefetch_tensormap(ptr %0, ptr addrspace(4) %1) {
+  // CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p0(ptr %0)
+  // CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p4(ptr addrspace(4) %1)
+  // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(101)
+  // CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p101(ptr addrspace(101) %3)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  nvvm.prefetch tensormap, %gen_ptr : !llvm.ptr
+  nvvm.prefetch tensormap, %const_ptr: !llvm.ptr<4>
+  nvvm.prefetch tensormap in_param_space, %gen_ptr : !llvm.ptr
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 863118cd8dd71..b35a6dbcca286 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -248,7 +248,7 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
 
 llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
   // expected-error @below {{cache eviction priority supported only for cache level L2}}
-  nvvm.prefetch level = L1, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
+  nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
   llvm.return
 }
 
@@ -256,7 +256,7 @@ llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
 
 llvm.func @nvvm_prefetch_L2_with_evict_last_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
   // expected-error @below {{cache eviction priority requires a global pointer}}
-  nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_last : !llvm.ptr<5>
+  nvvm.prefetch level = L2, evict_priority = evict_last, %local_ptr : !llvm.ptr<5>
   llvm.return
 }
 
@@ -264,7 +264,7 @@ llvm.func @nvvm_prefetch_L2_with_evict_last_invalid_addr_space(%local_ptr: !llvm
 
 llvm.func @nvvm_prefetch_L2_with_evict_normal_invalid_addr_space(%local_ptr: !llvm.ptr<5>) {
   // expected-error @below {{cache eviction priority requires a global pointer}}
-  nvvm.prefetch level = L2, %local_ptr, evict_priority = evict_normal : !llvm.ptr<5>
+  nvvm.prefetch level = L2, evict_priority = evict_normal, %local_ptr : !llvm.ptr<5>
   llvm.return
 }
 
@@ -272,7 +272,7 @@ llvm.func @nvvm_prefetch_L2_with_evict_normal_invalid_addr_space(%local_ptr: !ll
 
 llvm.func @nvvm_prefetch_L2_with_invalid_evict_first(%global_ptr: !llvm.ptr<1>) {
   // expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
-  nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_first : !llvm.ptr<1>
+  nvvm.prefetch level = L2, evict_priority = evict_first, %global_ptr : !llvm.ptr<1>
   llvm.return
 }
 
@@ -280,7 +280,7 @@ llvm.func @nvvm_prefetch_L2_with_invalid_evict_first(%global_ptr: !llvm.ptr<1>)
 
 llvm.func @nvvm_prefetch_L2_with_invalid_evict_unchanged(%global_ptr: !llvm.ptr<1>) {
   // expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
-  nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_unchanged : !llvm.ptr<1>
+  nvvm.prefetch level = L2, evict_priority = evict_unchanged, %global_ptr : !llvm.ptr<1>
   llvm.return
 }
 
@@ -288,7 +288,7 @@ llvm.func @nvvm_prefetch_L2_with_invalid_evict_unchanged(%global_ptr: !llvm.ptr<
 
 llvm.func @nvvm_prefetch_L2_with_invalid_no_allocate(%global_ptr: !llvm.ptr<1>) {
   // expected-error @below {{unsupported cache eviction priority, only evict_last and evict_normal are supported}}
-  nvvm.prefetch level = L2, %global_ptr, evict_priority = no_allocate : !llvm.ptr<1>
+  nvvm.prefetch level = L2, evict_priority = no_allocate, %global_ptr : !llvm.ptr<1>
   llvm.return
 }
 
@@ -310,6 +310,62 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr<
 
 // -----
 
+llvm.func @nvvm_prefetch_both_tensormap_and_cache_level(%gen_ptr: !llvm.ptr) {
+  // expected-error @below {{cannot specify both tensormap and cache level}}
+  nvvm.prefetch level = L1, tensormap, %gen_ptr : !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_tensormap_invalid_addr_space(%global_ptr: !llvm.ptr<1>) {
+  // expected-error @below {{prefetch tensormap requires a generic or constant pointer}}
+  nvvm.prefetch tensormap, %global_ptr : !llvm.ptr<1>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_tensormap_with_evict_priority(%gen_ptr: !llvm.ptr) {
+  // expected-error @below {{prefetch tensormap does not support eviction priority}}
+  nvvm.prefetch tensormap, evict_priority = evict_last, %gen_ptr : !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_tensormap_in_param_space_non_generic(%const_ptr: !llvm.ptr<4>) {
+  // expected-error @below {{in_param_space can only be specified for a generic pointer}}
+  nvvm.prefetch tensormap in_param_space, %const_ptr : !llvm.ptr<4>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_cache_level_invalid_addr_space(%const_ptr: !llvm.ptr<4>) {
+  // expected-error @below {{prefetch to cache level requires a generic, global, or local pointer}}
+  nvvm.prefetch level = L1, %const_ptr : !llvm.ptr<4>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_predicate_without_tensormap(%gen_ptr: !llvm.ptr, %pred: i1) {
+  // expected-error @below {{predicate supported only on prefetch tensormap}}
+  nvvm.prefetch level = L1, %gen_ptr, predicate = %pred : !llvm.ptr, i1
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_prefetch_no_level_or_tensormap(%gen_ptr: !llvm.ptr) {
+  // expected-error @below {{requires specification of either cache level or tensormap}}
+  nvvm.prefetch %gen_ptr : !llvm.ptr
+  llvm.return
+}
+
+// -----
+
 llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
   // expected-error at +1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}}
   nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32


        


More information about the Mlir-commits mailing list