[Mlir-commits] [mlir] [mlir] Execute same operand name constraints before user constraints. (PR #162526)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 8 11:34:03 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Chenguang Wang (wecing)
<details>
<summary>Changes</summary>
For a pattern like this:
Pat<(MyOp $x, $x),
(...),
[(MyCheck $x)]>;
The old implementation generates:
Pat<(MyOp $x0, $x1),
(...),
[(MyCheck $x0),
($x0 == $x1)]>;
This is not very straightforward, because the $x name appears in the source pattern; it's attempting to assume equality check will be performed as part of the source pattern matching.
This commit moves the equality checks before the other constraints, i.e.:
Pat<(MyOp $x0, $x1),
(...),
[($x0 == $x1),
(MyCheck $x0)]>;
---
Full diff: https://github.com/llvm/llvm-project/pull/162526.diff
4 Files Affected:
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+13)
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+5)
- (modified) mlir/test/mlir-tblgen/pattern.mlir (+23-5)
- (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+26-23)
``````````diff
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6ea27187655ee..ed62bee3bc152 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1169,6 +1169,11 @@ def OpP : TEST_Op<"op_p"> {
let results = (outs I32);
}
+def OpQ : TEST_Op<"op_q"> {
+ let arguments = (ins AnyType, AnyType);
+ let results = (outs AnyType);
+}
+
// Test constant-folding a pattern that maps `(F32) -> SI32`.
def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
let arguments = (ins RankedTensorOf<[F32]>:$operand);
@@ -1207,6 +1212,14 @@ def TestNestedSameOpAndSameArgEqualityPattern :
def TestMultipleEqualArgsPattern :
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
+// Test equal arguments checks are applied before user provided constraints.
+// CheckIntIs32Bits would throw exceptions if input is not i32.
+def CheckIntIs32Bits : Constraint<CPred<"intIs32Bits($0)">>;
+def TestEqualArgsCheckBeforeUserConstraintsPattern :
+ Pat<(OpQ $x, $x),
+ (replaceWithValue $x),
+ [(CheckIntIs32Bits $x)]>;
+
// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f8b5144e3acb2..d764deb023873 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -70,6 +70,11 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) {
return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
}
+// Requires input value is of i32 type.
+static bool intIs32Bits(Value v) {
+ return mlir::dyn_cast<IntegerType>(v.getType()).getWidth() == 32;
+}
+
namespace {
#include "TestPatterns.inc"
} // namespace
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index bd55338618eec..a67830373e701 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -156,16 +156,19 @@ func.func @verifyNestedOpEqualArgs(
// def TestNestedOpEqualArgsPattern :
// Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
- // CHECK: %arg1
+ // CHECK: "test.op_o"(%arg1)
%0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
+ %2 = "test.op_o"(%1) : (i32) -> (i32)
- // CHECK: test.op_p
- // CHECK: test.op_n
- %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
+ // CHECK-NEXT: %[[P:.*]] = "test.op_p"
+ // CHECK-NEXT: %[[N:.*]] = "test.op_n"(%arg0, %[[P]])
+ // CHECK-NEXT: "test.op_o"(%[[N]])
+ %3 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
- %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
+ %4 = "test.op_n"(%arg0, %3) : (i32, i32) -> (i32)
+ %5 = "test.op_o"(%4) : (i32) -> (i32)
return
}
@@ -206,6 +209,21 @@ func.func @verifyMultipleEqualArgs(
return
}
+func.func @verifyEqualArgsCheckBeforeUserConstraints(%arg0: i32, %arg1: f32) {
+ // def TestEqualArgsCheckBeforeUserConstraintsPattern :
+ // Pat<(OpQ $x, $x),
+ // [(CheckIntIs32Bits $x)],
+ // (replaceWithValue $x)>;
+
+ // CHECK: "test.op_q"(%arg0, %arg1)
+ %0 = "test.op_q"(%arg0, %arg1) : (i32, f32) -> (i32)
+
+ // CHECK: "test.op_q"(%arg1, %arg0)
+ %1 = "test.op_q"(%arg1, %arg0) : (f32, i32) -> (i32)
+
+ return
+}
+
//===----------------------------------------------------------------------===//
// Test Symbol Binding
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 605033daa719f..40bc1a9c3868c 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
int depth = 0;
emitMatch(tree, opName, depth);
+ // Some of the operands could be bound to the same symbol name, we need
+ // to enforce equality constraint on those.
+ // This has to happen before user provided constraints, which may assume the
+ // same name checks are already performed, since in the pattern source code
+ // the user provided constraints appear later.
+ // TODO: we should be able to emit equality checks early
+ // and short circuit unnecessary work if vars are not equal.
+ for (auto symbolInfoIt = symbolInfoMap.begin();
+ symbolInfoIt != symbolInfoMap.end();) {
+ auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
+ auto startRange = range.first;
+ auto endRange = range.second;
+
+ auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
+ for (++startRange; startRange != endRange; ++startRange) {
+ auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
+ emitMatchCheck(
+ opName,
+ formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
+ formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
+ secondOperand));
+ }
+
+ symbolInfoIt = endRange;
+ }
+
for (auto &appliedConstraint : pattern.getConstraints()) {
auto &constraint = appliedConstraint.constraint;
auto &entities = appliedConstraint.entities;
@@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
}
}
- // Some of the operands could be bound to the same symbol name, we need
- // to enforce equality constraint on those.
- // TODO: we should be able to emit equality checks early
- // and short circuit unnecessary work if vars are not equal.
- for (auto symbolInfoIt = symbolInfoMap.begin();
- symbolInfoIt != symbolInfoMap.end();) {
- auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
- auto startRange = range.first;
- auto endRange = range.second;
-
- auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
- for (++startRange; startRange != endRange; ++startRange) {
- auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
- emitMatchCheck(
- opName,
- formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
- formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
- secondOperand));
- }
-
- symbolInfoIt = endRange;
- }
-
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/162526
More information about the Mlir-commits
mailing list