[Mlir-commits] [mlir] 498f5ae - [mlir][linalg] Remove duplicate tensor.pad lowering pattern
Matthias Springer
llvmlistbot at llvm.org
Thu Jun 22 02:37:54 PDT 2023
Author: Matthias Springer
Date: 2023-06-22T11:35:27+02:00
New Revision: 498f5ae94d3b32724f6f10fb2b47616c2e04540e
URL: https://github.com/llvm/llvm-project/commit/498f5ae94d3b32724f6f10fb2b47616c2e04540e
DIFF: https://github.com/llvm/llvm-project/commit/498f5ae94d3b32724f6f10fb2b47616c2e04540e.diff
LOG: [mlir][linalg] Remove duplicate tensor.pad lowering pattern
There is another transform that lowers tensor.pad to tensor.empty + linalg.fill + tensor.insert_slice: `transform.structured.rewrite_in_destination_passing_style`. Delete the other transform.
Differential Revision: https://reviews.llvm.org/D153429
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
mlir/test/Dialect/Linalg/lower-pad-tensor.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 22137ed44beea..9591f0b9b3ef2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1151,15 +1151,6 @@ struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
PatternRewriter &rewriter) const override;
};
-/// tensor::PadOp is not canonicalized away yet, so we provide a
-/// transformation to `linalg.generic`.
-struct PadOpTransformationPattern : public OpRewritePattern<tensor::PadOp> {
- using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::PadOp padOp,
- PatternRewriter &rewriter) const override;
-};
-
using OptimizeCopyFn =
std::function<LogicalResult(RewriterBase &, tensor::PadOp, Value)>;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 5fd9228233532..9044fea4509ac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -139,12 +139,6 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
opOperand->get(), paddingValue, nofold);
}
-static SmallVector<utils::IteratorType>
-getNParallelLoopsAttrs(unsigned nParallelLoops) {
- return SmallVector<utils::IteratorType>(nParallelLoops,
- utils::IteratorType::parallel);
-}
-
//===----------------------------------------------------------------------===//
// Transformations exposed as functional-style API calls.
//===----------------------------------------------------------------------===//
@@ -1028,71 +1022,6 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
return vectorizeCopy(rewriter, copyOp);
}
-///
-/// Pattern to rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to
-/// initialize with pad_val) and GenericOp (to copy contents).
-///
-LogicalResult
-PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
- PatternRewriter &rewriter) const {
-
- auto inputShapedType = cast<ShapedType>(padOp.getSource().getType());
- auto resultShapedType = cast<ShapedType>(padOp.getResult().getType());
-
- // Bail on non-static shapes.
- if (!inputShapedType.hasStaticShape())
- return failure();
- if (!resultShapedType.hasStaticShape())
- return failure();
-
- // Only support padding with a constant for now, i.e. either:
- // 1. A BBarg from a
diff erent block.
- // 2. A value defined outside of the current block.
- Block &block = padOp.getRegion().front();
- auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
- Value padValue = yieldOp.getValue();
- Operation *definingOp = padValue.getDefiningOp();
- if (definingOp && definingOp->getBlock() == &block)
- return failure();
- if (!definingOp && cast<BlockArgument>(padValue).getOwner() == &block)
- return failure();
-
- // Create tensor with the padded shape
- Location loc = padOp.getLoc();
- SmallVector<Value> indices(resultShapedType.getRank(),
- rewriter.create<arith::ConstantIndexOp>(loc, 0));
- Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultShapedType.getShape(), resultShapedType.getElementType());
-
- // Initialize tensor with the pad value
- Value tmpTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{padValue},
- ValueRange{emptyTensor})
- .result();
-
- // Copy original contents into new tensor
- // Uses linalg.generic, but could be done with tensor.insert_slice
- SmallVector<AffineExpr, 4> outputExprs;
- for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
- outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
- padOp.getStaticLow()[i]);
- }
-
- SmallVector<AffineMap, 2> transferMaps = {
- rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
- AffineMap::get(resultShapedType.getRank(),
- /*symbolCount=*/0, outputExprs, rewriter.getContext())};
-
- rewriter.replaceOpWithNewOp<linalg::GenericOp>(
- padOp, resultShapedType, padOp.getSource(), tmpTensor, transferMaps,
- getNParallelLoopsAttrs(resultShapedType.getRank()),
- [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
- });
-
- return success();
-}
-
/// Filling `dest` using FillOp constant padding value if possible.
/// Otherwise, generate a tensor::GenerateOp.
Value GeneralizePadOpPattern::createFillOrGenerateOp(
diff --git a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir
deleted file mode 100644
index 6df26b9afaa20..0000000000000
--- a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir
+++ /dev/null
@@ -1,63 +0,0 @@
-// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-transform-pad-tensor" %s | FileCheck --check-prefix=CHECK %s
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 + 1, d1 + 1, d2 + 1, d3 + 2)>
-// CHECK-LABEL: func @pad_tensor_with_memrefs
-func.func @pad_tensor_with_memrefs(%arg0: memref<1x28x28x1xf32>) -> memref<2x31x31x3xf32> {
- %cst = arith.constant 0.000000e+00 : f32
- %0 = bufferization.to_tensor %arg0 : memref<1x28x28x1xf32>
- %1 = tensor.pad %0 low[1, 1, 1, 2] high[0, 2, 2, 0] {
- ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
- tensor.yield %cst : f32
- } : tensor<1x28x28x1xf32> to tensor<2x31x31x3xf32>
- %2 = bufferization.to_memref %1 : memref<2x31x31x3xf32>
- return %2 : memref<2x31x31x3xf32>
-}
-
-// CHECK: linalg.fill
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-
-// -----
-
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 + 1, d1 + 2, d2 + 2)>
-// CHECK-LABEL: func @pad_tensor_no_memrefs
-func.func @pad_tensor_no_memrefs(%arg0: tensor<1x28x28xf32>) -> tensor<2x32x32xf32> {
- %cst = arith.constant 0.000000e+00 : f32
- %0 = tensor.pad %arg0 low[1, 2, 2] high[0, 2, 2] {
- ^bb0(%arg1: index, %arg2: index, %arg3: index):
- tensor.yield %cst : f32
- } : tensor<1x28x28xf32> to tensor<2x32x32xf32>
- return %0 : tensor<2x32x32xf32>
-}
-
-// CHECK: linalg.fill
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
-
-// -----
-
-// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 + 2, d2 + 2, d3)>
-// CHECK-LABEL: func @pad_tensor_detailed
-func.func @pad_tensor_detailed(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
- %cst = arith.constant 0.000000e+00 : f32
- %0 = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
- ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
- tensor.yield %cst : f32
- } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
- return %0 : tensor<1x32x32x1xf32>
-}
-
-// CHECK: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32>
-// CHECK: %[[CTE:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[TMP:.+]] = tensor.empty() : tensor<1x32x32x1xf32>
-// CHECK: %[[R1c:.+]] = linalg.fill
-// CHECK: %[[R2c:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP4]], #[[$MAP5]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK: ins(%{{.*}} : tensor<1x28x28x1xf32>) outs(%{{.*}} : tensor<1x32x32x1xf32>)
-// CHECK: ^bb0(%[[VAL:.+]]: f32, %{{.*}}: f32)
-// CHECK: linalg.yield %[[VAL]] : f32
-// CHECK: return %[[R2c:.+]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index c1d01acf3ad02..4892fa2f99a7c 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -70,10 +70,6 @@ struct TestLinalgTransforms
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
"in vector.contract form"),
llvm::cl::init(false)};
- Option<bool> testTransformPadTensor{
- *this, "test-transform-pad-tensor",
- llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
- llvm::cl::init(false)};
Option<bool> testGeneralizePadTensor{
*this, "test-generalize-pad-tensor",
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
@@ -163,12 +159,6 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
-static void applyPadTensorToGenericPatterns(func::FuncOp funcOp) {
- RewritePatternSet patterns(funcOp.getContext());
- patterns.add<PadOpTransformationPattern>(funcOp.getContext());
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-}
-
static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
@@ -225,8 +215,6 @@ void TestLinalgTransforms::runOnOperation() {
return applyVectorTransferForwardingPatterns(getOperation());
if (testGenericToVectorPattern)
return applyLinalgToVectorPatterns(getOperation());
- if (testTransformPadTensor)
- return applyPadTensorToGenericPatterns(getOperation());
if (testGeneralizePadTensor)
return applyGeneralizePadTensorPatterns(getOperation());
if (testGeneralizeTensorPackOp)
More information about the Mlir-commits
mailing list