[Mlir-commits] [mlir] a75565a - [mlir] Execute same operand name constraints before user constraints. (#162526)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 9 16:19:51 PDT 2025


Author: Chenguang Wang
Date: 2025-10-09T16:19:47-07:00
New Revision: a75565a54401571b896e1a3c60939e4dcdc0b13a

URL: https://github.com/llvm/llvm-project/commit/a75565a54401571b896e1a3c60939e4dcdc0b13a
DIFF: https://github.com/llvm/llvm-project/commit/a75565a54401571b896e1a3c60939e4dcdc0b13a.diff

LOG: [mlir] Execute same operand name constraints before user constraints. (#162526)

For a pattern like this:

    Pat<(MyOp $x, $x),
        (...),
        [(MyCheck $x)]>;

The old implementation generates:

    Pat<(MyOp $x0, $x1),
        (...),
        [(MyCheck $x0),
         ($x0 == $x1)]>;

This is not very straightforward, because the $x name appears in the
source pattern; it's attempting to assume equality check will be
performed as part of the source pattern matching.

This commit moves the equality checks before the other constraints,
i.e.:

    Pat<(MyOp $x0, $x1),
        (...),
        [($x0 == $x1),
         (MyCheck $x0)]>;

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6ea27187655ee..6329d61ba691b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1169,6 +1169,11 @@ def OpP : TEST_Op<"op_p"> {
   let results = (outs I32);
 }
 
+def OpQ : TEST_Op<"op_q"> {
+  let arguments = (ins AnyType, AnyType);
+  let results = (outs AnyType);
+}
+
 // Test constant-folding a pattern that maps `(F32) -> SI32`.
 def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
   let arguments = (ins RankedTensorOf<[F32]>:$operand);
@@ -1207,6 +1212,14 @@ def TestNestedSameOpAndSameArgEqualityPattern :
 def TestMultipleEqualArgsPattern :
   Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
 
+// Test equal arguments checks are applied before user provided constraints.
+def AssertBinOpEqualArgsAndReturnTrue : Constraint<
+  CPred<"assertBinOpEqualArgsAndReturnTrue($0)">>;
+def TestEqualArgsCheckBeforeUserConstraintsPattern :
+  Pat<(OpQ:$op $x, $x),
+      (replaceWithValue $x),
+      [(AssertBinOpEqualArgsAndReturnTrue $op)]>;
+
 // Test for memrefs normalization of an op with normalizable memrefs.
 def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
   let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f8b5144e3acb2..ee4fa39158721 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -70,6 +70,16 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) {
   return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
 }
 
+static bool assertBinOpEqualArgsAndReturnTrue(Value v) {
+  Operation *operation = v.getDefiningOp();
+  if (operation->getOperand(0) != operation->getOperand(1)) {
+    // Name binding equality check must happen before user-defined constraints,
+    // thus this must not be triggered.
+    llvm::report_fatal_error("Arguments are not equal");
+  }
+  return true;
+}
+
 namespace {
 #include "TestPatterns.inc"
 } // namespace

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index bd55338618eec..ffb78c28412ce 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -156,16 +156,19 @@ func.func @verifyNestedOpEqualArgs(
   // def TestNestedOpEqualArgsPattern :
   //   Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
 
-  // CHECK: %arg1
+  // CHECK: "test.op_o"(%arg1)
   %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
     : (i32, i32, i32, i32, i32, i32) -> (i32)
   %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
+  %2 = "test.op_o"(%1) : (i32) -> (i32)
 
-  // CHECK: test.op_p
-  // CHECK: test.op_n
-  %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
+  // CHECK-NEXT: %[[P:.*]] = "test.op_p"
+  // CHECK-NEXT: %[[N:.*]] = "test.op_n"(%arg0, %[[P]])
+  // CHECK-NEXT: "test.op_o"(%[[N]])
+  %3 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
     : (i32, i32, i32, i32, i32, i32) -> (i32)
-  %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
+  %4 = "test.op_n"(%arg0, %3) : (i32, i32) -> (i32)
+  %5 = "test.op_o"(%4) : (i32) -> (i32)
 
   return
 }
@@ -206,6 +209,21 @@ func.func @verifyMultipleEqualArgs(
   return
 }
 
+func.func @verifyEqualArgsCheckBeforeUserConstraints(%arg0: i32, %arg1: f32) {
+  // def TestEqualArgsCheckBeforeUserConstraintsPattern :
+  //   Pat<(OpQ:$op $x, $x),
+  //       (replaceWithValue $x),
+  //       [(AssertBinOpEqualArgsAndReturnTrue $op)]>;
+
+  // CHECK: "test.op_q"(%arg0, %arg1)
+  %0 = "test.op_q"(%arg0, %arg1) : (i32, f32) -> (i32)
+
+  // CHECK: "test.op_q"(%arg1, %arg0)
+  %1 = "test.op_q"(%arg1, %arg0) : (f32, i32) -> (i32)
+
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test Symbol Binding
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 605033daa719f..40bc1a9c3868c 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
   int depth = 0;
   emitMatch(tree, opName, depth);
 
+  // Some of the operands could be bound to the same symbol name, we need
+  // to enforce equality constraint on those.
+  // This has to happen before user provided constraints, which may assume the
+  // same name checks are already performed, since in the pattern source code
+  // the user provided constraints appear later.
+  // TODO: we should be able to emit equality checks early
+  // and short circuit unnecessary work if vars are not equal.
+  for (auto symbolInfoIt = symbolInfoMap.begin();
+       symbolInfoIt != symbolInfoMap.end();) {
+    auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
+    auto startRange = range.first;
+    auto endRange = range.second;
+
+    auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
+    for (++startRange; startRange != endRange; ++startRange) {
+      auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
+      emitMatchCheck(
+          opName,
+          formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
+          formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
+                  secondOperand));
+    }
+
+    symbolInfoIt = endRange;
+  }
+
   for (auto &appliedConstraint : pattern.getConstraints()) {
     auto &constraint = appliedConstraint.constraint;
     auto &entities = appliedConstraint.entities;
@@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
     }
   }
 
-  // Some of the operands could be bound to the same symbol name, we need
-  // to enforce equality constraint on those.
-  // TODO: we should be able to emit equality checks early
-  // and short circuit unnecessary work if vars are not equal.
-  for (auto symbolInfoIt = symbolInfoMap.begin();
-       symbolInfoIt != symbolInfoMap.end();) {
-    auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
-    auto startRange = range.first;
-    auto endRange = range.second;
-
-    auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
-    for (++startRange; startRange != endRange; ++startRange) {
-      auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
-      emitMatchCheck(
-          opName,
-          formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
-          formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
-                  secondOperand));
-    }
-
-    symbolInfoIt = endRange;
-  }
-
   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
 }
 


        


More information about the Mlir-commits mailing list