[Mlir-commits] [mlir] [mlir]: fix a issue and refine some code (#67977) (PR #68129)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 3 10:12:11 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

<details>
<summary>Changes</summary>

1) fix empty-tensor-elimination pass crash
2) improve linlg.copy op's canonicalization pattern 3) add indentation when emit regionBuilder func

---
Full diff: https://github.com/llvm/llvm-project/pull/68129.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp (+2) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+6-5) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+11) 
- (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+7-7) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 77ad13dacaa9838..4c5789306ad7583 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -149,6 +149,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
           op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
       if (!replacement)
         continue;
+      if (emptyTensorOp == replacement.getDefiningOp())
+        continue;
       if (replacement.getType() != v.getType()) {
         rewriter.setInsertionPointAfterValue(replacement);
         replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 491f4a66574616e..aad7e509c4185d1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -549,12 +549,13 @@ struct EraseSelfCopyOnBuffers : OpRewritePattern<CopyOp> {
   using OpRewritePattern<CopyOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(CopyOp copyOp,
                                 PatternRewriter &rewriter) const override {
-    if (!copyOp.hasBufferSemantics())
-      return rewriter.notifyMatchFailure(copyOp,
-                                         "does not have buffer semantics");
-    if (copyOp.getInputs().front() != copyOp.getOutputs().front())
+    if (copyOp.getInputs() != copyOp.getOutputs())
       return rewriter.notifyMatchFailure(copyOp, "not a self copy");
-    rewriter.eraseOp(copyOp);
+    if (copyOp.hasBufferSemantics())
+      rewriter.eraseOp(copyOp);
+    else
+      rewriter.replaceOp(copyOp, copyOp.getInputs());
+
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 41e43047657daff..9f187f6f416d8b1 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -317,3 +317,14 @@ func.func @linalg_copy(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
   %1 = linalg.copy ins(%filled : tensor<5xf32>) outs(%t : tensor<5xf32>) -> tensor<5xf32>
   return %1 : tensor<5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @linalg_copy_empty(
+// CHECK: %[[ret:.*]] = tensor.empty() : tensor<26xi32>
+// CHECK-NEXT: return %[[ret]]
+func.func @linalg_copy_empty() -> tensor<26xi32> {
+  %0 = tensor.empty() : tensor<26xi32>
+  %1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32>
+  return %1 : tensor<26xi32>
+}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 664167e4f6c3471..5898b0f7d69e832 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -1029,13 +1029,13 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
       // {1}: attribute name
       // {2}: default type function name
       static const char attrDef[] = R"FMT(
-{0} {1}Val = {0}::{2};
-auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
-                              return attr.getName() == "{1}"; });
-if ({1}Iter != attrs.end()) {{
-  if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
-    {1}Val = attr.getValue();
-}
+  {0} {1}Val = {0}::{2};
+  auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
+                                return attr.getName() == "{1}"; });
+  if ({1}Iter != attrs.end()) {{
+    if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
+      {1}Val = attr.getValue();
+  }
 )FMT";
       std::string enumName = convertOperandKindToEnumName(arg.kind);
       attrs.push_back(

``````````

</details>


https://github.com/llvm/llvm-project/pull/68129


More information about the Mlir-commits mailing list