[Mlir-commits] [mlir] 38d854c - [MLIR][NVVM] Update MLIR mapa to reflect new address space (#146031)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 12 09:13:54 PDT 2025


Author: modiking
Date: 2025-08-12T21:43:51+05:30
New Revision: 38d854c6e8ae51f3b8bfdb51cb37fee7544bb7b1

URL: https://github.com/llvm/llvm-project/commit/38d854c6e8ae51f3b8bfdb51cb37fee7544bb7b1
DIFF: https://github.com/llvm/llvm-project/commit/38d854c6e8ae51f3b8bfdb51cb37fee7544bb7b1.diff

LOG: [MLIR][NVVM] Update MLIR mapa to reflect new address space (#146031)

The mapa.shared.cluster variant that takes in address-space 3 now should
output address-space 7. This patch updates the NVVMOps.td file to reflect this.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/nvvm.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 3eaaa0539df80..f5a77af028abd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3068,9 +3068,10 @@ def NVVM_GriddepcontrolLaunchDependentsOp
 //===----------------------------------------------------------------------===//
 
 def NVVM_MapaOp: NVVM_Op<"mapa",
-    [TypesMatchWith<"`res` and `a` should have the same type",
-                    "a", "res", "$_self">, NVVMRequiresSM<90>]> {
-  let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
+  [InputAddressIsCombinationOf<["a", "res"],
+    [[LLVM_PointerShared, LLVM_PointerSharedCluster], [LLVM_PointerGeneric, LLVM_PointerGeneric]],
+    "Valid address-space check(or mapping) for mapa Op">, NVVMRequiresSM<90>]> {
+  let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
   let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
 
   string llvmBuilder = [{

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 9e5fb5659a22b..af8c072a7a364 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -603,6 +603,51 @@ class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
                            string transform>
   : TypesMatchWith<summary, lhsArg, rhsArg, transform, "llvm::equal">;
 
+// Checks that each inputArg has the same type as the corresponding entry
+// in allowedTypes
+class InputMatchesTypes<list<string> inputArgs, list<Type> allowedTypes> :
+    PredOpTrait<"operands {" # !interleave(inputArgs, ", ") # "} match expected types",
+                !foldl(TruePred, !range(!size(inputArgs)), acc, i,
+                       And<[acc,
+                           SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
+                                      allowedTypes[i].predicate>
+                       ]>)> {
+    assert !eq(!size(inputArgs), !size(allowedTypes)),
+           "inputArgs and allowedTypes lists must have the same length";
+
+    list<string> inputArgList = inputArgs;
+    list<Type> allowedTypeList = allowedTypes;
+}
+
+// Checks that inputArgs match one of the allowed type combinations.
+// Each combination in allowedCombinations must have the same number of types
+// as there are inputArgs.
+class InputAddressIsCombinationOf<list<string> inputArgs, 
+                                  list<list<Type>> allowedCombinations,
+                                  string description = ""> :
+    PredOpTrait<!if(!empty(description), 
+                    "operands {" # !interleave(inputArgs, ", ") # "} match one of the allowed type combinations",
+                    description),
+                Or<!foreach(combination, allowedCombinations,
+                           !foldl(TruePred, !range(!size(inputArgs)), acc, i,
+                                  And<[acc,
+                                      SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
+                                                 combination[i].predicate>
+                                  ]>))>> {
+    assert !gt(!size(allowedCombinations), 0),
+           "allowedCombinations must not be empty";
+    
+    // Validate that each combination has the same number of types as inputArgs
+    defvar inputArgSize = !size(inputArgs);
+    defvar validSizes = !foldl(1, allowedCombinations, acc, combination,
+                              !and(acc, !eq(inputArgSize, !size(combination))));
+    assert validSizes,
+           "each combination in allowedCombinations must have the same length as inputArgs";
+    
+    list<string> inputArgList = inputArgs;
+    list<list<Type>> allowedCombinationList = allowedCombinations;
+}
+
 // Type Constraint operand `idx`'s Element type is `type`.
 class TCopVTEtIs<int idx, Type type> : And<[
    CPred<"$_op.getNumOperands() > " # idx>,

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index c88ff0f9be5d1..4394786db5a5d 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1275,8 +1275,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
 // -----
 
 func.func @mapa(%a: !llvm.ptr, %b : i32) {
-  // expected-error @below {{`res` and `a` should have the same type}}
-  %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
+  // expected-error @below {{'nvvm.mapa' op failed to verify that Valid address-space check(or mapping) for mapa Op}}
+  %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
   return
 }
 

diff  --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 6a4edd0d22a08..e99f27c7f10a3 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -541,7 +541,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
   // CHECK:   nvvm.mapa %{{.*}}
   %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
   // CHECK:   nvvm.mapa %{{.*}}
-  %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+  %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
   return
 }
 

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 63e286cdfe07c..0996a8c7eb361 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -813,8 +813,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() {
 llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
   // CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
   %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
-  // CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
-  %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+  // CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+  %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
   llvm.return
 }
 


        


More information about the Mlir-commits mailing list