[Mlir-commits] [mlir] 08a1b07 - [mlir] Partially port splitting transform to TilingInterface

Alex Zinenko llvmlistbot at llvm.org
Wed Jul 27 01:52:23 PDT 2022


Author: Alex Zinenko
Date: 2022-07-27T08:52:08Z
New Revision: 08a1b07e7c19d14896f8be501c763ba8aff5b427

URL: https://github.com/llvm/llvm-project/commit/08a1b07e7c19d14896f8be501c763ba8aff5b427
DIFF: https://github.com/llvm/llvm-project/commit/08a1b07e7c19d14896f8be501c763ba8aff5b427.diff

LOG: [mlir] Partially port splitting transform to TilingInterface

The structured op splitting transformation is conceptually similar to
tiling in the sense that it decomposes the iteration space of the
original op into several parts. Therefore, it is possible to implement
it using the TilingInterface to operate on iteration spaces and their
parts. However, the implementation also requires to pass updated input
operands, which is not supported by the interface, so the implementation
currently remains Linalg-specific.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D129564

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Split.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
    mlir/test/Dialect/Linalg/transform-op-split.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 569faf383a0ee..e74905edadffc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallSet.h"
@@ -134,9 +135,10 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
 ///
 /// Note that there is no simplification other than constant propagation applied
 /// to slice extraction and insertion.
-std::pair<LinalgOp, LinalgOp> splitOp(RewriterBase &rewriter, LinalgOp op,
-                                      unsigned dimension,
-                                      OpFoldResult splitPoint);
+std::pair<TilingInterface, TilingInterface> splitOp(RewriterBase &rewriter,
+                                                    TilingInterface op,
+                                                    unsigned dimension,
+                                                    OpFoldResult splitPoint);
 
 /// Perform standalone tiling of a single LinalgOp by `tileSizes`.
 /// and permute the loop nest according to `interchangeVector`

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 6905b49535f41..beee22c493f56 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -177,12 +177,6 @@ bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
 bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
                    Value consumedView, LinalgOp producer);
 
-/// Creates either a memref.subview or a tensor.extract_slice with the given
-/// offsets/sizes/strides based on the type of `value`.
-Value createSlice(OpBuilder &builder, Location loc, Value value,
-                  ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
-                  ArrayRef<OpFoldResult> strides);
-
 /// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
 /// tile size is zero (i.e., no tiling), the corresponding offset is also zero.
 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index f3fdc30168b28..ee998530d3d8e 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -76,8 +76,11 @@ def TilingInterface : OpInterface<"TilingInterface"> {
             operation is to be inserted into. The type of the `dest`
             Values is same as the types returned by
             `getDestinationOperands` method.
-          - `offsets` provides the offset of the tile within the
-            iteration space
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
           - `sizes` provides the size of the tile.
           - `tileDestOperands` specifies whether to also tile `dest` operands
             or not. Avoiding tiling `dest` operands can be useful for 
@@ -141,8 +144,11 @@ def TilingInterface : OpInterface<"TilingInterface"> {
             operation is to be inserted into. The type of the `dest`
             Values is same as the types returned by
             `getDestinationOperands` method.
-          - `offsets` provides the offset of the tile within the
-            iteration space
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
           - `sizes` provides the size of the tile.
           - `tileDestOperands` specifies whether to also tile `dest` operands
             or not. Avoiding tiling `dest` operands can be useful for 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1cd57176d398f..d8b9187bb6a3d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -740,8 +740,9 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
     }
 
     rewriter.setInsertionPoint(linalgOp);
-    std::tie(first.emplace_back(), second.emplace_back()) =
-        linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
+    std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
+        rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+        getDimension(), std::get<1>(pair));
   }
 
   results.set(getFirst().cast<OpResult>(), first);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index 8849e7f964b3a..d735c671ab49c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -8,147 +8,124 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/TilingInterface.h"
 
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
 
-/// Extract the slices of `operands` supplied to the given operation `op` such
-/// that they are sufficient to execute the op for the subset of its iteration
-/// space defined by `splitIterationSpace`. The subset is a part of the original
-/// iteration space split at the given `dimension`. If `offset` is provided, it
-/// indicates the iterator value at which the dimension has been split and
-/// requires the "high" part starting at the given offset of the operands to be
-/// generated; otherwise, the "low" part with no offset is generated. Note that
-/// `operands` are not necessarily the actual operands of `op`.
-static SmallVector<Value>
-getOperandSlices(RewriterBase &b, Location loc, LinalgOp op,
-                 ValueRange splitIterationSpace, ValueRange operands,
-                 unsigned dimension, Value offset = nullptr) {
-  SmallVector<Value> slices;
-  slices.reserve(op.getNumInputsAndOutputs());
-  for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
-    auto type = opOperand->get().getType().dyn_cast<ShapedType>();
-    AffineMap indexing = op.getTiedIndexingMap(opOperand);
-
-    // If the type is not sliceable, or the slice is requested along the
-    // dimension that is not used in indexing this type, just use the entire
-    // operand.
-    if (!type || dimension >= indexing.getNumDims() ||
-        !indexing.isFunctionOfDim(dimension)) {
-      slices.push_back(opOperand->get());
-      continue;
-    }
-
-    SmallVector<OpFoldResult> sizes;
-    sizes.reserve(indexing.getNumResults());
-    for (AffineExpr dimIndexing : indexing.getResults()) {
-      sizes.push_back(makeComposedFoldedAffineApply(
-          b, loc, dimIndexing,
-          getAsOpFoldResult(llvm::to_vector(splitIterationSpace))));
-    }
-    SmallVector<OpFoldResult> offsets(type.getRank(), b.getIndexAttr(0));
-    SmallVector<OpFoldResult> strides(type.getRank(), b.getIndexAttr(1));
-
-    if (offset) {
-      offsets[dimension] = offset;
-      offsets = applyMapToValues(b, loc, indexing, offsets);
-    }
-
-    slices.push_back(createSlice(b, loc,
-                                 operands[opOperand->getOperandNumber()],
-                                 offsets, sizes, strides));
-  }
-
-  return slices;
-}
-
 /// Creates a part of the given `op` split along the iteration space `dimension`
 /// with the given `size` and an optional `offset` (default 0). Makes slices
 /// of operands, using the input operands of the original op and the output
-/// operands provided as `resultOperands`. Expects `splitIterationSpace` to be
-/// a list of values representing the shape of the iteration space of the
-/// original op and updates it to be the iteration space of the curent part.
-/// Returns the split-out op as well as the output operand values updated with
-/// the partial results produced by this op through `results`.
-static LinalgOp
-createSplitPart(RewriterBase &b, Location loc, LinalgOp op,
-                ValueRange resultOperands,
-                llvm::MutableArrayRef<Value> splitIterationSpace,
-                unsigned dimension, OpFoldResult size,
-                SmallVectorImpl<Value> &results, Value offset = nullptr) {
-  ImplicitLocOpBuilder implicit(op.getLoc(), b);
-  splitIterationSpace[dimension] = materializeOpFoldResult(implicit, size);
-  SmallVector<Value> operands = llvm::to_vector(
-      llvm::map_range(op.getInputOperands(),
-                      [](OpOperand *opOperand) { return opOperand->get(); }));
-  llvm::append_range(operands, resultOperands);
-  operands = getOperandSlices(b, loc, op, splitIterationSpace, operands,
-                              dimension, offset);
-  Operation *part =
-      op.clone(b, loc, getTensorOutputTypes(op, operands), operands);
-  results = insertSlicesBack(b, loc, op, operands, part->getResults());
-  return cast<LinalgOp>(part);
+/// operands provided as `resultOperands`. Expects `offsets` and `sizes` to
+/// define the shape of the iteration space of the original op. Returns the
+/// split-out op as well as the output operand values updated with the partial
+/// results produced by this op through `results`.
+static TilingInterface
+createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
+                ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+                ValueRange resultOperands, unsigned dimension,
+                OpFoldResult size, OpFoldResult offset,
+                SmallVectorImpl<Value> &results) {
+  // Iteration space of the current part.
+  SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(sizes);
+  SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(offsets);
+  sizesCopy[dimension] = size;
+  offsetsCopy[dimension] = offset;
+
+  // Create the part as it it were a single tile.
+  SmallVector<Operation *> tiled =
+      op.getTiledImplementation(b, resultOperands, offsetsCopy, sizesCopy,
+                                /*tileDestOperands=*/true);
+  assert(tiled.size() == 1 && "expected a single result from tiling");
+  auto part = cast<TilingInterface>(tiled.front());
+
+  // Insert the results back and populate the `results` list.
+  for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) {
+    SmallVector<OpFoldResult> resultOffsets, resultSizes;
+    if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy,
+                                        resultOffsets, resultSizes)))
+      return nullptr;
+    SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
+                                            b.getIndexAttr(1));
+    Value inserted = b.create<tensor::InsertSliceOp>(
+        loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes,
+        resultStrides);
+    results.push_back(inserted);
+  }
+
+  return part;
 }
 
-std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
-                                              LinalgOp op, unsigned dimension,
-                                              OpFoldResult splitPoint) {
+std::pair<TilingInterface, TilingInterface>
+linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
+                OpFoldResult splitPoint) {
+  // Compute the iteration space.
+  SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter);
+
   // Bail out on dimension overflow.
-  if (dimension >= op.getNumLoops())
-    return std::make_pair(op, LinalgOp());
-
-  // Compute the iteration space size as values.
-  SmallVector<Value, 4> allShapes =
-      op.createFlatListOfOperandDims(rewriter, op.getLoc());
-  AffineMap shapesToLoops = op.getShapesToLoopsMap();
-  SmallVector<Value, 4> iterationSpaceShapes =
-      applyMapToValues(rewriter, op.getLoc(), shapesToLoops, allShapes);
-
-  // Update the iteration space to have `splitPoint` as the size of `dimension`
-  // and use it to slice operands and results for a new, smaller instance of the
-  // `op`. Adjust the size if necessary to prevent overflows. Insert the partial
-  // results back.
-  OpFoldResult dimSize = getAsOpFoldResult(iterationSpaceShapes[dimension]);
+  if (dimension >= iterationSpace.size())
+    return std::make_pair(op, TilingInterface());
+
+  SmallVector<OpFoldResult> offsets =
+      getAsOpFoldResult(llvm::to_vector(llvm::map_range(
+          iterationSpace, [](const Range &range) { return range.offset; })));
+  SmallVector<OpFoldResult> sizes =
+      getAsOpFoldResult(llvm::to_vector(llvm::map_range(
+          iterationSpace, [](const Range &range) { return range.size; })));
+
+  // Adjust the split point so that it doesn't overflow the size.
+  AffineExpr d0, d1, d2;
+  bindDims(rewriter.getContext(), d0, d1, d2);
   OpFoldResult minSplitPoint = makeComposedFoldedAffineMin(
-      rewriter, op->getLoc(),
-      AffineMap::getMultiDimIdentityMap(/*numDims=*/2, rewriter.getContext()),
-      {splitPoint, dimSize});
-  SmallVector<Value> splitIterationSpace =
-      llvm::to_vector(iterationSpaceShapes);
-  SmallVector<Value> originalResults = llvm::to_vector(
-      llvm::map_range(op.getOutputOperands(),
-                      [](OpOperand *opOperand) { return opOperand->get(); }));
-  SmallVector<Value> firstResults;
-  LinalgOp first = createSplitPart(rewriter, op.getLoc(), op, originalResults,
-                                   splitIterationSpace, dimension,
-                                   minSplitPoint, firstResults);
-
-  // Update the iteration space to cover the remaining part of the original
-  // space, then create another instance of the `op` in that space. The size of
-  // the remaining part may become zero, but is never negative because of the
-  // adjustment above.
-  AffineExpr d0 = rewriter.getAffineDimExpr(0);
-  AffineExpr d1 = rewriter.getAffineDimExpr(1);
+      rewriter, op.getLoc(),
+      AffineMap::inferFromExprList(ArrayRef<AffineExpr>{d0, d1 + d2}).front(),
+      {splitPoint, offsets[dimension], sizes[dimension]});
+
+  // Compute the size of the second part. Return early if the second part would
+  // have an empty iteration space.
   OpFoldResult remainingSize = makeComposedFoldedAffineApply(
-      rewriter, op.getLoc(), d0 - d1, {dimSize, minSplitPoint});
+      rewriter, op.getLoc(), d0 + d1 - d2,
+      {iterationSpace[dimension].offset, iterationSpace[dimension].size,
+       minSplitPoint});
+  if (auto attr = remainingSize.dyn_cast<Attribute>()) {
+    if (attr.cast<IntegerAttr>().getValue().isZero())
+      return {op, TilingInterface()};
+  }
+
+  // Create the first part.
+  SmallVector<Value> firstResults;
+  TilingInterface firstPart = createSplitPart(
+      rewriter, op.getLoc(), op, offsets, sizes,
+      op.getDestinationOperands(rewriter), dimension, minSplitPoint,
+      getAsOpFoldResult(iterationSpace[dimension].offset), firstResults);
+
+  // Need to pretend that the original op now takes as operands firstResults,
+  // otherwise tiling interface implementation will take the wrong value to
+  // produce data tiles.
+  rewriter.updateRootInPlace(op, [&]() {
+    unsigned numTotalOperands = op->getNumOperands();
+    unsigned numOutputOperands = firstResults.size();
+    op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
+                    firstResults);
+  });
+
+  // Create the second part.
+  OpFoldResult totalOffset = makeComposedFoldedAffineApply(
+      rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint});
   SmallVector<Value> secondResults;
-  ImplicitLocOpBuilder implicit(op.getLoc(), rewriter);
-  Value splitPointValue = materializeOpFoldResult(implicit, minSplitPoint);
-  LinalgOp second = createSplitPart(
-      rewriter, op.getLoc(), op, firstResults, splitIterationSpace, dimension,
-      remainingSize, secondResults, splitPointValue);
-
-  // Fixup the linalg.index results in the second part.
-  SmallVector<Value> ivAdditions;
-  ivAdditions.resize(splitIterationSpace.size());
-  ivAdditions[dimension] = splitPointValue;
-  linalg::offsetIndices(rewriter, cast<LinalgOp>(second), ivAdditions);
+  TilingInterface secondPart =
+      createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
+                      dimension, remainingSize, totalOffset, secondResults);
 
   // Replace the original op with the results of the two newly created ops.
   rewriter.replaceOp(op, secondResults);
-  return std::make_pair(first, second);
+  return {firstPart, secondPart};
 }

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 6007ccde25442..d26eb484d0955 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -913,21 +913,6 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
   return sliceOp->getResult(0);
 }
 
-Value createSlice(OpBuilder &builder, Location loc, Value value,
-                  ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
-                  ArrayRef<OpFoldResult> strides) {
-  if (value.getType().isa<MemRefType>()) {
-    return builder.create<memref::SubViewOp>(loc, value, offsets, sizes,
-                                             strides);
-  }
-
-  // This intentionally does not attempt to compose the extractslice operations.
-  assert(value.getType().isa<RankedTensorType>() &&
-         "expected a ranked tensor type");
-  return builder.create<tensor::ExtractSliceOp>(loc, value, offsets, sizes,
-                                                strides);
-}
-
 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
                                       ValueRange ivs, ValueRange tileSizes) {
   SmallVector<Value> offsets;

diff  --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index 2492f097ca21e..0651e76f590d0 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -49,18 +49,17 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
 
   // CHECK:      %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
   // CHECK:      scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
-  // CHECK:        %[[INSLICE_1:.+]] = tensor.extract_slice %[[IN]][%[[I1]], 0] [2, 34] [1, 1]
   // CHECK:        %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1]
 
-  // CHECK:        %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
+  // CHECK:        %[[SLICE_2:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 16] [1, 1]
   // CHECK:        %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
-  // CHECK:          %[[INSLICE_2:.+]] = tensor.extract_slice %[[INSLICE_1]][0, %[[I2]]] [2, 8] [1, 1]
+  // CHECK:          %[[INSLICE_2:.+]] = tensor.extract_slice %[[IN]][%[[I1]], %[[I2]]] [2, 8] [1, 1]
   // CHECK:          %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1]
   // CHECK:          %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>)
   // CHECK:          %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
   // CHECK:          scf.yield %[[RESPARTIAL]]
 
-  // CHECK:        %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
+  // CHECK:        %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][%[[I1]], 0] [2, 16] [1, 1]
   // CHECK:        %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
   // CHECK:        scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
   // CHECK-COUNT-2:  tensor.extract_slice
@@ -74,7 +73,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
   // CHECK:        tensor.insert_slice
   // CHECK:        tensor.extract_slice
   // CHECK:        scf.for
-  // CHECK-COUNT-3:  tensor.extract_slice
+  // CHECK-COUNT-2:  tensor.extract_slice
   // CHECK:          scf.for
   // CHECK-COUNT-2:    tensor.extract_slice
   // CHECK:            linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>)

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index 9ff3f8002c37b..212712446ac5b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -1,5 +1,4 @@
 // RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -verify-diagnostics | FileCheck %s
-// RUN: mlir-opt %s --test-transform-dialect-interpreter --canonicalize --split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CANON
 
 transform.with_pdl_patterns {
 ^bb0(%arg0: !pdl.operation):
@@ -13,7 +12,6 @@ transform.with_pdl_patterns {
 func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
 
 // CHECK: #[[$ADD_42_MAP:.+]] = affine_map<(d0) -> (d0 + 42)>
-// CHECK: #[[$ADD_10_MAP:.+]] = affine_map<(d0) -> (d0 + 10)>
 
 // CHECK-LABEL: @one_d_static
 // CHECK-SAME:  %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32>
@@ -53,37 +51,14 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
 
 // CHECK-LABEL: @one_d_static_overflow
 // CHECK-SAME:  %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
-// CANON-LABEL:  @one_d_static_overflow
-// CANON-SAME:  %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
 func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
-  // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
-  // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
+  // Folding is sufficiently powerful to detect the static overflow and avoid
+  // the splitting altogether.
   // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
-  // CHECK:   ins(%[[IN_SLICE_LOW]]
-  // CHECK:   outs(%[[OUT_SLICE_LOW]]
+  // CHECK:   ins(%[[IN]]
+  // CHECK:   outs(%[[OUT]]
   // CHECK:   linalg.index 0
   // CHECK:   func.call @elem
-  // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1]
-  //
-  // Due to overflow, the first part of the split computes everything and the
-  // insert/extract slices are folded away by the canonicalizer.
-  // CANON: %[[RES_PARTIAL:.+]] = linalg.generic
-  // CANON:   ins(%[[IN]]
-  // CANON:   outs(%[[OUT]]
-  // CANON:   linalg.index 0
-  // CANON:   func.call @elem
-  // The second part operates on zero-sized slices that are not currently
-  // folded away.
-  //
-  // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
-  // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
-  // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
-  // CHECK:   ins(%[[IN_SLICE_HIGH]]
-  // CHECK:   outs(%[[OUT_SLICE_HIGH]]
-  // CHECK:   %[[IDX:.+]] = linalg.index 0
-  // CHECK:   affine.apply #[[$ADD_10_MAP]](%[[IDX]])
-  // CHECK:   func.call @elem
-  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][10] [0] [1]
   %0 = linalg.generic {
     indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
     iterator_types = ["parallel"]
@@ -118,6 +93,7 @@ func.func private @get_size() -> index
 func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
   // CHECK: %[[SPLIT:.+]] = call @get_size
   // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]]()[%[[SPLIT]]
+  // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
   // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
   // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
   // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
@@ -125,7 +101,6 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
   // CHECK:   outs(%[[OUT_SLICE_LOW]]
   // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1]
   //
-  // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
   // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
   // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor<?xf32>
   // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
@@ -133,7 +108,8 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
   // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
   // CHECK:   ins(%[[IN_SLICE_HIGH]]
   // CHECK:   outs(%[[OUT_SLICE_HIGH]]
-  // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1]
+  // CHECK: %[[SPLIT_HIGH_4:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
+  // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_4]]] [1]
   %0 = func.call @get_size() : () -> index
   %1 = linalg.generic {
     indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
@@ -175,14 +151,16 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
   //
   // CHECK:      %[[IN_2:.+]] = tensor.extract_slice %[[IN]]
   // CHECK:      %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]]
-  // CHECK:      %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]]
-  // CHECK:      %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]]
+  // Note that `extract_slice` taking a slice from another `extract_slice` result
+  // is folded to use the operand of the first `extract_slice`.
+  // CHECK:      %[[IN_21:.+]] = tensor.extract_slice %[[IN]]
+  // CHECK:      %[[OUT_21:.+]] = tensor.extract_slice %[[PARTIAL_1]]
   // CHECK:      %[[RES_21:.+]] = linalg.generic
   // CHECK-SAME:   ins(%[[IN_21]] : tensor<6x16xf32>)
   // CHECK-SAME:   outs(%[[OUT_21]] : tensor<6x16xf32>)
   // CHECK:      %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]]
   //
-  // CHECK:      %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]]
+  // CHECK:      %[[IN_22:.+]] = tensor.extract_slice %[[IN]]
   // CHECK:      %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]]
   // CHECK:      %[[RES_22:.+]] = linalg.generic
   // CHECK-SAME:   ins(%[[IN_22]] : tensor<6x18xf32>)


        


More information about the Mlir-commits mailing list