[Mlir-commits] [mlir] 38d854c - [MLIR][NVVM] Update MLIR mapa to reflect new address space (#146031)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 12 09:13:54 PDT 2025
Author: modiking
Date: 2025-08-12T21:43:51+05:30
New Revision: 38d854c6e8ae51f3b8bfdb51cb37fee7544bb7b1
URL: https://github.com/llvm/llvm-project/commit/38d854c6e8ae51f3b8bfdb51cb37fee7544bb7b1
DIFF: https://github.com/llvm/llvm-project/commit/38d854c6e8ae51f3b8bfdb51cb37fee7544bb7b1.diff
LOG: [MLIR][NVVM] Update MLIR mapa to reflect new address space (#146031)
The mapa.shared.cluster variant that takes in address-space 3 now should
output address-space 7. This patch updates the NVVMOps.td file to reflect this.
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/include/mlir/IR/OpBase.td
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/nvvm.mlir
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 3eaaa0539df80..f5a77af028abd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3068,9 +3068,10 @@ def NVVM_GriddepcontrolLaunchDependentsOp
//===----------------------------------------------------------------------===//
def NVVM_MapaOp: NVVM_Op<"mapa",
- [TypesMatchWith<"`res` and `a` should have the same type",
- "a", "res", "$_self">, NVVMRequiresSM<90>]> {
- let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
+ [InputAddressIsCombinationOf<["a", "res"],
+ [[LLVM_PointerShared, LLVM_PointerSharedCluster], [LLVM_PointerGeneric, LLVM_PointerGeneric]],
+ "Valid address-space check(or mapping) for mapa Op">, NVVMRequiresSM<90>]> {
+ let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
string llvmBuilder = [{
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 9e5fb5659a22b..af8c072a7a364 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -603,6 +603,51 @@ class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
string transform>
: TypesMatchWith<summary, lhsArg, rhsArg, transform, "llvm::equal">;
+// Checks that each inputArg has the same type as the corresponding entry
+// in allowedTypes
+class InputMatchesTypes<list<string> inputArgs, list<Type> allowedTypes> :
+ PredOpTrait<"operands {" # !interleave(inputArgs, ", ") # "} match expected types",
+ !foldl(TruePred, !range(!size(inputArgs)), acc, i,
+ And<[acc,
+ SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
+ allowedTypes[i].predicate>
+ ]>)> {
+ assert !eq(!size(inputArgs), !size(allowedTypes)),
+ "inputArgs and allowedTypes lists must have the same length";
+
+ list<string> inputArgList = inputArgs;
+ list<Type> allowedTypeList = allowedTypes;
+}
+
+// Checks that inputArgs match one of the allowed type combinations.
+// Each combination in allowedCombinations must have the same number of types
+// as there are inputArgs.
+class InputAddressIsCombinationOf<list<string> inputArgs,
+ list<list<Type>> allowedCombinations,
+ string description = ""> :
+ PredOpTrait<!if(!empty(description),
+ "operands {" # !interleave(inputArgs, ", ") # "} match one of the allowed type combinations",
+ description),
+ Or<!foreach(combination, allowedCombinations,
+ !foldl(TruePred, !range(!size(inputArgs)), acc, i,
+ And<[acc,
+ SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
+ combination[i].predicate>
+ ]>))>> {
+ assert !gt(!size(allowedCombinations), 0),
+ "allowedCombinations must not be empty";
+
+ // Validate that each combination has the same number of types as inputArgs
+ defvar inputArgSize = !size(inputArgs);
+ defvar validSizes = !foldl(1, allowedCombinations, acc, combination,
+ !and(acc, !eq(inputArgSize, !size(combination))));
+ assert validSizes,
+ "each combination in allowedCombinations must have the same length as inputArgs";
+
+ list<string> inputArgList = inputArgs;
+ list<list<Type>> allowedCombinationList = allowedCombinations;
+}
+
// Type Constraint operand `idx`'s Element type is `type`.
class TCopVTEtIs<int idx, Type type> : And<[
CPred<"$_op.getNumOperands() > " # idx>,
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index c88ff0f9be5d1..4394786db5a5d 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1275,8 +1275,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
// -----
func.func @mapa(%a: !llvm.ptr, %b : i32) {
- // expected-error @below {{`res` and `a` should have the same type}}
- %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
+ // expected-error @below {{'nvvm.mapa' op failed to verify that Valid address-space check(or mapping) for mapa Op}}
+ %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
return
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index 6a4edd0d22a08..e99f27c7f10a3 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -541,7 +541,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
// CHECK: nvvm.mapa %{{.*}}
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
// CHECK: nvvm.mapa %{{.*}}
- %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+ %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 63e286cdfe07c..0996a8c7eb361 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -813,8 +813,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() {
llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
// CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
- // CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
- %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+ // CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
llvm.return
}
More information about the Mlir-commits
mailing list