[Mlir-commits] [mlir] [mlir]: fix a issue and refine some code (#67977) (PR #68129)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 3 10:11:05 PDT 2023
https://github.com/lipracer created https://github.com/llvm/llvm-project/pull/68129
1) fix empty-tensor-elimination pass crash
2) improve linlg.copy op's canonicalization pattern 3) add indentation when emit regionBuilder func
>From 870c023c43aa59c34db2a06ebe165489c8975eff Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Wed, 4 Oct 2023 00:09:25 +0800
Subject: [PATCH] [mlir]: fix a issue and refine some code (#67977)
1) fix empty-tensor-elimination pass crash
2) improve linlg.copy op's canonicalization pattern
3) add indentation when emit regionBuilder func
---
.../Transforms/EmptyTensorElimination.cpp | 2 ++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 11 ++++++-----
...ne-shot-bufferize-empty-tensor-elimination.mlir | 11 +++++++++++
.../mlir-linalg-ods-yaml-gen.cpp | 14 +++++++-------
4 files changed, 26 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 77ad13dacaa9838..4c5789306ad7583 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -149,6 +149,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
if (!replacement)
continue;
+ if (emptyTensorOp == replacement.getDefiningOp())
+ continue;
if (replacement.getType() != v.getType()) {
rewriter.setInsertionPointAfterValue(replacement);
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 491f4a66574616e..aad7e509c4185d1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -549,12 +549,13 @@ struct EraseSelfCopyOnBuffers : OpRewritePattern<CopyOp> {
using OpRewritePattern<CopyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CopyOp copyOp,
PatternRewriter &rewriter) const override {
- if (!copyOp.hasBufferSemantics())
- return rewriter.notifyMatchFailure(copyOp,
- "does not have buffer semantics");
- if (copyOp.getInputs().front() != copyOp.getOutputs().front())
+ if (copyOp.getInputs() != copyOp.getOutputs())
return rewriter.notifyMatchFailure(copyOp, "not a self copy");
- rewriter.eraseOp(copyOp);
+ if (copyOp.hasBufferSemantics())
+ rewriter.eraseOp(copyOp);
+ else
+ rewriter.replaceOp(copyOp, copyOp.getInputs());
+
return success();
}
};
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 41e43047657daff..9f187f6f416d8b1 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -317,3 +317,14 @@ func.func @linalg_copy(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
%1 = linalg.copy ins(%filled : tensor<5xf32>) outs(%t : tensor<5xf32>) -> tensor<5xf32>
return %1 : tensor<5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @linalg_copy_empty(
+// CHECK: %[[ret:.*]] = tensor.empty() : tensor<26xi32>
+// CHECK-NEXT: return %[[ret]]
+func.func @linalg_copy_empty() -> tensor<26xi32> {
+ %0 = tensor.empty() : tensor<26xi32>
+ %1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32>
+ return %1 : tensor<26xi32>
+}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 664167e4f6c3471..5898b0f7d69e832 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -1029,13 +1029,13 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
// {1}: attribute name
// {2}: default type function name
static const char attrDef[] = R"FMT(
-{0} {1}Val = {0}::{2};
-auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
- return attr.getName() == "{1}"; });
-if ({1}Iter != attrs.end()) {{
- if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
- {1}Val = attr.getValue();
-}
+ {0} {1}Val = {0}::{2};
+ auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
+ return attr.getName() == "{1}"; });
+ if ({1}Iter != attrs.end()) {{
+ if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue()))
+ {1}Val = attr.getValue();
+ }
)FMT";
std::string enumName = convertOperandKindToEnumName(arg.kind);
attrs.push_back(
More information about the Mlir-commits
mailing list