[Mlir-commits] [mlir] 65bdedd - [mlir] Bubble up tensor.extract_slice above linalg operation

Okwan Kwon llvmlistbot at llvm.org
Thu Mar 31 09:58:35 PDT 2022


Author: Okwan Kwon
Date: 2022-03-31T16:48:38Z
New Revision: 65bdeddb1e5c0d27be0397379131b2d712c7a227

URL: https://github.com/llvm/llvm-project/commit/65bdeddb1e5c0d27be0397379131b2d712c7a227
DIFF: https://github.com/llvm/llvm-project/commit/65bdeddb1e5c0d27be0397379131b2d712c7a227.diff

LOG: [mlir] Bubble up tensor.extract_slice above linalg operation

Bubble up extract_slice above Linalg operation.

A sequence of operations

    %0 = linalg.<op> ... arg0, arg1, ...
    %1 = tensor.extract_slice %0 ...

can be replaced with

    %0 = tensor.extract_slice %arg0
    %1 = tensor.extract_slice %arg1
    %2 = linalg.<op> ... %0, %1, ...

This results in the reduce computation of the linalg operation.

The implementation uses the tiling utility functions. One difference
from the tiling process is that we don't need to insert the checking
code for the out-of-bound accesses. The use of the slice itself
represents that the code writer is sure about the boundary condition.
To avoid adding the boundary condtion check code, `omitPartialTileCheck`
is introduced for the tiling utility functions.

Differential Revision: https://reviews.llvm.org/D122437

Added: 
    mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
    mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ab892fd1fcb85..5b6f99f74a38e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -119,6 +119,9 @@ void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
 /// Patterns that are used to inline constant operands into linalg generic ops.
 void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
 
+/// Patterns that are used to bubble up extract slice op above linalg op.
+void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
+
 /// Options that control fusion of elementwise operations.
 struct LinalgElementwiseFusionOptions {
   /// Enable fusion of reshapes into the shape with elementwise operations. By

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 4f991588d1d43..4f707dd85dc18 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -166,16 +166,21 @@ SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc, ValueRange ivs,
 
 /// Creates an extract_slice/subview op for a single `valueToTile` with
 /// `builder`. This new operation extracts a tile of `valueToTile`, starting
-/// at offsets `lbs` and with sizes `subShapeSizes`.
+/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
+/// controls whether to omit the partial/boundary tile condition check in cases
+/// where we statically know that it is unnecessary.
 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
                      ValueRange tileSizes, AffineMap map, ValueRange lbs,
-                     ValueRange ubs, ValueRange subShapeSizes);
+                     ValueRange ubs, ValueRange subShapeSizes,
+                     bool omitPartialTileCheck);
 
 /// Creates extract_slice/subview ops for all `valuesToTile` of the given
 /// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
 /// nest for tiling with the given induction variables `ivs` and tile sizes
 /// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the
-/// implicit loops in `linalgOp`.
+/// implicit loops in `linalgOp`. `omitPartialTileCheck` controls whether to
+/// omit the partial/boundary tile condition check in cases where we statically
+/// know that it is unnecessary.
 ///
 /// Note that a constant zero in `tileSizes` means no tiling at that implicit
 /// loop. The number of non-zero values in `tileSizes` should be equal to the
@@ -184,7 +189,8 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
                                       LinalgOp linalgOp,
                                       ArrayRef<Value> valuesToTile,
                                       ValueRange ivs, ValueRange tileSizes,
-                                      ArrayRef<Value> sizeBounds);
+                                      ArrayRef<Value> sizeBounds,
+                                      bool omitPartialTileCheck);
 
 /// Add the tile loop induction variables `ivs` to the IndexOp results found in
 /// the body of the `tiledOp` to account for the tile offset.

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 82a33b0654402..87ac693492113 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -480,7 +480,7 @@ AffineMap inversePermutation(AffineMap map);
 /// ```mlir
 ///    affine_map<(d0, d1) -> (d0, 0, 0)>
 /// ```
-AffineMap inverseAndBroadcastProjectedPermuation(AffineMap map);
+AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map);
 
 /// Concatenates a list of `maps` into a single AffineMap, stepping over
 /// potentially empty maps. Assumes each of the underlying map has 0 symbols.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
new file mode 100644
index 0000000000000..53200d86511ce
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
@@ -0,0 +1,139 @@
+//===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns that transforms linalg.<op> +
+// tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
+// the computation for the linalg op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Bubble up extract_slice above Linalg operation.
+///
+/// A sequence of operations
+///
+/// ```mlir
+/// %0 = linalg.<op> ... arg0, arg1, ...
+/// %1 = tensor.extract_slice %0 ...
+/// ```
+///
+/// can be replaced with
+///
+/// ```mlir
+/// %0 = tensor.extract_slice %arg0
+/// %1 = tensor.extract_slice %arg1
+/// %2 = linalg.<op> ... %0, %1, ...
+/// ```
+///
+/// This results in the reduce computation of the linalg operation.
+///
+struct BubbleUpExtractSliceOpPattern
+    : OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const final {
+    Value source = sliceOp.source();
+    auto linalgOp = source.getDefiningOp<LinalgOp>();
+    if (!linalgOp) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "expected source to be linalg op");
+    }
+
+    // TODO: we might relax this if we want heuristics to detect that all uses
+    // are small portion of the output.
+    if (!linalgOp->hasOneUse()) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "expected single use of linalg op");
+    }
+
+    if (linalgOp.getNumOutputs() != 1) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "expected single output of linalg op");
+    }
+
+    if (!linalgOp.hasTensorSemantics()) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "expected tensor of linalg op");
+    }
+
+    if (!sliceOp.hasUnitStride())
+      return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
+
+    OpOperand *outOperand = linalgOp.getOutputOperand(0);
+    AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
+    if (!indexingMap.isProjectedPermutation()) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "expected a projected permutation for output");
+    }
+
+    auto linalgLoc = linalgOp.getLoc();
+    auto allShapeSizes =
+        linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
+    AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
+    if (!shapeSizesToLoopsMap) {
+      return rewriter.notifyMatchFailure(
+          linalgOp, "failed to get loops map from shape sizes");
+    }
+    auto sizeBounds = applyMapToValues(rewriter, linalgLoc,
+                                       shapeSizesToLoopsMap, allShapeSizes);
+
+    auto sliceLoc = sliceOp.getLoc();
+    auto offsetVals = getValueOrCreateConstantIndexOp(
+        rewriter, sliceLoc, sliceOp.getMixedOffsets());
+    auto sizeVals = getValueOrCreateConstantIndexOp(rewriter, sliceLoc,
+                                                    sliceOp.getMixedSizes());
+
+    // The offsets and sizes from the slice operation only give you the tile
+    // size of the output. Use that compute the tile sizes and offsets of the
+    // loops. For loops not used to access the output, set the tile sizes to
+    // loop bounds and set the offset to 0.
+    Value zero = rewriter.create<arith::ConstantIndexOp>(linalgLoc, 0);
+    SmallVector<Value, 4> tileOffsets(sizeBounds.size(), zero);
+    SmallVector<Value, 4> tileSizes = sizeBounds;
+    for (auto const &result : enumerate(indexingMap.getResults())) {
+      unsigned position = result.value().cast<AffineDimExpr>().getPosition();
+      tileOffsets[position] = offsetVals[result.index()];
+      tileSizes[position] = sizeVals[result.index()];
+    }
+
+    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+
+    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+        rewriter, linalgLoc, linalgOp, valuesToTile, tileOffsets, tileSizes,
+        sizeBounds, /*omitPartialTileCheck=*/true);
+
+    SmallVector<Type, 4> resultTensorTypes;
+    for (OpOperand *opOperand : linalgOp.getOutputTensorOperands())
+      resultTensorTypes.push_back(
+          tiledOperands[opOperand->getOperandNumber()].getType());
+
+    Operation *newOp =
+        linalgOp.clone(rewriter, linalgLoc, resultTensorTypes, tiledOperands);
+    rewriter.replaceOp(sliceOp, newOp->getResults());
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateBubbleUpExtractSliceOpPatterns(
+    RewritePatternSet &patterns) {
+  auto *context = patterns.getContext();
+  patterns.add<BubbleUpExtractSliceOpPattern>(context);
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 77457dac3113e..19955d450f248 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRLinalgTransforms
+  BubbleUpExtractSlice.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   CodegenStrategy.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 94edb8b630876..066008c114525 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -142,9 +142,9 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
   clonedShapes.reserve(producer.getNumInputsAndOutputs());
 
   // Compute subranges for all tensor input/output operands.
-  clonedShapes.append(makeTiledShapes(b, loc, producer,
-                                      getTiledOperands(producer), ivs,
-                                      tileSizes, sizeBounds));
+  clonedShapes.append(makeTiledShapes(
+      b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds,
+      /**omitPartialTileCheck=*/false));
 
   // Iterate over the results in order.
   // Extract the subtensor type from the linearized range.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 3ce6570f4bf53..9c962eb4b02c0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -163,7 +163,8 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
   erase_value(tileIvs, nullptr);
   SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
   tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
-                                  tileSizes, producerLoopBounds);
+                                  tileSizes, producerLoopBounds,
+                                  /**omitPartialTileCheck=*/false);
 
   // Output fusion has to update the iteration arguments of the tile loop nest.
   // In particular, the iteration argument of the outermost tile loop needs to

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4f863298ba422..b9a069c1d1e50 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -178,8 +178,9 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
     SmallVector<Value> valuesToTile = operandValuesToUse;
     auto sizeBounds =
         applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
-    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
-        b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds);
+    SmallVector<Value, 4> tiledOperands =
+        makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes,
+                        sizeBounds, /*omitPartialTileCheck=*/false);
 
     // TODO: use an interface/adaptor to avoid leaking position in
     // `tiledOperands`.
@@ -325,9 +326,9 @@ static LogicalResult tilePadOp(RewriterBase &builder, tensor::PadOp op,
         // Note: The tensor::PadOp is located outside of the loop nest. It is
         // later moved inside by ExtractSliceOfPadTensorSwapPattern.
         auto map = AffineMap::getMultiDimIdentityMap(rank, b.getContext());
-        Value tiledOutput =
-            makeTiledShape(b, loc, newPadOp->getResult(0), tileSizes, map,
-                           offsets, allDims, sizes);
+        Value tiledOutput = makeTiledShape(
+            b, loc, newPadOp->getResult(0), tileSizes, map, offsets, allDims,
+            sizes, /*omitPartialTileCheck=*/false);
         auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
         assert(sliceOp && "expected ExtractSliceOp");
         // Insert the tile into the output tensor.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4660baadf7e25..845c6434994e7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -504,7 +504,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
     //   readType = VectorType::get({}, bbarg.getType());
     // } else {
     if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
-      map = inverseAndBroadcastProjectedPermuation(
+      map = inverseAndBroadcastProjectedPermutation(
           linalgOp.getTiedIndexingMap(opOperand));
       readType = VectorType::get(commonVectorShape,
                                  getElementTypeOrSelf(opOperand->get()));

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 0d2a61713a5ea..c4b73a61417a1 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -745,7 +745,8 @@ static Value fullyComposeAndAffineApply(OpBuilder &b, Location loc,
 
 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
                      ValueRange tileSizes, AffineMap map, ValueRange lbs,
-                     ValueRange ubs, ValueRange subShapeSizes) {
+                     ValueRange ubs, ValueRange subShapeSizes,
+                     bool omitPartialTileCheck) {
   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
   assert(shapedType && "only shaped types can be tiled");
   ArrayRef<int64_t> shape = shapedType.getShape();
@@ -773,7 +774,7 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
     auto m = map.getSubMap({r});
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n");
     auto offset = applyMapToValues(builder, loc, m, lbs).front();
-    offsets.push_back(offset);
+    offsets.push_back(getAsOpFoldResult(offset));
     auto closedIntSize =
         applyMapToValues(builder, loc, m, subShapeSizes).front();
     // Resulting size needs to be made half open interval again.
@@ -781,6 +782,17 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
     Value size =
         fullyComposeAndAffineApply(builder, loc, s0 + 1, closedIntSize);
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "makeTiledShape: new offset: " << offset << "\n");
+    strides.push_back(builder.getIndexAttr(1));
+
+    if (omitPartialTileCheck) {
+      // We statically know that the partial/boundary tile condition is
+      // unnecessary.
+      LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
+      sizes.push_back(getAsOpFoldResult(size));
+      continue;
+    }
 
     // The size of the subview / extract_slice should be trimmed to avoid
     // out-of-bounds accesses, unless:
@@ -829,12 +841,8 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
       size = builder.create<AffineMinOp>(loc, builder.getIndexType(), minMap,
                                          operands);
     }
-
-    sizes.push_back(size);
-    LLVM_DEBUG(llvm::dbgs()
-               << "makeTiledShape: new offset: " << offset << "\n");
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
-    strides.push_back(builder.getIndexAttr(1));
+    sizes.push_back(getAsOpFoldResult(size));
   }
 
   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
@@ -886,7 +894,8 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
                                       LinalgOp linalgOp,
                                       ArrayRef<Value> valuesToTile,
                                       ValueRange ivs, ValueRange tileSizes,
-                                      ArrayRef<Value> sizeBounds) {
+                                      ArrayRef<Value> sizeBounds,
+                                      bool omitPartialTileCheck) {
   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
                            [](Value v) { return !isZero(v); })) &&
@@ -921,7 +930,8 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
 
     tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs,
-                                         sizeBounds, subShapeSizes));
+                                         sizeBounds, subShapeSizes,
+                                         omitPartialTileCheck));
   }
 
   return tiledShapes;

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 32eb07cf1a2f4..c93b4c28d769a 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -679,7 +679,7 @@ AffineMap mlir::inversePermutation(AffineMap map) {
   return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
 }
 
-AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) {
+AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true));
   MLIRContext *context = map.getContext();
   AffineExpr zero = mlir::getAffineConstantExpr(0, context);

diff  --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
new file mode 100644
index 0000000000000..234c2b8fec30f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
@@ -0,0 +1,158 @@
+//RUN: mlir-opt -test-linalg-transform-patterns=test-bubble-up-extract-slice-op-pattern -split-input-file %s | FileCheck %s
+
+func @dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5:index) -> tensor<?x?xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+    outs(%arg0 : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0 [%arg2, %arg3] [%arg4, %arg5] [1, 1]
+    : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//      CHECK: func @dynamic
+//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1] : tensor<?xf32> to tensor<?xf32>
+//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE2]] : tensor<?x?xf32>)
+//      CHECK: return %[[GENERIC]] : tensor<?x?xf32>
+
+//-----
+
+func @static(%arg0: tensor<16x8xf32>, %arg1: tensor<8xf32>) -> tensor<4x2xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<16x8xf32>, tensor<8xf32>)
+    outs(%arg0 : tensor<16x8xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<16x8xf32>
+  %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1]
+    : tensor<16x8xf32> to tensor<4x2xf32>
+  return %1 : tensor<4x2xf32>
+}
+
+//      CHECK: func @static
+//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32>
+//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<8xf32> to tensor<2xf32>
+//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32>
+//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>)
+//      CHECK: return %[[GENERIC]] : tensor<4x2xf32>
+
+//-----
+
+func @mixed(%arg0: tensor<?x8xf32>, %arg1: tensor<8xf32>, %arg2: index, %arg3: index) -> tensor<?x2xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x8xf32>, tensor<8xf32>)
+    outs(%arg0 : tensor<?x8xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<?x8xf32>
+  %1 = tensor.extract_slice %0 [8, %arg2] [%arg3, 2] [1, 1]
+    : tensor<?x8xf32> to tensor<?x2xf32>
+  return %1 : tensor<?x2xf32>
+}
+
+//      CHECK: func @mixed
+//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor<?x8xf32> to tensor<?x2xf32>
+//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg2] [2] [1] : tensor<8xf32> to tensor<2xf32>
+//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor<?x8xf32> to tensor<?x2xf32>
+//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<?x2xf32>)
+//      CHECK: return %[[GENERIC]] : tensor<?x2xf32>
+
+//-----
+
+func @dynamic_to_static(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<4x2xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+    outs(%arg0 : tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+      %add = arith.addf %b0, %b1 : f32
+      linalg.yield %add : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1]
+    : tensor<?x?xf32> to tensor<4x2xf32>
+  return %1 : tensor<4x2xf32>
+}
+
+//      CHECK: func @dynamic_to_static
+//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<?x?xf32> to tensor<4x2xf32>
+//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<?xf32> to tensor<2xf32>
+//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<?x?xf32> to tensor<4x2xf32>
+//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>)
+//      CHECK: return %[[GENERIC]] : tensor<4x2xf32>
+
+//-----
+
+func @matmul_slice() -> tensor<2x2xf32> {
+    %lhs = arith.constant dense<1.0> : tensor<4x4xf32>
+    %rhs = arith.constant dense<1.0> : tensor<4x4xf32>
+    %dst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0], [8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]]> : tensor<4x4xf32>
+    %0 = linalg.matmul ins(%lhs, %rhs : tensor<4x4xf32>, tensor<4x4xf32>) outs(%dst : tensor<4x4xf32>) -> tensor<4x4xf32>
+    %1 = tensor.extract_slice %0[1,1][2,2][1,1] : tensor<4x4xf32> to tensor<2x2xf32>
+    return %1 : tensor<2x2xf32>
+}
+
+// CHECK: func @matmul_slice
+// CHECK: %[[SLICE0:.+]] = arith.constant dense<1.000000e+00> : tensor<2x4xf32>
+// CHECK: %[[SLICE1:.+]] = arith.constant dense<1.000000e+00> : tensor<4x2xf32>
+// CHECK: %[[SLICE3:.+]] = tensor.extract_slice %[[CST:.+]][1, 1] [2, 2] [1, 1] : tensor<4x4xf32> to tensor<2x2xf32>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[SLICE0]], %[[SLICE1]] : tensor<2x4xf32>, tensor<4x2xf32>) outs(%[[SLICE3]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+// CHECK: return %[[MATMUL]] : tensor<2x2xf32>
+
+//-----
+
+func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>) -> tensor<1x32x32x16xf32> {
+  %c112 = arith.constant 112 : index
+  %c32 = arith.constant 32 : index
+  %c16 = arith.constant 16 : index
+  %c8 = arith.constant 8 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f32
+
+  %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+
+  %conv = linalg.conv_2d_nhwc_hwcf
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+    ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>)
+    outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+
+  %slice = tensor.extract_slice %conv [0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32>
+
+  return %slice : tensor<1x32x32x16xf32>
+}
+
+// CHECK: func @conv_slice
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[0, 128, 128, 0] [1, 65, 65, 3] [1, 1, 1, 1] : tensor<1x225x225x3xf32> to tensor<1x65x65x3xf32>
+// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[0, 0, 0, 16] [3, 3, 3, 16] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x16xf32>
+// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[INIT]][0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST:.+]] : f32) outs(%[[SLICE2]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[SLICE0]], %[[SLICE1]] : tensor<1x65x65x3xf32>, tensor<3x3x3x16xf32>) outs(%[[FILL]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
+// CHECK: return %[[CONV]] : tensor<1x32x32x16xf32>

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index e6786b24f5939..06d1d81a38592 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -132,6 +132,11 @@ struct TestLinalgTransforms
       llvm::cl::desc("Specify the type of loops to generate: for, parallel or "
                      "tiled_loop"),
       llvm::cl::init("for")};
+  Option<bool> testBubbleUpExtractSliceOpPattern{
+      *this, "test-bubble-up-extract-slice-op-pattern",
+      llvm::cl::desc("Test rewrite of linalgOp + extract_slice into "
+                     "extract_slice + linalgOp"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -635,6 +640,12 @@ static void applySplitReduction(FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyBubbleUpExtractSliceOpPattern(FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  populateBubbleUpExtractSliceOpPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   auto lambda = [&](void *) {
@@ -686,6 +697,8 @@ void TestLinalgTransforms::runOnOperation() {
                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
   if (testSplitReduction)
     return applySplitReduction(getOperation());
+  if (testBubbleUpExtractSliceOpPattern)
+    return applyBubbleUpExtractSliceOpPattern(getOperation());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list