[Mlir-commits] [mlir] [MLIR][NVVM] Update MLIR mapa to reflect new address space (PR #146031)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 3 20:39:23 PDT 2025


https://github.com/modiking updated https://github.com/llvm/llvm-project/pull/146031

>From f6d951b96f35ae2bb0447ca791b4c5f01f1b975b Mon Sep 17 00:00:00 2001
From: Modi Mo <mmo at nvidia.com>
Date: Thu, 26 Jun 2025 23:25:05 -0700
Subject: [PATCH 1/3] update mapa

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 25 ++++++++++++++++++---
 mlir/test/Dialect/LLVMIR/invalid.mlir       |  4 ++--
 mlir/test/Dialect/LLVMIR/nvvm.mlir          |  2 +-
 mlir/test/Target/LLVMIR/nvvmir.mlir         |  4 ++--
 4 files changed, 27 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6895e946b8a45..e55060dc04204 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3009,10 +3009,29 @@ def NVVM_GriddepcontrolLaunchDependentsOp
 // NVVM Mapa Op
 //===----------------------------------------------------------------------===//
 
+// Helper predicates for address space checking
+def IsGenericAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 0">;
+def IsSharedAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 3">;
+def IsSharedClusterAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 7">;
+
+class NVVM_AddressSpaceMapping<string inputArg, string resultArg> : 
+    PredOpTrait<"valid address space mapping for NVVM mapa operation",
+                Or<[
+                  // Generic -> Generic
+                  And<[
+                    SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>,
+                    SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace>
+                  ]>,
+                  // Shared -> SharedCluster
+                  And<[
+                    SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>,
+                    SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace>
+                  ]>
+                ]>>;
+
 def NVVM_MapaOp: NVVM_Op<"mapa",
-    [TypesMatchWith<"`res` and `a` should have the same type",
-                    "a", "res", "$_self">, NVVMRequiresSM<90>]> {
-  let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
+    [NVVM_AddressSpaceMapping<"a", "res">, NVVMRequiresSM<90>]> {
+  let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
   let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
 
   string llvmBuilder = [{
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 251ca716c7a7a..7a85eea58c558 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1201,8 +1201,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
 // -----
 
 func.func @mapa(%a: !llvm.ptr, %b : i32) {
-  // expected-error @below {{`res` and `a` should have the same type}}
-  %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
+  // expected-error @below {{'nvvm.mapa' op failed to verify that valid address space mapping for NVVM mapa operation}}
+  %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
   return
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c7fa41c98ac92..4349193aa1a45 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -552,7 +552,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
   // CHECK:   nvvm.mapa %{{.*}}
   %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
   // CHECK:   nvvm.mapa %{{.*}}
-  %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+  %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
   return
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a04186f512..c119c1a0fd21f 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -760,8 +760,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() {
 llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
   // CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
   %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
-  // CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
-  %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+  // CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+  %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
   llvm.return
 }
 

>From 65552af49a860190a45ac4f742d50300839d6312 Mon Sep 17 00:00:00 2001
From: Modi Mo <mmo at nvidia.com>
Date: Tue, 1 Jul 2025 15:27:23 -0700
Subject: [PATCH 2/3] review feedback

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 6 +++---
 mlir/test/Dialect/LLVMIR/invalid.mlir       | 2 +-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e55060dc04204..431a33412c43e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3014,8 +3014,8 @@ def IsGenericAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).get
 def IsSharedAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 3">;
 def IsSharedClusterAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 7">;
 
-class NVVM_AddressSpaceMapping<string inputArg, string resultArg> : 
-    PredOpTrait<"valid address space mapping for NVVM mapa operation",
+class NVVM_MapaASCheck<string inputArg, string resultArg> : 
+    PredOpTrait<"Valid address-space check(or mapping) for mapa Op",
                 Or<[
                   // Generic -> Generic
                   And<[
@@ -3030,7 +3030,7 @@ class NVVM_AddressSpaceMapping<string inputArg, string resultArg> :
                 ]>>;
 
 def NVVM_MapaOp: NVVM_Op<"mapa",
-    [NVVM_AddressSpaceMapping<"a", "res">, NVVMRequiresSM<90>]> {
+    [NVVM_MapaASCheck<"a", "res">, NVVMRequiresSM<90>]> {
   let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
   let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 7a85eea58c558..2c1c3071c456e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1201,7 +1201,7 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
 // -----
 
 func.func @mapa(%a: !llvm.ptr, %b : i32) {
-  // expected-error @below {{'nvvm.mapa' op failed to verify that valid address space mapping for NVVM mapa operation}}
+  // expected-error @below {{'nvvm.mapa' op failed to verify that Valid address-space check(or mapping) for mapa Op}}
   %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
   return
 }

>From 088fddf0b88ee4113fb0fc848bb5e07e08cc29a5 Mon Sep 17 00:00:00 2001
From: Modi Mo <mmo at nvidia.com>
Date: Thu, 3 Jul 2025 20:37:44 -0700
Subject: [PATCH 3/3] make generic helper and update usage to it

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 27 ++++-----------------
 mlir/include/mlir/IR/OpBase.td              | 16 ++++++++++++
 2 files changed, 21 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 431a33412c43e..9ebaac6f0fd80 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3009,28 +3009,11 @@ def NVVM_GriddepcontrolLaunchDependentsOp
 // NVVM Mapa Op
 //===----------------------------------------------------------------------===//
 
-// Helper predicates for address space checking
-def IsGenericAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 0">;
-def IsSharedAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 3">;
-def IsSharedClusterAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 7">;
-
-class NVVM_MapaASCheck<string inputArg, string resultArg> : 
-    PredOpTrait<"Valid address-space check(or mapping) for mapa Op",
-                Or<[
-                  // Generic -> Generic
-                  And<[
-                    SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>,
-                    SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace>
-                  ]>,
-                  // Shared -> SharedCluster
-                  And<[
-                    SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>,
-                    SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace>
-                  ]>
-                ]>>;
-
-def NVVM_MapaOp: NVVM_Op<"mapa",
-    [NVVM_MapaASCheck<"a", "res">, NVVMRequiresSM<90>]> {
+def NVVM_MapaASCheck : PredOpTrait<"Valid address-space check(or mapping) for mapa Op",
+    Or<[InputMatchesTypes<["a", "res"], [LLVM_PointerShared, LLVM_PointerSharedCluster]>.predicate,
+        InputMatchesTypes<["a", "res"], [LLVM_PointerGeneric, LLVM_PointerGeneric]>.predicate]>>;
+
+def NVVM_MapaOp: NVVM_Op<"mapa", [NVVM_MapaASCheck, NVVMRequiresSM<90>]> {
   let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
   let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
 
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 43ef28624fb19..b21603a410c0c 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -603,6 +603,22 @@ class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
                            string transform>
   : TypesMatchWith<summary, lhsArg, rhsArg, transform, "llvm::equal">;
 
+// Checks that each inputArg has the same type as the corresponding entry
+// in allowedTypes
+class InputMatchesTypes<list<string> inputArgs, list<Type> allowedTypes> :
+    PredOpTrait<"operands {" # !interleave(inputArgs, ", ") # "} match expected types",
+                !foldl(TruePred, !range(!size(inputArgs)), acc, i,
+                       And<[acc,
+                           SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
+                                      allowedTypes[i].predicate>
+                       ]>)> {
+    assert !eq(!size(inputArgs), !size(allowedTypes)),
+           "inputArgs and allowedTypes lists must have the same length";
+
+    list<string> inputArgList = inputArgs;
+    list<Type> allowedTypeList = allowedTypes;
+}
+
 // Type Constraint operand `idx`'s Element type is `type`.
 class TCopVTEtIs<int idx, Type type> : And<[
    CPred<"$_op.getNumOperands() > " # idx>,



More information about the Mlir-commits mailing list