[Mlir-commits] [mlir] 4064b6a - [mlir][linalg] Tile PadTensorOp

Matthias Springer llvmlistbot at llvm.org
Wed Jul 14 18:42:57 PDT 2021


Author: Matthias Springer
Date: 2021-07-15T10:42:32+09:00
New Revision: 4064b6a36348a0405a52b690437a1ae3004beec1

URL: https://github.com/llvm/llvm-project/commit/4064b6a36348a0405a52b690437a1ae3004beec1
DIFF: https://github.com/llvm/llvm-project/commit/4064b6a36348a0405a52b690437a1ae3004beec1.diff

LOG: [mlir][linalg] Tile PadTensorOp

Tiling can be enabled with `linalg-tile-pad-tensor-ops`. Only scf::ForOp can be generated at the moment.

Differential Revision: https://reviews.llvm.org/D105460

Added: 
    mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index b6420f7b104bc..4aa7792eca901 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -152,6 +152,18 @@ transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
   }
 }
 
+// Insert a tile `source` into the destination tensor `dest`. The position at
+// which the tile is inserted (as well as size of tile) is taken from a given
+// ExtractSliceOp `sliceOp`.
+static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
+                                   tensor::ExtractSliceOp sliceOp, Value source,
+                                   Value dest) {
+  return b.create<tensor::InsertSliceOp>(
+      loc, sliceOp.source().getType(), source, dest, sliceOp.offsets(),
+      sliceOp.sizes(), sliceOp.strides(), sliceOp.static_offsets(),
+      sliceOp.static_sizes(), sliceOp.static_strides());
+}
+
 template <typename LoopTy>
 static Optional<TiledLinalgOp>
 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
@@ -259,11 +271,8 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
       // `tiledOperands`.
       Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
       if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
-        tensorResults.push_back(b.create<tensor::InsertSliceOp>(
-            loc, sliceOp.source().getType(), res->getResult(resultIdx),
-            sliceOp.source(), sliceOp.offsets(), sliceOp.sizes(),
-            sliceOp.strides(), sliceOp.static_offsets(), sliceOp.static_sizes(),
-            sliceOp.static_strides()));
+        tensorResults.push_back(insertSliceIntoTensor(
+            b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source()));
       } else {
         tensorResults.push_back(res->getResult(resultIdx));
       }
@@ -341,6 +350,86 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
   return llvm::None;
 }
 
+/// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp`
+/// and `loopNest` are output parameters that return the new (tiled) PadTensorOp
+/// and the loop nest.
+static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
+                                     PadTensorOp &newPadOp, LoopNest &loopNest,
+                                     const LinalgTilingOptions &options) {
+  // Can tile only PadTensorOp that have an output operand.
+  if (!op.output())
+    return failure();
+
+  Location loc = op.getLoc();
+  OpBuilder::InsertionGuard g(builder);
+  builder.setInsertionPoint(op);
+
+  // Clone PadTensorOp so that the existing op can be replaced more easily.
+  newPadOp = cast<PadTensorOp>(builder.clone(*op.getOperation()));
+  // Get rank and tile sizes.
+  int64_t rank = op.getResultType().getRank();
+  SmallVector<Value> tileSizes =
+      options.tileSizeComputationFunction(builder, op);
+  assert(static_cast<int64_t>(tileSizes.size()) == rank);
+  // Compute lower and upper bounds of the loop nest.
+  SmallVector<Value> lbs, dims, steps;
+  for (int64_t i = 0; i < rank; ++i) {
+    if (!isZero(tileSizes[i])) {
+      lbs.push_back(builder.create<ConstantIndexOp>(loc, 0));
+      dims.push_back(builder.create<tensor::DimOp>(loc, op.output(), i));
+      steps.push_back(tileSizes[i]);
+    }
+  }
+  // Generate loop nest: One loop per dimension.
+  loopNest = mlir::scf::buildLoopNest(
+      builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(op.output()),
+      [&](OpBuilder &b, Location loc, ValueRange localIvs,
+          ValueRange iterArgs) -> scf::ValueVector {
+        // Compute offsets and sizes of ExtractSliceOp.
+        SmallVector<Value> offsets =
+            computeTileOffsets(b, loc, localIvs, tileSizes);
+        SmallVector<Value> sizes =
+            computeTileSizes(b, loc, localIvs, tileSizes, dims);
+        // Create ExtractSliceOp: Extract a tile from the PadTensorOp.
+        // Note: The PadTensorOp is located outside of the loop nest. It is
+        // later moved inside by ExtractSliceOfPadTensorSwapPattern.
+        auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext());
+        Value tiledOutput = makeTiledShape(b, loc, newPadOp->getResult(0),
+                                           tileSizes, map, offsets, sizes);
+        auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
+        assert(sliceOp && "expected ExtractSliceOp");
+        // Insert the tile into the output tensor.
+        Value yieldValue =
+            insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]);
+        return scf::ValueVector({yieldValue});
+      });
+  return success();
+}
+
+namespace {
+struct PadTensorOpTilingPattern : public OpRewritePattern<PadTensorOp> {
+  PadTensorOpTilingPattern(MLIRContext *ctx, LinalgTilingOptions opt)
+      : OpRewritePattern<PadTensorOp>(ctx), options(opt) {}
+
+  LogicalResult matchAndRewrite(PadTensorOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op->hasAttr(LinalgTransforms::kLinalgTransformMarker))
+      return failure();
+    PadTensorOp newPadOp;
+    LoopNest loopNest;
+    if (failed(tilePadTensorOp(rewriter, op, newPadOp, loopNest, options)))
+      return failure();
+    newPadOp->setAttr(LinalgTransforms::kLinalgTransformMarker,
+                      rewriter.getUnitAttr());
+    // Replace all uses of the original PadTensorOp.
+    rewriter.replaceOp(op, loopNest.getResults()[0]);
+    return success();
+  }
+
+  LinalgTilingOptions options;
+};
+} // namespace
+
 namespace {
 /// Helper classes for type list expansion.
 template <typename... OpTypes>
@@ -408,6 +497,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
+  PadTensorOp::getCanonicalizationPatterns(patterns, ctx);
   ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
   CanonicalizationPatternList<
 #define GET_OP_LIST
@@ -422,6 +512,8 @@ static void insertTilingPatterns(RewritePatternSet &patterns,
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
                      >::insert(patterns, options);
+  patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options);
+  patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext());
 }
 
 static void

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 8d919e2d7e1bb..d582c5328e4ea 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
@@ -550,7 +551,7 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
     if (!isTiled(map.getSubMap({r}), tileSizes)) {
       offsets.push_back(builder.getIndexAttr(0));
       Value dim = createOrFoldDimOp(builder, loc, valueToTile, r);
-      sizes.push_back(dim);
+      sizes.push_back(getAsOpFoldResult(dim));
       strides.push_back(builder.getIndexAttr(1));
       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
       continue;

diff  --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
index f8b49c14ff7ef..9698ca62e0656 100644
--- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
@@ -105,14 +105,12 @@ func @matmul_partially_padded_tensors(
 //      CHECK-1DIM-TILE:        %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xi32>) {
 //      CHECK-1DIM-TILE:            %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xi32>) {
 //      CHECK-1DIM-TILE:                %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x8xi8> to tensor<?x8xi8>
-//      CHECK-1DIM-TILE:                %[[sTAc:.*]] = tensor.cast %[[sTA]] : tensor<?x8xi8> to tensor<?x?xi8>
 //      CHECK-1DIM-TILE:                %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8>
-//      CHECK-1DIM-TILE:                %[[sTBc:.*]] = tensor.cast %[[sTB]] : tensor<8x?xi8> to tensor<?x?xi8>
 //      CHECK-1DIM-TILE:                %[[sTC:.*]] = tensor.extract_slice %[[TC1]][{{.*}}] : tensor<?x?xi32> to tensor<?x?xi32>
-//      CHECK-1DIM-TILE:                %[[pA:.*]] = linalg.pad_tensor %[[sTAc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
-//      CHECK-1DIM-TILE:                   : tensor<?x?xi8> to tensor<2x8xi8>
-//      CHECK-1DIM-TILE:                %[[pB:.*]] = linalg.pad_tensor %[[sTBc]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
-//      CHECK-1DIM-TILE:                   : tensor<?x?xi8> to tensor<8x3xi8>
+//      CHECK-1DIM-TILE:                %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                   : tensor<?x8xi8> to tensor<2x8xi8>
+//      CHECK-1DIM-TILE:                %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                   : tensor<8x?xi8> to tensor<8x3xi8>
 //      CHECK-1DIM-TILE:                %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<?x?xi32> to tensor<2x3xi32>
 //      CHECK-1DIM-TILE:               %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)

diff  --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
new file mode 100644
index 0000000000000..10f4dc3e34c30
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -cse -split-input-file | \
+// RUN: FileCheck %s -check-prefix=TILE2
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -cse -split-input-file | \
+// RUN: FileCheck %s -check-prefix=TILE1
+
+// TILE2-LABEL: func @dynamic_pad_tensor(
+//  TILE2-SAME:     %[[IN:.*]]: tensor<?x?xf32>, %[[OUT:.*]]: tensor<?x?xf32>
+//   TILE2-DAG:   %[[C0:.*]] = constant 0 : index
+//   TILE2-DAG:   %[[C1:.*]] = constant 1 : index
+//   TILE2-DAG:   %[[C2:.*]] = constant 2 : index
+//   TILE2-DAG:   %[[C3:.*]] = constant 3 : index
+//       TILE2:   %[[DIM0:.*]] = tensor.dim %[[OUT]], %[[C0]]
+//       TILE2:   %[[DIM1:.*]] = tensor.dim %[[OUT]], %[[C1]]
+//       TILE2:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[DIM0]] step %[[C2]]
+//       TILE2:     scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       TILE2:       %[[SWAP_RESULT:.*]] = scf.if
+//       TILE2:         tensor.generate
+//       TILE2:       else
+//       TILE2:         %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       TILE2:         %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]]
+//       TILE2:       tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       TILE2:   return %[[RESULT]]
+
+// TILE1-LABEL: func @dynamic_pad_tensor(
+//  TILE1-SAME:     %[[IN:.*]]: tensor<?x?xf32>, %[[OUT:.*]]: tensor<?x?xf32>
+//   TILE1-DAG:   %[[C0:.*]] = constant 0 : index
+//   TILE1-DAG:   %[[C1:.*]] = constant 1 : index
+//   TILE1-DAG:   %[[C3:.*]] = constant 3 : index
+//       TILE1:   %[[DIM1:.*]] = tensor.dim %[[OUT]], %[[C1]]
+//       TILE1:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       TILE1:     %[[DIM0:.*]] = tensor.dim %[[OUT]], %[[C0]]
+//       TILE1:     %[[SWAP_RESULT:.*]] = scf.if
+//       TILE1:       tensor.generate
+//       TILE1:     else
+//       TILE1:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       TILE1:       %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[3, %{{.*}}] high[{{.*}}, {{.*}}]
+//       TILE1:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [%[[DIM0]], {{.*}}] [1, 1]
+//       TILE1:   return %[[RESULT]]
+
+func @dynamic_pad_tensor(%input_tensor: tensor<?x?xf32>,
+                         %output_tensor: tensor<?x?xf32>,
+                         %pad_value: f32) -> tensor<?x?xf32> {
+  %0 = linalg.pad_tensor %input_tensor
+    low[3, 4] high[5, 3] into %output_tensor{
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad_value : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+// TILE2-LABEL: func @static_pad_tensor(
+//  TILE2-SAME:     %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<15x16xf32>
+//   TILE2-DAG:   %[[C0:.*]] = constant 0 : index
+//   TILE2-DAG:   %[[C2:.*]] = constant 2 : index
+//   TILE2-DAG:   %[[C3:.*]] = constant 3 : index
+//   TILE2-DAG:   %[[C15:.*]] = constant 15 : index
+//   TILE2-DAG:   %[[C16:.*]] = constant 16 : index
+//       TILE2:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]]
+//       TILE2:     scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       TILE2:       %[[SWAP_RESULT:.*]] = scf.if
+//       TILE2:         tensor.generate
+//       TILE2:       else
+//       TILE2:         %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       TILE2:         %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]]
+//       TILE2:       tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       TILE2:   return %[[RESULT]]
+
+
+// TILE1-LABEL: func @static_pad_tensor(
+//  TILE1-SAME:     %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<15x16xf32>
+//   TILE1-DAG:   %[[C0:.*]] = constant 0 : index
+//   TILE1-DAG:   %[[C3:.*]] = constant 3 : index
+//   TILE1-DAG:   %[[C16:.*]] = constant 16 : index
+//       TILE1:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       TILE1:     %[[SWAP_RESULT:.*]] = scf.if
+//       TILE1:       tensor.generate
+//       TILE1:     else
+//       TILE1:       %[[SLICE:.*]] = tensor.extract_slice %[[IN]][0, {{.*}}] [7, {{.*}}] [1, 1]
+//       TILE1:       %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[3, %{{.*}}] high[5, {{.*}}]
+//       TILE1:     tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][0, {{.*}}] [15, {{.*}}] [1, 1]
+//       TILE1:   return %[[RESULT]]
+
+func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
+                        %output_tensor: tensor<15x16xf32>,
+                        %pad_value: f32) -> tensor<15x16xf32> {
+  %0 = linalg.pad_tensor %input_tensor
+    low[3, 4] high[5, 3] into %output_tensor {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad_value : f32
+    } : tensor<7x9xf32> to tensor<15x16xf32>
+  return %0 : tensor<15x16xf32>
+}


        


More information about the Mlir-commits mailing list