[Mlir-commits] [mlir] 0804a88 - [mlir][linalg] Transform PadTensorOp into InitOp, FillOp, GenericOp
Matthias Springer
llvmlistbot at llvm.org
Thu Jun 3 06:13:26 PDT 2021
Author: Nicolas Agostini
Date: 2021-06-03T22:09:09+09:00
New Revision: 0804a88e48ac23bcf73d6b985ef755559419ee11
URL: https://github.com/llvm/llvm-project/commit/0804a88e48ac23bcf73d6b985ef755559419ee11
DIFF: https://github.com/llvm/llvm-project/commit/0804a88e48ac23bcf73d6b985ef755559419ee11.diff
LOG: [mlir][linalg] Transform PadTensorOp into InitOp, FillOp, GenericOp
Introduces a test pass that rewrites PadTensorOps with static shapes as a sequence of:
```
linalg.init_tensor // to create output
linalg.fill // to initialize with padding value
linalg.generic // to copy the original contents to the padded tensor
```
The pass can be triggered with:
- `--test-linalg-transform-patterns="test-transform-pad-tensor"`
Differential Revision: https://reviews.llvm.org/D102804
Added:
mlir/test/Dialect/Linalg/lower-pad-tensor.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d6cb5cb0e39b..25f4b3627d6e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -871,6 +871,15 @@ void populateLinalgDistributeTiledLoopPattern(
// Op-specific patterns.
//===----------------------------------------------------------------------===//
+/// PadTensorOp is not canonicalized away yet, so we provide a transformation to
+/// `linalg.generic`.
+struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
+ using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PadTensorOp padOp,
+ PatternRewriter &rewriter) const override;
+};
+
/// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`,
/// it needs a specific pattern to vectorize.
struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 15420cc302da..5b42aa394163 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -637,3 +637,68 @@ LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite(
return failure();
}
+
+static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
+ return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
+}
+
+/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
+/// with pad_val) and GenericOp (to copy contents).
+LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
+ linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
+
+ auto inputShapedType = padOp.source().getType().cast<ShapedType>();
+ auto resultShapedType = padOp.result().getType().cast<ShapedType>();
+
+ // Bail on non-static shapes.
+ if (!inputShapedType.hasStaticShape())
+ return failure();
+ if (!resultShapedType.hasStaticShape())
+ return failure();
+
+ // Only support padding with a constant for now, i.e. either:
+ // 1. A BBarg from a
diff erent block.
+ // 2. A value defined outside of the current block.
+ Block &block = padOp.region().front();
+ auto yieldOp = cast<YieldOp>(block.getTerminator());
+ assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
+ Value padValue = yieldOp.values().front();
+ Operation *definingOp = padValue.getDefiningOp();
+ if (definingOp && definingOp->getBlock() == &block)
+ return failure();
+ if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
+ return failure();
+
+ // Create tensor with the padded shape
+ Location loc = padOp.getLoc();
+ SmallVector<Value> indices(resultShapedType.getRank(),
+ rewriter.create<ConstantIndexOp>(loc, 0));
+ Value initTensor = rewriter.create<InitTensorOp>(
+ loc, resultShapedType.getShape(), resultShapedType.getElementType());
+
+ // Initialize tensor with the pad value
+ Value tmpTensor =
+ rewriter.create<linalg::FillOp>(loc, initTensor, padValue).result();
+
+ // Copy original contents into new tensor
+ // Uses linalg.generic, but could be done with std.subtensor_insert
+ SmallVector<AffineExpr, 4> outputExprs;
+ for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
+ outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
+ padOp.static_low()[i].cast<IntegerAttr>().getInt());
+ }
+
+ SmallVector<AffineMap, 2> transferMaps = {
+ rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
+ AffineMap::get(resultShapedType.getRank(),
+ /*symbolCount=*/0, outputExprs, rewriter.getContext())};
+
+ rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+ padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
+ getNParallelLoopsAttrs(resultShapedType.getRank()),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+ });
+
+ return success();
+}
diff --git a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir
new file mode 100644
index 000000000000..98eb27fbe00d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-transform-pad-tensor" %s | FileCheck --check-prefix=CHECK %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 + 1, d1 + 1, d2 + 1, d3 + 2)>
+// CHECK-LABEL: func @pad_tensor_with_memrefs
+func @pad_tensor_with_memrefs(%arg0: memref<1x28x28x1xf32>) -> memref<2x31x31x3xf32> {
+ %cst = constant 0.000000e+00 : f32
+ %0 = memref.tensor_load %arg0 : memref<1x28x28x1xf32>
+ %1 = linalg.pad_tensor %0 low[1, 1, 1, 2] high[0, 2, 2, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<1x28x28x1xf32> to tensor<2x31x31x3xf32>
+ %2 = memref.buffer_cast %1 : memref<2x31x31x3xf32>
+ return %2 : memref<2x31x31x3xf32>
+}
+
+// CHECK: linalg.fill
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+
+// -----
+
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 + 1, d1 + 2, d2 + 2)>
+// CHECK-LABEL: func @pad_tensor_no_memrefs
+func @pad_tensor_no_memrefs(%arg0: tensor<1x28x28xf32>) -> tensor<2x32x32xf32> {
+ %cst = constant 0.000000e+00 : f32
+ %0 = linalg.pad_tensor %arg0 low[1, 2, 2] high[0, 2, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<1x28x28xf32> to tensor<2x32x32xf32>
+ return %0 : tensor<2x32x32xf32>
+}
+
+// CHECK: linalg.fill
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
+
+// -----
+
+// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 + 2, d2 + 2, d3)>
+// CHECK-LABEL: func @pad_tensor_detailed
+func @pad_tensor_detailed(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
+ %cst = constant 0.000000e+00 : f32
+ %0 = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
+ return %0 : tensor<1x32x32x1xf32>
+}
+
+// CHECK: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32>
+// CHECK: %[[CTE:.+]] = constant 0.000000e+00 : f32
+// CHECK: %[[TMP:.+]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32>
+// CHECK: %[[R1c:.+]] = linalg.fill
+// CHECK: %[[R2c:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP4]], #[[$MAP5]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK: ins(%arg0 : tensor<1x28x28x1xf32>) outs(%1 : tensor<1x32x32x1xf32>)
+// CHECK: ^bb0(%[[VAL:.+]]: f32, %arg2: f32)
+// CHECK: linalg.yield %[[VAL]] : f32
+// CHECK: return %[[R2c:.+]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index c8ef2d419065..181e1d93b250 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -87,6 +87,10 @@ struct TestLinalgTransforms
Option<int> testHoistPadding{*this, "test-hoist-padding",
llvm::cl::desc("Test hoist padding"),
llvm::cl::init(0)};
+ Option<bool> testTransformPadTensor{
+ *this, "test-transform-pad-tensor",
+ llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
+ llvm::cl::init(false)};
ListOption<int64_t> tileSizesForPadding{
*this, "tile-sizes-for-padding",
llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
@@ -508,6 +512,12 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
RewritePatternSet foldPattern(funcOp.getContext());
foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
@@ -583,6 +593,8 @@ void TestLinalgTransforms::runOnFunction() {
return applyVectorTransferForwardingPatterns(getFunction());
if (testGenericToVectorPattern)
return applyLinalgToVectorPatterns(getFunction());
+ if (testTransformPadTensor)
+ return applyPadTensorToGenericPatterns(getFunction());
if (testAffineMinSCFCanonicalizationPatterns)
return applyAffineMinSCFCanonicalizationPatterns(getFunction());
if (testTileAndPadPattern)
More information about the Mlir-commits
mailing list