[Mlir-commits] [mlir] [MLIR][NVVM] Update MLIR mapa to reflect new address space (PR #146031)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 3 20:39:23 PDT 2025
https://github.com/modiking updated https://github.com/llvm/llvm-project/pull/146031
>From f6d951b96f35ae2bb0447ca791b4c5f01f1b975b Mon Sep 17 00:00:00 2001
From: Modi Mo <mmo at nvidia.com>
Date: Thu, 26 Jun 2025 23:25:05 -0700
Subject: [PATCH 1/3] update mapa
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 25 ++++++++++++++++++---
mlir/test/Dialect/LLVMIR/invalid.mlir | 4 ++--
mlir/test/Dialect/LLVMIR/nvvm.mlir | 2 +-
mlir/test/Target/LLVMIR/nvvmir.mlir | 4 ++--
4 files changed, 27 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6895e946b8a45..e55060dc04204 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3009,10 +3009,29 @@ def NVVM_GriddepcontrolLaunchDependentsOp
// NVVM Mapa Op
//===----------------------------------------------------------------------===//
+// Helper predicates for address space checking
+def IsGenericAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 0">;
+def IsSharedAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 3">;
+def IsSharedClusterAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 7">;
+
+class NVVM_AddressSpaceMapping<string inputArg, string resultArg> :
+ PredOpTrait<"valid address space mapping for NVVM mapa operation",
+ Or<[
+ // Generic -> Generic
+ And<[
+ SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>,
+ SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace>
+ ]>,
+ // Shared -> SharedCluster
+ And<[
+ SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>,
+ SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace>
+ ]>
+ ]>>;
+
def NVVM_MapaOp: NVVM_Op<"mapa",
- [TypesMatchWith<"`res` and `a` should have the same type",
- "a", "res", "$_self">, NVVMRequiresSM<90>]> {
- let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
+ [NVVM_AddressSpaceMapping<"a", "res">, NVVMRequiresSM<90>]> {
+ let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
string llvmBuilder = [{
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 251ca716c7a7a..7a85eea58c558 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1201,8 +1201,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
// -----
func.func @mapa(%a: !llvm.ptr, %b : i32) {
- // expected-error @below {{`res` and `a` should have the same type}}
- %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
+ // expected-error @below {{'nvvm.mapa' op failed to verify that valid address space mapping for NVVM mapa operation}}
+ %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
return
}
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c7fa41c98ac92..4349193aa1a45 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -552,7 +552,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
// CHECK: nvvm.mapa %{{.*}}
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
// CHECK: nvvm.mapa %{{.*}}
- %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+ %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a04186f512..c119c1a0fd21f 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -760,8 +760,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() {
llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
// CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
- // CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
- %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
+ // CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
llvm.return
}
>From 65552af49a860190a45ac4f742d50300839d6312 Mon Sep 17 00:00:00 2001
From: Modi Mo <mmo at nvidia.com>
Date: Tue, 1 Jul 2025 15:27:23 -0700
Subject: [PATCH 2/3] review feedback
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 6 +++---
mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +-
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e55060dc04204..431a33412c43e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3014,8 +3014,8 @@ def IsGenericAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).get
def IsSharedAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 3">;
def IsSharedClusterAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 7">;
-class NVVM_AddressSpaceMapping<string inputArg, string resultArg> :
- PredOpTrait<"valid address space mapping for NVVM mapa operation",
+class NVVM_MapaASCheck<string inputArg, string resultArg> :
+ PredOpTrait<"Valid address-space check(or mapping) for mapa Op",
Or<[
// Generic -> Generic
And<[
@@ -3030,7 +3030,7 @@ class NVVM_AddressSpaceMapping<string inputArg, string resultArg> :
]>>;
def NVVM_MapaOp: NVVM_Op<"mapa",
- [NVVM_AddressSpaceMapping<"a", "res">, NVVMRequiresSM<90>]> {
+ [NVVM_MapaASCheck<"a", "res">, NVVMRequiresSM<90>]> {
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 7a85eea58c558..2c1c3071c456e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1201,7 +1201,7 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
// -----
func.func @mapa(%a: !llvm.ptr, %b : i32) {
- // expected-error @below {{'nvvm.mapa' op failed to verify that valid address space mapping for NVVM mapa operation}}
+ // expected-error @below {{'nvvm.mapa' op failed to verify that Valid address-space check(or mapping) for mapa Op}}
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
return
}
>From 088fddf0b88ee4113fb0fc848bb5e07e08cc29a5 Mon Sep 17 00:00:00 2001
From: Modi Mo <mmo at nvidia.com>
Date: Thu, 3 Jul 2025 20:37:44 -0700
Subject: [PATCH 3/3] make generic helper and update usage to it
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 27 ++++-----------------
mlir/include/mlir/IR/OpBase.td | 16 ++++++++++++
2 files changed, 21 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 431a33412c43e..9ebaac6f0fd80 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3009,28 +3009,11 @@ def NVVM_GriddepcontrolLaunchDependentsOp
// NVVM Mapa Op
//===----------------------------------------------------------------------===//
-// Helper predicates for address space checking
-def IsGenericAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 0">;
-def IsSharedAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 3">;
-def IsSharedClusterAddressSpace : CPred<"llvm::cast<LLVM::LLVMPointerType>($_self).getAddressSpace() == 7">;
-
-class NVVM_MapaASCheck<string inputArg, string resultArg> :
- PredOpTrait<"Valid address-space check(or mapping) for mapa Op",
- Or<[
- // Generic -> Generic
- And<[
- SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>,
- SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace>
- ]>,
- // Shared -> SharedCluster
- And<[
- SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>,
- SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace>
- ]>
- ]>>;
-
-def NVVM_MapaOp: NVVM_Op<"mapa",
- [NVVM_MapaASCheck<"a", "res">, NVVMRequiresSM<90>]> {
+def NVVM_MapaASCheck : PredOpTrait<"Valid address-space check(or mapping) for mapa Op",
+ Or<[InputMatchesTypes<["a", "res"], [LLVM_PointerShared, LLVM_PointerSharedCluster]>.predicate,
+ InputMatchesTypes<["a", "res"], [LLVM_PointerGeneric, LLVM_PointerGeneric]>.predicate]>>;
+
+def NVVM_MapaOp: NVVM_Op<"mapa", [NVVM_MapaASCheck, NVVMRequiresSM<90>]> {
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 43ef28624fb19..b21603a410c0c 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -603,6 +603,22 @@ class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
string transform>
: TypesMatchWith<summary, lhsArg, rhsArg, transform, "llvm::equal">;
+// Checks that each inputArg has the same type as the corresponding entry
+// in allowedTypes
+class InputMatchesTypes<list<string> inputArgs, list<Type> allowedTypes> :
+ PredOpTrait<"operands {" # !interleave(inputArgs, ", ") # "} match expected types",
+ !foldl(TruePred, !range(!size(inputArgs)), acc, i,
+ And<[acc,
+ SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
+ allowedTypes[i].predicate>
+ ]>)> {
+ assert !eq(!size(inputArgs), !size(allowedTypes)),
+ "inputArgs and allowedTypes lists must have the same length";
+
+ list<string> inputArgList = inputArgs;
+ list<Type> allowedTypeList = allowedTypes;
+}
+
// Type Constraint operand `idx`'s Element type is `type`.
class TCopVTEtIs<int idx, Type type> : And<[
CPred<"$_op.getNumOperands() > " # idx>,
More information about the Mlir-commits
mailing list