[Mlir-commits] [mlir] [mlir] Handle NativeCodeCallVoid in result patterns. (PR #65804)

Jian Cai llvmlistbot at llvm.org
Fri Sep 8 14:29:05 PDT 2023


https://github.com/jcai19 updated https://github.com/llvm/llvm-project/pull/65804:

>From b147d4df4daa64d99dbf46381c056e67a3aec41f Mon Sep 17 00:00:00 2001
From: Jian Cai <caij2003 at gmail.com>
Date: Fri, 8 Sep 2023 14:28:36 -0700
Subject: [PATCH] [mlir] Handle NativeCodeCallVoid in result patterns.

Currently NativeCodeCallVoid is not supported in the result patterns.
For example, below code will fail to build with an error message
`referencing unbound symbol`.

```
def Foo: NativeCodeCallVoid<"foo()">;

def AddToAddV2 : Pattern<
  (TF_AddOp TF_NumberTensor:$arg0, TF_NumberTensor:$arg1),
  [(TF_AddV2Op $arg0, $arg1), (Foo)]>;
```

MLIR tablegen-based pattern rewrites does not preserve attributes of the source
op, with this change users could mannualy copy source attributes to the target
op via NativeCodeCallVoid. This is a replacement reviews.llvm.org/D157032.
---
 mlir/tools/mlir-tblgen/RewriterGen.cpp | 27 +++++++++++++++-----------
 1 file changed, 16 insertions(+), 11 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 6bb79fb4b4cbe67..bc2731df1850838 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1184,17 +1184,22 @@ void PatternEmitter::emitRewriteLogic() {
       DagNode resultTree = pattern.getResultPattern(i);
       auto val = handleResultPattern(resultTree, offsets[i], 0);
       os << "\n";
-      // Resolve each symbol for all range use so that we can loop over them.
-      // We need an explicit cast to `SmallVector` to capture the cases where
-      // `{0}` resolves to an `Operation::result_range` as well as cases that
-      // are not iterable (e.g. vector that gets wrapped in additional braces by
-      // RewriterGen).
-      // TODO: Revisit the need for materializing a vector.
-      os << symbolInfoMap.getAllRangeUse(
-          val,
-          "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
-          "  tblgen_repl_values.push_back(v);\n}\n",
-          "\n");
+      if (resultTree.isNativeCodeCall() &&
+          resultTree.getNumReturnsOfNativeCode() == 0) {
+        os << val << ";\n";
+      } else {
+        // Resolve each symbol for all range use so that we can loop over them.
+        // We need an explicit cast to `SmallVector` to capture the cases where
+        // `{0}` resolves to an `Operation::result_range` as well as cases that
+        // are not iterable (e.g. vector that gets wrapped in additional braces by
+        // RewriterGen).
+        // TODO: Revisit the need for materializing a vector.
+        os << symbolInfoMap.getAllRangeUse(
+            val,
+            "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
+            "  tblgen_repl_values.push_back(v);\n}\n",
+            "\n");
+      }
     }
     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
   }



More information about the Mlir-commits mailing list