[Mlir-commits] [mlir] 0804a88 - [mlir][linalg] Transform PadTensorOp into InitOp, FillOp, GenericOp

Matthias Springer llvmlistbot at llvm.org
Thu Jun 3 06:13:26 PDT 2021


Author: Nicolas Agostini
Date: 2021-06-03T22:09:09+09:00
New Revision: 0804a88e48ac23bcf73d6b985ef755559419ee11

URL: https://github.com/llvm/llvm-project/commit/0804a88e48ac23bcf73d6b985ef755559419ee11
DIFF: https://github.com/llvm/llvm-project/commit/0804a88e48ac23bcf73d6b985ef755559419ee11.diff

LOG: [mlir][linalg] Transform PadTensorOp into InitOp, FillOp, GenericOp

Introduces a test pass that rewrites PadTensorOps with static shapes as a sequence of:

```
linalg.init_tensor // to create output
linalg.fill        // to initialize with padding value
linalg.generic     // to copy the original contents to the padded tensor
```

The pass can be triggered with:

- `--test-linalg-transform-patterns="test-transform-pad-tensor"`

Differential Revision: https://reviews.llvm.org/D102804

Added: 
    mlir/test/Dialect/Linalg/lower-pad-tensor.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d6cb5cb0e39b..25f4b3627d6e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -871,6 +871,15 @@ void populateLinalgDistributeTiledLoopPattern(
 // Op-specific patterns.
 //===----------------------------------------------------------------------===//
 
+/// PadTensorOp is not canonicalized away yet, so we provide a transformation to
+/// `linalg.generic`.
+struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
+  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadTensorOp padOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 /// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`,
 /// it needs a specific pattern to vectorize.
 struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 15420cc302da..5b42aa394163 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -637,3 +637,68 @@ LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite(
 
   return failure();
 }
+
+static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
+  return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
+}
+
+/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp (to initialize
+/// with pad_val) and GenericOp (to copy contents).
+LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
+    linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
+
+  auto inputShapedType = padOp.source().getType().cast<ShapedType>();
+  auto resultShapedType = padOp.result().getType().cast<ShapedType>();
+
+  // Bail on non-static shapes.
+  if (!inputShapedType.hasStaticShape())
+    return failure();
+  if (!resultShapedType.hasStaticShape())
+    return failure();
+
+  // Only support padding with a constant for now, i.e. either:
+  //   1. A BBarg from a 
diff erent block.
+  //   2. A value defined outside of the current block.
+  Block &block = padOp.region().front();
+  auto yieldOp = cast<YieldOp>(block.getTerminator());
+  assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
+  Value padValue = yieldOp.values().front();
+  Operation *definingOp = padValue.getDefiningOp();
+  if (definingOp && definingOp->getBlock() == &block)
+    return failure();
+  if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
+    return failure();
+
+  // Create tensor with the padded shape
+  Location loc = padOp.getLoc();
+  SmallVector<Value> indices(resultShapedType.getRank(),
+                             rewriter.create<ConstantIndexOp>(loc, 0));
+  Value initTensor = rewriter.create<InitTensorOp>(
+      loc, resultShapedType.getShape(), resultShapedType.getElementType());
+
+  // Initialize tensor with the pad value
+  Value tmpTensor =
+      rewriter.create<linalg::FillOp>(loc, initTensor, padValue).result();
+
+  // Copy original contents into new tensor
+  // Uses linalg.generic, but could be done with std.subtensor_insert
+  SmallVector<AffineExpr, 4> outputExprs;
+  for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
+    outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
+                          padOp.static_low()[i].cast<IntegerAttr>().getInt());
+  }
+
+  SmallVector<AffineMap, 2> transferMaps = {
+      rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
+      AffineMap::get(resultShapedType.getRank(),
+                     /*symbolCount=*/0, outputExprs, rewriter.getContext())};
+
+  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+      padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
+      getNParallelLoopsAttrs(resultShapedType.getRank()),
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
+      });
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir
new file mode 100644
index 000000000000..98eb27fbe00d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/lower-pad-tensor.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-transform-pad-tensor"  %s | FileCheck --check-prefix=CHECK %s
+
+// CHECK-DAG:   #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG:   #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 + 1, d1 + 1, d2 + 1, d3 + 2)>
+// CHECK-LABEL: func @pad_tensor_with_memrefs
+func @pad_tensor_with_memrefs(%arg0: memref<1x28x28x1xf32>) -> memref<2x31x31x3xf32> {
+  %cst = constant 0.000000e+00 : f32
+  %0 = memref.tensor_load %arg0 : memref<1x28x28x1xf32>
+  %1 = linalg.pad_tensor %0 low[1, 1, 1, 2] high[0, 2, 2, 0]  {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):  // no predecessors
+    linalg.yield %cst : f32
+  } : tensor<1x28x28x1xf32> to tensor<2x31x31x3xf32>
+  %2 = memref.buffer_cast %1 : memref<2x31x31x3xf32>
+  return %2 : memref<2x31x31x3xf32>
+}
+
+// CHECK:       linalg.fill
+// CHECK:       linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+
+// -----
+
+// CHECK-DAG:   #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG:   #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 + 1, d1 + 2, d2 + 2)>
+// CHECK-LABEL: func @pad_tensor_no_memrefs
+func @pad_tensor_no_memrefs(%arg0: tensor<1x28x28xf32>) -> tensor<2x32x32xf32> {
+  %cst = constant 0.000000e+00 : f32
+  %0 = linalg.pad_tensor %arg0 low[1, 2, 2] high[0, 2, 2]  {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index):  // no predecessors
+    linalg.yield %cst : f32
+  } : tensor<1x28x28xf32> to tensor<2x32x32xf32>
+  return %0 : tensor<2x32x32xf32>
+}
+
+// CHECK:       linalg.fill
+// CHECK:       linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
+
+// -----
+
+// CHECK-DAG:   #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG:   #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 + 2, d2 + 2, d3)>
+// CHECK-LABEL: func @pad_tensor_detailed
+func @pad_tensor_detailed(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
+  %cst = constant 0.000000e+00 : f32
+  %0 = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]  {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):  // no predecessors
+    linalg.yield %cst : f32
+  } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
+  return %0 : tensor<1x32x32x1xf32>
+}
+
+// CHECK:      %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32>
+// CHECK:      %[[CTE:.+]] = constant 0.000000e+00 : f32
+// CHECK:      %[[TMP:.+]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32>
+// CHECK:      %[[R1c:.+]] = linalg.fill
+// CHECK:      %[[R2c:.+]] = linalg.generic
+// CHECK-SAME:   indexing_maps = [#[[$MAP4]], #[[$MAP5]]]
+// CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK:        ins(%arg0 : tensor<1x28x28x1xf32>) outs(%1 : tensor<1x32x32x1xf32>)
+// CHECK:      ^bb0(%[[VAL:.+]]: f32, %arg2: f32)
+// CHECK:        linalg.yield %[[VAL]] : f32
+// CHECK:      return %[[R2c:.+]]

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index c8ef2d419065..181e1d93b250 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -87,6 +87,10 @@ struct TestLinalgTransforms
   Option<int> testHoistPadding{*this, "test-hoist-padding",
                                llvm::cl::desc("Test hoist padding"),
                                llvm::cl::init(0)};
+  Option<bool> testTransformPadTensor{
+      *this, "test-transform-pad-tensor",
+      llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
+      llvm::cl::init(false)};
   ListOption<int64_t> tileSizesForPadding{
       *this, "tile-sizes-for-padding",
       llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
@@ -508,6 +512,12 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext());
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
   RewritePatternSet foldPattern(funcOp.getContext());
   foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
@@ -583,6 +593,8 @@ void TestLinalgTransforms::runOnFunction() {
     return applyVectorTransferForwardingPatterns(getFunction());
   if (testGenericToVectorPattern)
     return applyLinalgToVectorPatterns(getFunction());
+  if (testTransformPadTensor)
+    return applyPadTensorToGenericPatterns(getFunction());
   if (testAffineMinSCFCanonicalizationPatterns)
     return applyAffineMinSCFCanonicalizationPatterns(getFunction());
   if (testTileAndPadPattern)


        


More information about the Mlir-commits mailing list