[Mlir-commits] [mlir] TableGen multi-result support in source patterns (PR #159656)

Martin Coll llvmlistbot at llvm.org
Thu Sep 18 14:35:35 PDT 2025


https://github.com/colltoaction created https://github.com/llvm/llvm-project/pull/159656

The changes proposed in this pull request enhance the matching capabilities of the TableGen language. It extends the "__N suffix" approach, only previously available in replacement patterns.

The new logic is exercised in the following TableGen pattern. Notice `OneResultOp4` references a two-result operation, and multi-result operations previously never matched. With these changes, the `"__1"` suffix means this pattern will now match when the parameter received is `#1` from `TwoResultOp2`, as seen in the MLIR example below.

```td
def : Pat<
  (OneResultOp4 (TwoResultOp2:$a__1)),
  (replaceWithValue $a__0)>;
```

```mlir
%0:2 = "test.two_result2"() : () -> (f32, f32)
%1 = "test.one_result4"(%0#1) : (f32) -> (f32)
return %1 : f32
```

CC @jpienaar following up after our Discord [conversation](https://discord.com/channels/636084430946959380/1375177152684883968). Hope you find this is a good addition! Thanks in advance for reviewing my code.

Thank you!

>From 736f47d03eccbd5afcf76a72c9cf923ac0098452 Mon Sep 17 00:00:00 2001
From: Martin Coll <mcoll at dc.uba.ar>
Date: Thu, 18 Sep 2025 20:49:43 +0000
Subject: [PATCH] Add support to TableGen source patterns to match multi-result
 values by index

---
 mlir/include/mlir/TableGen/Pattern.h   |  5 +++--
 mlir/lib/TableGen/Pattern.cpp          |  9 +++++++--
 mlir/test/lib/Dialect/Test/TestOps.td  | 16 ++++++++++++++++
 mlir/test/mlir-tblgen/pattern.mlir     | 19 +++++++++++++++++++
 mlir/tools/mlir-tblgen/RewriterGen.cpp | 13 ++++++++++++-
 5 files changed, 57 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 49b2dae62dc22..52069278c5eea 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -433,8 +433,9 @@ class SymbolInfoMap {
                         DagAndConstant(node.getAsOpaquePointer(), operandIndex,
                                        variadicSubIndex));
     }
-    static SymbolInfo getResult(const Operator *op) {
-      return SymbolInfo(op, Kind::Result, std::nullopt);
+    static SymbolInfo getResult(const Operator *op, int index) {
+      return SymbolInfo(op, Kind::Result,
+                        DagAndConstant(nullptr, index, std::nullopt));
     }
     static SymbolInfo getValue() {
       return SymbolInfo(nullptr, Kind::Value, std::nullopt);
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 1a1a58ad271bb..38725050cefe8 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -370,6 +370,8 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
   case Kind::Result: {
     // If `index` is greater than zero, then we are referencing a specific
     // result of a multi-result op. The result can still be variadic.
+    if (index < 0)
+      index = dagAndConstant->operandIndexOrNumValues;
     if (index >= 0) {
       std::string v =
           std::string(formatv("{0}.getODSResults({1})", name, index));
@@ -442,6 +444,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
     return std::string(repl);
   }
   case Kind::Result: {
+    if (index < 0)
+      index = dagAndConstant->operandIndexOrNumValues;
     if (index >= 0) {
       auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
       LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
@@ -522,8 +526,9 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
 }
 
 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
-  std::string name = getValuePackName(symbol).str();
-  auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
+  int index = -1;
+  StringRef name = getValuePackName(symbol, &index);
+  auto inserted = symbolInfoMap.emplace(name.str(), SymbolInfo::getResult(&op, index));
 
   return symbolInfoMap.count(inserted->first) == 1;
 }
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5564264ed8b0b..4bd68a5801bac 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1660,6 +1660,16 @@ def OneResultOp3 : TEST_Op<"one_result3"> {
   let results = (outs I32:$result1);
 }
 
+def OneResultOp4 : TEST_Op<"one_result4"> {
+  let arguments = (ins F32);
+  let results = (outs F32);
+}
+
+def TwoResultOp2 : TEST_Op<"two_result2"> {
+  let arguments = (ins);
+  let results = (outs F32, F32);
+}
+
 // Test using multi-result op as a whole
 def : Pat<(ThreeResultOp MultiResultOpKind1:$kind),
           (AnotherThreeResultOp $kind)>;
@@ -1696,6 +1706,12 @@ def : Pattern<
         (AnotherTwoResultOp $kind)
     ]>;
 
+// Test referencing a one-param op whose
+// param comes from the first result of a two-result op.
+def : Pat<
+  (OneResultOp4 (TwoResultOp2:$a__1)),
+  (replaceWithValue $a__0)>;
+
 //===----------------------------------------------------------------------===//
 // Test Patterns (Variadic Ops)
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index bd55338618eec..cedf528fb8717 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -594,6 +594,25 @@ func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %ar
   return
 }
 
+// CHECK-LABEL: @replaceOneResultWithNSuffixArgMatch
+func.func @replaceOneResultWithNSuffixArgMatch() -> (f32) {
+  // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
+  %0:2 = "test.two_result2"() : () -> (f32, f32)
+  %1 = "test.one_result4"(%0#1) : (f32) -> (f32)
+  // CHECK: return %0#0 : f32
+  return %1 : f32
+}
+
+// CHECK-LABEL: @replaceOneResultWithNSuffixArgNoMatch
+func.func @replaceOneResultWithNSuffixArgNoMatch() -> (f32) {
+  // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
+  %0:2 = "test.two_result2"() : () -> (f32, f32)
+  // CHECK: %1 = "test.one_result4"(%0#0) : (f32) -> f32
+  %1 = "test.one_result4"(%0#0) : (f32) -> (f32)
+  // CHECK: return %1 : f32
+  return %1 : f32
+}
+
 //===----------------------------------------------------------------------===//
 // Test patterns that operate on properties
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 605033daa719f..312f5174b7b9e 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -615,10 +615,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                            op.getQualCppClassName()));
 
   // If the operand's name is set, set to that variable.
-  auto name = tree.getSymbol();
+  int index = -1;
+  auto name = SymbolInfoMap::getValuePackName(tree.getSymbol(), &index).str();
   if (!name.empty())
     os << formatv("{0} = {1};\n", name, castedName);
 
+  if (index != -1) {
+    emitMatchCheck(opName,
+      formatv("(resultNumber{0} == 1)", depth),
+      formatv("\"{0} does not come from result number {1} type\"", castedName, index));
+  }
+
   for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
        ++i, ++opArgIdx) {
     auto opArg = op.getArg(opArgIdx);
@@ -662,6 +669,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
           "auto *{0} = "
           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
           argName, castedName, nextOperand);
+        os.indent() << formatv(
+          "[[maybe_unused]] auto resultNumber{0} = "
+          "::llvm::dyn_cast<::mlir::OpResult>((*{1}.getODSOperands({2}).begin())).getResultNumber();\n",
+          depth + 1, castedName, nextOperand);
       // Null check of operand's definingOp
       emitMatchCheck(
           castedName, /*matchStr=*/argName,



More information about the Mlir-commits mailing list