[Mlir-commits] [mlir] [mlir] Handle NativeCodeCallVoid in result patterns. (PR #65804)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 11 10:12:05 PDT 2023


llvmbot wrote:

@llvm/pr-subscribers-mlir-core

<details>
<summary>Changes</summary>

Currently NativeCodeCallVoid is not supported in the result patterns. For example, below code will fail to build with an error message "referencing unbound symbol"

def Foo: NativeCodeCallVoid<"foo()">;

def AddToAddV2 : Pattern<
  (TF_AddOp TF_NumberTensor:$arg0, TF_NumberTensor:$arg1),
  [(TF_AddV2Op $arg0, $arg1), (Foo)]>;

MLIR tablegen-based pattern rewrites does not preserve attributes of the source op, with this change users could mannualy copy source attributes to the target op via NativeCodeCallVoid. This is a replacement of D157032.
--
Full diff: https://github.com/llvm/llvm-project/pull/65804.diff

1 Files Affected:

- (modified) mlir/tools/mlir-tblgen/RewriterGen.cpp (+16-11) 


<pre>
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 6bb79fb4b4cbe67..bc2731df1850838 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1184,17 +1184,22 @@ void PatternEmitter::emitRewriteLogic() {
       DagNode resultTree = pattern.getResultPattern(i);
       auto val = handleResultPattern(resultTree, offsets[i], 0);
       os << "\n";
-      // Resolve each symbol for all range use so that we can loop over them.
-      // We need an explicit cast to `SmallVector` to capture the cases where
-      // `{0}` resolves to an `Operation::result_range` as well as cases that
-      // are not iterable (e.g. vector that gets wrapped in additional braces by
-      // RewriterGen).
-      // TODO: Revisit the need for materializing a vector.
-      os << symbolInfoMap.getAllRangeUse(
-          val,
-          "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
-          "  tblgen_repl_values.push_back(v);\n}\n",
-          "\n");
+      if (resultTree.isNativeCodeCall() &&
+          resultTree.getNumReturnsOfNativeCode() == 0) {
+        os << val << ";\n";
+      } else {
+        // Resolve each symbol for all range use so that we can loop over them.
+        // We need an explicit cast to `SmallVector` to capture the cases where
+        // `{0}` resolves to an `Operation::result_range` as well as cases that
+        // are not iterable (e.g. vector that gets wrapped in additional braces by
+        // RewriterGen).
+        // TODO: Revisit the need for materializing a vector.
+        os << symbolInfoMap.getAllRangeUse(
+            val,
+            "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
+            "  tblgen_repl_values.push_back(v);\n}\n",
+            "\n");
+      }
     }
     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
   }
</pre>

</details>

https://github.com/llvm/llvm-project/pull/65804


More information about the Mlir-commits mailing list