[Mlir-commits] [mlir] [MLIR][NVVM] Migrate CpAsyncOp to intrinsics (PR #123789)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 21 09:41:37 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

<details>
<summary>Changes</summary>

Intrinsics are available for the 'cpSize'
variants also. So, this patch migrates the Op
to lower to the intrinsics for all cases.

* Update the existing tests to check the lowering to intrinsics.
* Add newer cp_async_zfill tests to verify the lowering for the 'cpSize' variants.
* Tidy-up CHECK lines in cp_async() function in nvvmir.mlir (NFC)

PTX spec link:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async

---
Full diff: https://github.com/llvm/llvm-project/pull/123789.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+15-40) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+23) 
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+2-6) 
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+20-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 797a0067081314..dc4295926a8ce5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -849,55 +849,30 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
 
 def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
 
-def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
+def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
   Arguments<(ins LLVM_PointerShared:$dst,
                  LLVM_PointerGlobal:$src,
                  I32Attr:$size,
                  LoadCacheModifierAttr:$modifier,
                  Optional<LLVM_Type>:$cpSize)> {
-  string llvmBuilder = [{
-      llvm::Intrinsic::ID id;
-      switch ($size) {
-        case 4:
-          id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_4;
-          break;
-        case 8:
-          id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8;
-          break;
-        case 16:
-          if($modifier == NVVM::LoadCacheModifierKind::CG)
-            id = llvm::Intrinsic::nvvm_cp_async_cg_shared_global_16;
-          else if($modifier == NVVM::LoadCacheModifierKind::CA)
-            id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
-          else 
-            llvm_unreachable("unsupported cache modifier");
-          break;
-        default:
-          llvm_unreachable("unsupported async copy size");
-      }
-      createIntrinsicCall(builder, id, {$dst, $src});
-  }];
   let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)";
   let hasVerifier = 1;
   let extraClassDeclaration = [{
-    bool hasIntrinsic() { if(getCpSize()) return false; return true; }
-
-    void getAsmValues(RewriterBase &rewriter, 
-        llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues) {
-      asmValues.push_back({getDst(), PTXRegisterMod::Read});
-      asmValues.push_back({getSrc(), PTXRegisterMod::Read});
-      asmValues.push_back({makeConstantI32(rewriter, getSize()), PTXRegisterMod::Read});
-      asmValues.push_back({getCpSize(), PTXRegisterMod::Read});
-    }        
+    static llvm::Intrinsic::ID getIntrinsicID(int size,
+                                              NVVM::LoadCacheModifierKind kind,
+                                              bool hasCpSize);
   }];
-  let extraClassDefinition = [{        
-    std::string $cppClass::getPtx() { 
-      if(getModifier() == NVVM::LoadCacheModifierKind::CG)
-        return std::string("cp.async.cg.shared.global [%0], [%1], %2, %3;\n");
-      if(getModifier() == NVVM::LoadCacheModifierKind::CA)
-        return std::string("cp.async.ca.shared.global [%0], [%1], %2, %3;\n");
-      llvm_unreachable("unsupported cache modifier");      
-    }
+  string llvmBuilder = [{
+    bool hasCpSize = op.getCpSize() ? true : false;
+
+    llvm::SmallVector<llvm::Value *> translatedOperands;
+    translatedOperands.push_back($dst);
+    translatedOperands.push_back($src);
+    if (hasCpSize)
+      translatedOperands.push_back($cpSize);
+
+    auto id = NVVM::CpAsyncOp::getIntrinsicID($size, $modifier, hasCpSize);
+    createIntrinsicCall(builder, id, translatedOperands);
   }];
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ccb5ad05f0bf72..2c45753d52da9c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1110,6 +1110,29 @@ LogicalResult NVVM::BarrierOp::verify() {
   return success();
 }
 
+#define CP_ASYNC_ID_IMPL(mod, size, suffix)                                    \
+  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
+
+#define GET_CP_ASYNC_ID(mod, size, has_cpsize)                                 \
+  has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
+
+llvm::Intrinsic::ID
+CpAsyncOp::getIntrinsicID(int size, NVVM::LoadCacheModifierKind cacheMod,
+                          bool hasCpSize) {
+  switch (size) {
+  case 4:
+    return GET_CP_ASYNC_ID(ca, 4, hasCpSize);
+  case 8:
+    return GET_CP_ASYNC_ID(ca, 8, hasCpSize);
+  case 16:
+    return (cacheMod == NVVM::LoadCacheModifierKind::CG)
+               ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
+               : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
+  default:
+    llvm_unreachable("Invalid copy size in CpAsyncOp.");
+  }
+}
+
 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
                                                                 bool isIm2Col) {
   switch (tensorDims) {
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 84ea55ceb5acc2..c7a6eca1582768 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -74,13 +74,9 @@ func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
 
 // CHECK-LABEL: @async_cp_zfill
 func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", 
-  // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
+  // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache =  cg, %{{.*}} : !llvm.ptr<3>, !llvm.ptr<1>, i32
   nvvm.cp.async.shared.global %dst, %src, 16, cache =  cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", 
-  // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
+  // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 4, cache =  ca, %{{.*}} : !llvm.ptr<3>, !llvm.ptr<1>, i32
   nvvm.cp.async.shared.global %dst, %src, 4, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
   return
 }
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 09e98765413f0c..7dad9a403def0e 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -488,21 +488,35 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
 
 // CHECK-LABEL: @cp_async
 llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
-// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 4, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
-// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 8, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
-// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 16, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
-// CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
+  // CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 16, cache =  cg : !llvm.ptr<3>, !llvm.ptr<1>
-// CHECK: call void @llvm.nvvm.cp.async.commit.group()
+
+  // CHECK: call void @llvm.nvvm.cp.async.commit.group()
   nvvm.cp.async.commit.group
-// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
+  // CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
   nvvm.cp.async.wait.group 0
   llvm.return
 }
 
+// CHECK-LABEL: @async_cp_zfill
+llvm.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
+  nvvm.cp.async.shared.global %dst, %src, 4, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
+  nvvm.cp.async.shared.global %dst, %src, 8, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
+  nvvm.cp.async.shared.global %dst, %src, 16, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  // CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
+  nvvm.cp.async.shared.global %dst, %src, 16, cache =  cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  llvm.return
+}
+
 // CHECK-LABEL: @cp_async_mbarrier_arrive
 llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
   // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %{{.*}})

``````````

</details>


https://github.com/llvm/llvm-project/pull/123789


More information about the Mlir-commits mailing list