[Mlir-commits] [mlir] TableGen multi-result support in source patterns (PR #159656)
Martin Coll
llvmlistbot at llvm.org
Thu Sep 18 14:35:35 PDT 2025
https://github.com/colltoaction created https://github.com/llvm/llvm-project/pull/159656
The changes proposed in this pull request enhance the matching capabilities of the TableGen language. It extends the "__N suffix" approach, only previously available in replacement patterns.
The new logic is exercised in the following TableGen pattern. Notice `OneResultOp4` references a two-result operation, and multi-result operations previously never matched. With these changes, the `"__1"` suffix means this pattern will now match when the parameter received is `#1` from `TwoResultOp2`, as seen in the MLIR example below.
```td
def : Pat<
(OneResultOp4 (TwoResultOp2:$a__1)),
(replaceWithValue $a__0)>;
```
```mlir
%0:2 = "test.two_result2"() : () -> (f32, f32)
%1 = "test.one_result4"(%0#1) : (f32) -> (f32)
return %1 : f32
```
CC @jpienaar following up after our Discord [conversation](https://discord.com/channels/636084430946959380/1375177152684883968). Hope you find this is a good addition! Thanks in advance for reviewing my code.
Thank you!
>From 736f47d03eccbd5afcf76a72c9cf923ac0098452 Mon Sep 17 00:00:00 2001
From: Martin Coll <mcoll at dc.uba.ar>
Date: Thu, 18 Sep 2025 20:49:43 +0000
Subject: [PATCH] Add support to TableGen source patterns to match multi-result
values by index
---
mlir/include/mlir/TableGen/Pattern.h | 5 +++--
mlir/lib/TableGen/Pattern.cpp | 9 +++++++--
mlir/test/lib/Dialect/Test/TestOps.td | 16 ++++++++++++++++
mlir/test/mlir-tblgen/pattern.mlir | 19 +++++++++++++++++++
mlir/tools/mlir-tblgen/RewriterGen.cpp | 13 ++++++++++++-
5 files changed, 57 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 49b2dae62dc22..52069278c5eea 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -433,8 +433,9 @@ class SymbolInfoMap {
DagAndConstant(node.getAsOpaquePointer(), operandIndex,
variadicSubIndex));
}
- static SymbolInfo getResult(const Operator *op) {
- return SymbolInfo(op, Kind::Result, std::nullopt);
+ static SymbolInfo getResult(const Operator *op, int index) {
+ return SymbolInfo(op, Kind::Result,
+ DagAndConstant(nullptr, index, std::nullopt));
}
static SymbolInfo getValue() {
return SymbolInfo(nullptr, Kind::Value, std::nullopt);
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 1a1a58ad271bb..38725050cefe8 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -370,6 +370,8 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
case Kind::Result: {
// If `index` is greater than zero, then we are referencing a specific
// result of a multi-result op. The result can still be variadic.
+ if (index < 0)
+ index = dagAndConstant->operandIndexOrNumValues;
if (index >= 0) {
std::string v =
std::string(formatv("{0}.getODSResults({1})", name, index));
@@ -442,6 +444,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
return std::string(repl);
}
case Kind::Result: {
+ if (index < 0)
+ index = dagAndConstant->operandIndexOrNumValues;
if (index >= 0) {
auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
@@ -522,8 +526,9 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
}
bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
- std::string name = getValuePackName(symbol).str();
- auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
+ int index = -1;
+ StringRef name = getValuePackName(symbol, &index);
+ auto inserted = symbolInfoMap.emplace(name.str(), SymbolInfo::getResult(&op, index));
return symbolInfoMap.count(inserted->first) == 1;
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5564264ed8b0b..4bd68a5801bac 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1660,6 +1660,16 @@ def OneResultOp3 : TEST_Op<"one_result3"> {
let results = (outs I32:$result1);
}
+def OneResultOp4 : TEST_Op<"one_result4"> {
+ let arguments = (ins F32);
+ let results = (outs F32);
+}
+
+def TwoResultOp2 : TEST_Op<"two_result2"> {
+ let arguments = (ins);
+ let results = (outs F32, F32);
+}
+
// Test using multi-result op as a whole
def : Pat<(ThreeResultOp MultiResultOpKind1:$kind),
(AnotherThreeResultOp $kind)>;
@@ -1696,6 +1706,12 @@ def : Pattern<
(AnotherTwoResultOp $kind)
]>;
+// Test referencing a one-param op whose
+// param comes from the first result of a two-result op.
+def : Pat<
+ (OneResultOp4 (TwoResultOp2:$a__1)),
+ (replaceWithValue $a__0)>;
+
//===----------------------------------------------------------------------===//
// Test Patterns (Variadic Ops)
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index bd55338618eec..cedf528fb8717 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -594,6 +594,25 @@ func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %ar
return
}
+// CHECK-LABEL: @replaceOneResultWithNSuffixArgMatch
+func.func @replaceOneResultWithNSuffixArgMatch() -> (f32) {
+ // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
+ %0:2 = "test.two_result2"() : () -> (f32, f32)
+ %1 = "test.one_result4"(%0#1) : (f32) -> (f32)
+ // CHECK: return %0#0 : f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: @replaceOneResultWithNSuffixArgNoMatch
+func.func @replaceOneResultWithNSuffixArgNoMatch() -> (f32) {
+ // CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
+ %0:2 = "test.two_result2"() : () -> (f32, f32)
+ // CHECK: %1 = "test.one_result4"(%0#0) : (f32) -> f32
+ %1 = "test.one_result4"(%0#0) : (f32) -> (f32)
+ // CHECK: return %1 : f32
+ return %1 : f32
+}
+
//===----------------------------------------------------------------------===//
// Test patterns that operate on properties
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 605033daa719f..312f5174b7b9e 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -615,10 +615,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
op.getQualCppClassName()));
// If the operand's name is set, set to that variable.
- auto name = tree.getSymbol();
+ int index = -1;
+ auto name = SymbolInfoMap::getValuePackName(tree.getSymbol(), &index).str();
if (!name.empty())
os << formatv("{0} = {1};\n", name, castedName);
+ if (index != -1) {
+ emitMatchCheck(opName,
+ formatv("(resultNumber{0} == 1)", depth),
+ formatv("\"{0} does not come from result number {1} type\"", castedName, index));
+ }
+
for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
++i, ++opArgIdx) {
auto opArg = op.getArg(opArgIdx);
@@ -662,6 +669,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
"auto *{0} = "
"(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
argName, castedName, nextOperand);
+ os.indent() << formatv(
+ "[[maybe_unused]] auto resultNumber{0} = "
+ "::llvm::dyn_cast<::mlir::OpResult>((*{1}.getODSOperands({2}).begin())).getResultNumber();\n",
+ depth + 1, castedName, nextOperand);
// Null check of operand's definingOp
emitMatchCheck(
castedName, /*matchStr=*/argName,
More information about the Mlir-commits
mailing list