[Mlir-commits] [mlir] [mlir] Handle NativeCodeCallVoid in result patterns. (PR #65804)
Jian Cai
llvmlistbot at llvm.org
Fri Sep 8 14:36:53 PDT 2023
https://github.com/jcai19 updated https://github.com/llvm/llvm-project/pull/65804:
>From b147d4df4daa64d99dbf46381c056e67a3aec41f Mon Sep 17 00:00:00 2001
From: Jian Cai <caij2003 at gmail.com>
Date: Fri, 8 Sep 2023 14:28:36 -0700
Subject: [PATCH] [mlir] Handle NativeCodeCallVoid in result patterns.
Currently NativeCodeCallVoid is not supported in the result patterns.
For example, below code will fail to build with an error message
`referencing unbound symbol`.
```
def Foo: NativeCodeCallVoid<"foo()">;
def AddToAddV2 : Pattern<
(TF_AddOp TF_NumberTensor:$arg0, TF_NumberTensor:$arg1),
[(TF_AddV2Op $arg0, $arg1), (Foo)]>;
```
MLIR tablegen-based pattern rewrites does not preserve attributes of the source
op, with this change users could mannualy copy source attributes to the target
op via NativeCodeCallVoid. This is a replacement reviews.llvm.org/D157032.
---
mlir/tools/mlir-tblgen/RewriterGen.cpp | 27 +++++++++++++++-----------
1 file changed, 16 insertions(+), 11 deletions(-)
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 6bb79fb4b4cbe67..bc2731df1850838 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1184,17 +1184,22 @@ void PatternEmitter::emitRewriteLogic() {
DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
os << "\n";
- // Resolve each symbol for all range use so that we can loop over them.
- // We need an explicit cast to `SmallVector` to capture the cases where
- // `{0}` resolves to an `Operation::result_range` as well as cases that
- // are not iterable (e.g. vector that gets wrapped in additional braces by
- // RewriterGen).
- // TODO: Revisit the need for materializing a vector.
- os << symbolInfoMap.getAllRangeUse(
- val,
- "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
- " tblgen_repl_values.push_back(v);\n}\n",
- "\n");
+ if (resultTree.isNativeCodeCall() &&
+ resultTree.getNumReturnsOfNativeCode() == 0) {
+ os << val << ";\n";
+ } else {
+ // Resolve each symbol for all range use so that we can loop over them.
+ // We need an explicit cast to `SmallVector` to capture the cases where
+ // `{0}` resolves to an `Operation::result_range` as well as cases that
+ // are not iterable (e.g. vector that gets wrapped in additional braces by
+ // RewriterGen).
+ // TODO: Revisit the need for materializing a vector.
+ os << symbolInfoMap.getAllRangeUse(
+ val,
+ "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
+ " tblgen_repl_values.push_back(v);\n}\n",
+ "\n");
+ }
}
os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
}
More information about the Mlir-commits
mailing list