[Mlir-commits] [mlir] 50b8a3c - [mlir][linalg] Refactor `EraseIdentityGenericOp` to be reused by other `LinalgOp`s (#80466)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 12 10:53:21 PST 2024
Author: srcarroll
Date: 2024-02-12T12:53:17-06:00
New Revision: 50b8a3c01c6a20327dc3f65d2ee175ce73cdcc73
URL: https://github.com/llvm/llvm-project/commit/50b8a3c01c6a20327dc3f65d2ee175ce73cdcc73
DIFF: https://github.com/llvm/llvm-project/commit/50b8a3c01c6a20327dc3f65d2ee175ce73cdcc73.diff
LOG: [mlir][linalg] Refactor `EraseIdentityGenericOp` to be reused by other `LinalgOp`s (#80466)
This refactored pattern rewrite is intended to be reused by any
`LinalgOp`'s canonicalization pattern for removing identity ops.
Additionally, this canonicalization has been applied to `BroadCastOp`.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 751edd02288301..272bc3116c5fdc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -531,6 +531,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e86b9762d8581f..a0f02f6a7f259d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1087,24 +1087,25 @@ LogicalResult GenericOp::verify() { return success(); }
namespace {
-/// Remove generic operations (on tensors) that are just copying
+/// Remove any linalg operation (on tensors) that are just copying
/// the values from inputs to the results. Requirements are
/// 1) All iterator types are parallel
/// 2) The body contains just a yield operation with the yielded values being
/// the arguments corresponding to the operands.
-struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
+template <typename OpTy>
+struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(OpTy linalgOp,
PatternRewriter &rewriter) const override {
// Check all indexing maps are identity.
- if (llvm::any_of(genericOp.getIndexingMapsArray(),
+ if (llvm::any_of(linalgOp.getIndexingMapsArray(),
[](AffineMap map) { return !map.isIdentity(); }))
return failure();
// Check that the body of the linalg operation is just a linalg.yield
// operation.
- Block &body = genericOp.getRegion().front();
+ Block &body = linalgOp->getRegion(0).front();
if (!llvm::hasSingleElement(body))
return failure();
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
@@ -1112,18 +1113,18 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
return failure();
// In the buffer case, we need to check exact buffer equality.
- if (genericOp.hasPureBufferSemantics()) {
- if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
- genericOp.getDpsInputOperand(0)->get() ==
- genericOp.getDpsInitOperand(0)->get()) {
- rewriter.eraseOp(genericOp);
+ if (linalgOp.hasPureBufferSemantics()) {
+ if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
+ linalgOp.getDpsInputOperand(0)->get() ==
+ linalgOp.getDpsInitOperand(0)->get()) {
+ rewriter.eraseOp(linalgOp);
return success();
}
return failure();
}
// Mixed semantics is not supported yet.
- if (!genericOp.hasPureTensorSemantics())
+ if (!linalgOp.hasPureTensorSemantics())
return failure();
// Get the argument number of the returned values. That is the operand
@@ -1134,8 +1135,8 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
if (!yieldArg || yieldArg.getOwner() != &body)
return failure();
unsigned argumentNumber = yieldArg.getArgNumber();
- Value returnedArg = genericOp->getOperand(argumentNumber);
- Type resultType = genericOp->getResult(yieldVal.index()).getType();
+ Value returnedArg = linalgOp->getOperand(argumentNumber);
+ Type resultType = linalgOp->getResult(yieldVal.index()).getType();
// The input can have a
diff erent type than the result, e.g. a dynamic
// input dimension can be turned into a static output dimension.
Type returnType = returnedArg.getType();
@@ -1145,21 +1146,21 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
if (sparse_tensor::getSparseTensorEncoding(returnType) ||
sparse_tensor::getSparseTensorEncoding(resultType))
returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
- genericOp.getLoc(), resultType, returnedArg);
+ linalgOp.getLoc(), resultType, returnedArg);
else {
if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
resultType))
return failure();
returnedArg = rewriter.create<tensor::CastOp>(
- genericOp.getLoc(), resultType, returnedArg);
+ linalgOp.getLoc(), resultType, returnedArg);
}
}
returnedArgs.push_back(returnedArg);
}
- if (returnedArgs.size() != genericOp->getNumResults())
+ if (returnedArgs.size() != linalgOp->getNumResults())
return failure();
- rewriter.replaceOp(genericOp, returnedArgs);
+ rewriter.replaceOp(linalgOp, returnedArgs);
return success();
}
};
@@ -1168,7 +1169,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<EraseIdentityGenericOp>(context);
+ results.add<EraseIdentityLinalgOp<GenericOp>>(context);
}
LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
@@ -1907,6 +1908,11 @@ void BroadcastOp::getEffects(
getDpsInits());
}
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 052dc367ca6779..721f35162ef867 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,3 +1017,15 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
%copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
return %copy : tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @broadcast_same_shape(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[ARG0]] : tensor<2x3xf32>
+func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ %0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = []
+ return %0 : tensor<2x3xf32>
+}
More information about the Mlir-commits
mailing list