[Mlir-commits] [mlir] Handle NativeCodeCallVoid in result patterns. (PR #65804)

Jian Cai llvmlistbot at llvm.org
Fri Sep 8 14:12:20 PDT 2023


https://github.com/jcai19 updated https://github.com/llvm/llvm-project/pull/65804:

>From 20edac5e7e8fe380d14896047ab2c2c77bb78f2e Mon Sep 17 00:00:00 2001
From: Jian Cai <caij2003 at gmail.com>
Date: Fri, 8 Sep 2023 10:41:18 -0700
Subject: [PATCH] [mlir] Handle NativeCodeCallVoid in result patterns.

Currently NativeCodeCallVoid is not supported in the result patterns.
For example, below code will fail to build with an error message
"referencing unbound symbol"

def Foo: NativeCodeCallVoid<"foo()">;

def AddToAddV2 : Pattern<
  (TF_AddOp TF_NumberTensor:$arg0, TF_NumberTensor:$arg1),
  [(TF_AddV2Op $arg0, $arg1), (Foo)]>;

MLIR tablegen-based pattern rewrites does not preserve attributes of the source
op, with this change users could mannualy copy source attributes to the target
op via NativeCodeCallVoid. This is a replacement reviews.llvm.org/D157032.

Differential Revision: https://reviews.llvm.org/D159489
---
 mlir/tools/mlir-tblgen/RewriterGen.cpp | 27 +++++++++++++++-----------
 1 file changed, 16 insertions(+), 11 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 6bb79fb4b4cbe67..bc2731df1850838 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1184,17 +1184,22 @@ void PatternEmitter::emitRewriteLogic() {
       DagNode resultTree = pattern.getResultPattern(i);
       auto val = handleResultPattern(resultTree, offsets[i], 0);
       os << "\n";
-      // Resolve each symbol for all range use so that we can loop over them.
-      // We need an explicit cast to `SmallVector` to capture the cases where
-      // `{0}` resolves to an `Operation::result_range` as well as cases that
-      // are not iterable (e.g. vector that gets wrapped in additional braces by
-      // RewriterGen).
-      // TODO: Revisit the need for materializing a vector.
-      os << symbolInfoMap.getAllRangeUse(
-          val,
-          "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
-          "  tblgen_repl_values.push_back(v);\n}\n",
-          "\n");
+      if (resultTree.isNativeCodeCall() &&
+          resultTree.getNumReturnsOfNativeCode() == 0) {
+        os << val << ";\n";
+      } else {
+        // Resolve each symbol for all range use so that we can loop over them.
+        // We need an explicit cast to `SmallVector` to capture the cases where
+        // `{0}` resolves to an `Operation::result_range` as well as cases that
+        // are not iterable (e.g. vector that gets wrapped in additional braces by
+        // RewriterGen).
+        // TODO: Revisit the need for materializing a vector.
+        os << symbolInfoMap.getAllRangeUse(
+            val,
+            "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
+            "  tblgen_repl_values.push_back(v);\n}\n",
+            "\n");
+      }
     }
     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
   }



More information about the Mlir-commits mailing list