[Mlir-commits] [mlir] a5c802a - [mlir] fold more eagerly in structured op splitting

Alex Zinenko llvmlistbot at llvm.org
Tue Jul 12 08:06:59 PDT 2022


Author: Alex Zinenko
Date: 2022-07-12T15:06:55Z
New Revision: a5c802a429e2746c3d5190b2f3ed781911c62ed8

URL: https://github.com/llvm/llvm-project/commit/a5c802a429e2746c3d5190b2f3ed781911c62ed8
DIFF: https://github.com/llvm/llvm-project/commit/a5c802a429e2746c3d5190b2f3ed781911c62ed8.diff

LOG: [mlir] fold more eagerly in structured op splitting

Existing implementation of structured op splitting creates several
affine.apply and affine.min operations in its subshape computation.
As these shapes are further used in data slice extraction, this may lead
to slice shapes being dynamic even when the original shapes and the
splitting point are static. This is particularly visible when splitting
is combined with further subsetting transformations such as tiling. Use
composition and folding more aggressively in splitting to avoid this.

In particular, introduce a `createComposedAffineMin` function that the
affine map used in "min" with the maps used by any `affine.apply` that
may be feeding the operands to the "min". This enables production of
more static shapes. Also introduce a `createComposedFoldedAffineApply`
function that combines the existing `createComposedAffineApply` with
in-place folding to propagate constants produced by zero-input affine
maps. Using these when splitting allows the subsequent canonicalizer
pass to recover static shapes for structured ops.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129379

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Split.cpp
    mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
    mlir/test/Dialect/Linalg/transform-op-split.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 6c1d1fef4ee50..a48c48a4a91cf 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -25,7 +25,7 @@ namespace mlir {
 class AffineApplyOp;
 class AffineBound;
 class AffineValueMap;
-class IRRewriter;
+class RewriterBase;
 
 /// TODO: These should be renamed if they are on the mlir namespace.
 ///       Ideally, they should go in a mlir::affine:: namespace.
@@ -381,13 +381,37 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
 AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
                                       ValueRange values);
 
+/// Constructs an AffineApplyOp that applies `map` to `operands` after composing
+/// the map with the maps of any other AffineApplyOp supplying the operands,
+/// then immediately attempts to fold it. If folding results in a constant
+/// value, erases all created ops. The `map` must be a single-result affine map.
+OpFoldResult makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
+                                           AffineMap map,
+                                           ArrayRef<OpFoldResult> operands);
+/// Variant of `makeComposedFoldedAffineApply` that applies to an expression.
+OpFoldResult makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
+                                           AffineExpr expr,
+                                           ArrayRef<OpFoldResult> operands);
+
+/// Returns an AffineMinOp obtained by composing `map` and `operands` with
+/// AffineApplyOps supplying those operands.
+Value makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
+                            ValueRange operands);
+
+/// Constructs an AffineMinOp that computes a minimum across the results of
+/// applying `map` to `operands`, then immediately attempts to fold it. If
+/// folding results in a constant value, erases all created ops.
+OpFoldResult makeComposedFoldedAffineMin(RewriterBase &b, Location loc,
+                                         AffineMap map,
+                                         ArrayRef<OpFoldResult> operands);
+
 /// Returns the values obtained by applying `map` to the list of values.
 SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
                                        AffineMap map, ValueRange values);
 
 /// Returns the values obtained by applying `map` to the list of values, which
 /// may be known constants.
-SmallVector<OpFoldResult> applyMapToValues(IRRewriter &b, Location loc,
+SmallVector<OpFoldResult> applyMapToValues(RewriterBase &b, Location loc,
                                            AffineMap map,
                                            ArrayRef<OpFoldResult> values);
 

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 41d66a3aecbd7..e674e8b2585f5 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -588,7 +589,7 @@ OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
 /// Mutate `map`,`dims` and `syms` in place as follows:
 ///   1. `dims` and `syms` are only appended to.
-///   2. `map` dim and symbols are gradually shifted to higer positions.
+///   2. `map` dim and symbols are gradually shifted to higher positions.
 ///   3. Old `dim` and `sym` entries are replaced by nullptr
 /// This avoids the need for any bookkeeping.
 static LogicalResult replaceDimOrSym(AffineMap *map,
@@ -705,6 +706,68 @@ void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
   }
 }
 
+/// Given a list of `OpFoldResult`, build the necessary operations to populate
+/// `actualValues` with values produced by operations. In particular, for any
+/// attribute-typed element in `values`, call the constant materializer
+/// associated with the Affine dialect to produce an operation.
+static void materializeConstants(OpBuilder &b, Location loc,
+                                 ArrayRef<OpFoldResult> values,
+                                 SmallVectorImpl<Operation *> &constants,
+                                 SmallVectorImpl<Value> &actualValues) {
+  actualValues.reserve(values.size());
+  auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
+  for (OpFoldResult ofr : values) {
+    if (auto value = ofr.dyn_cast<Value>()) {
+      actualValues.push_back(value);
+      continue;
+    }
+    constants.push_back(dialect->materializeConstant(b, ofr.get<Attribute>(),
+                                                     b.getIndexType(), loc));
+    actualValues.push_back(constants.back()->getResult(0));
+  }
+}
+
+/// Create an operation of the type provided as template argument and attempt to
+/// fold it immediately. The operation is expected to have a builder taking
+/// arbitrary `leadingArguments`, followed by a list of Value-typed `operands`.
+/// The operation is also expected to always produce a single result. Return an
+/// `OpFoldResult` containing the Attribute representing the folded constant if
+/// complete folding was possible and a Value produced by the created operation
+/// otherwise.
+template <typename OpTy, typename... Args>
+static std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(),
+                        OpFoldResult>
+createOrFold(RewriterBase &b, Location loc, ValueRange operands,
+             Args &&...leadingArguments) {
+  // Identify the constant operands and extract their values as attributes.
+  // Note that we cannot use the original values directly because the list of
+  // operands may have changed due to canonicalization and composition.
+  SmallVector<Attribute> constantOperands;
+  constantOperands.reserve(operands.size());
+  for (Value operand : operands) {
+    IntegerAttr attr;
+    if (matchPattern(operand, m_Constant(&attr)))
+      constantOperands.push_back(attr);
+    else
+      constantOperands.push_back(nullptr);
+  }
+
+  // Create the operation and immediately attempt to fold it. On success,
+  // delete the operation and prepare the (unmaterialized) value for being
+  // returned. On failure, return the operation result value.
+  // TODO: arguably, the main folder (createOrFold) API should support this use
+  // case instead of indiscriminately materializing constants.
+  OpTy op =
+      b.create<OpTy>(loc, std::forward<Args>(leadingArguments)..., operands);
+  SmallVector<OpFoldResult, 1> foldResults;
+  if (succeeded(op->fold(constantOperands, foldResults)) &&
+      !foldResults.empty()) {
+    b.eraseOp(op);
+    return foldResults.front();
+  }
+  return op->getResult(0);
+}
+
 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
                                             AffineMap map,
                                             ValueRange operands) {
@@ -722,6 +785,86 @@ AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
       values);
 }
 
+OpFoldResult
+mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
+                                    AffineMap map,
+                                    ArrayRef<OpFoldResult> operands) {
+  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
+
+  SmallVector<Operation *> constants;
+  SmallVector<Value> actualValues;
+  materializeConstants(b, loc, operands, constants, actualValues);
+  composeAffineMapAndOperands(&map, &actualValues);
+  OpFoldResult result = createOrFold<AffineApplyOp>(b, loc, actualValues, map);
+  if (result.is<Attribute>()) {
+    for (Operation *op : constants)
+      b.eraseOp(op);
+  }
+  return result;
+}
+
+OpFoldResult
+mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
+                                    AffineExpr expr,
+                                    ArrayRef<OpFoldResult> operands) {
+  return makeComposedFoldedAffineApply(
+      b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}).front(),
+      operands);
+}
+
+/// Composes the given affine map with the given list of operands, pulling in
+/// the maps from any affine.apply operations that supply the operands.
+static void composeMultiResultAffineMap(AffineMap &map,
+                                        SmallVectorImpl<Value> &operands) {
+  // Compose and canonicalize each expression in the map individually because
+  // composition only applies to single-result maps, collecting potentially
+  // duplicate operands in a single list with shifted dimensions and symbols.
+  SmallVector<Value> dims, symbols;
+  SmallVector<AffineExpr> exprs;
+  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
+    SmallVector<Value> submapOperands(operands.begin(), operands.end());
+    AffineMap submap = map.getSubMap({i});
+    fullyComposeAffineMapAndOperands(&submap, &submapOperands);
+    canonicalizeMapAndOperands(&submap, &submapOperands);
+    unsigned numNewDims = submap.getNumDims();
+    submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
+    llvm::append_range(dims,
+                       ArrayRef<Value>(submapOperands).take_front(numNewDims));
+    llvm::append_range(symbols,
+                       ArrayRef<Value>(submapOperands).drop_front(numNewDims));
+    exprs.push_back(submap.getResult(0));
+  }
+
+  // Canonicalize the map created from composed expressions to deduplicate the
+  // dimension and symbol operands.
+  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
+  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
+  canonicalizeMapAndOperands(&map, &operands);
+}
+
+Value mlir::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
+                                  ValueRange operands) {
+  SmallVector<Value> allOperands = llvm::to_vector(operands);
+  composeMultiResultAffineMap(map, allOperands);
+  return b.createOrFold<AffineMinOp>(loc, b.getIndexType(), map, allOperands);
+}
+
+OpFoldResult
+mlir::makeComposedFoldedAffineMin(RewriterBase &b, Location loc, AffineMap map,
+                                  ArrayRef<OpFoldResult> operands) {
+  SmallVector<Operation *> constants;
+  SmallVector<Value> actualValues;
+  materializeConstants(b, loc, operands, constants, actualValues);
+  composeMultiResultAffineMap(map, actualValues);
+  OpFoldResult result =
+      createOrFold<AffineMinOp>(b, loc, actualValues, b.getIndexType(), map);
+  if (result.is<Attribute>()) {
+    for (Operation *op : constants)
+      b.eraseOp(op);
+  }
+  return result;
+}
+
 /// Fully compose map with operands and canonicalize the result.
 /// Return the `createOrFold`'ed AffineApply op.
 static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
@@ -749,23 +892,13 @@ SmallVector<Value, 4> mlir::applyMapToValues(OpBuilder &b, Location loc,
 }
 
 SmallVector<OpFoldResult>
-mlir::applyMapToValues(IRRewriter &b, Location loc, AffineMap map,
+mlir::applyMapToValues(RewriterBase &b, Location loc, AffineMap map,
                        ArrayRef<OpFoldResult> values) {
   // Materialize constants and keep track of produced operations so we can clean
   // them up later.
   SmallVector<Operation *> constants;
   SmallVector<Value> actualValues;
-  actualValues.reserve(values.size());
-  auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
-  for (OpFoldResult ofr : values) {
-    if (auto value = ofr.dyn_cast<Value>()) {
-      actualValues.push_back(value);
-      continue;
-    }
-    constants.push_back(dialect->materializeConstant(b, ofr.get<Attribute>(),
-                                                     b.getIndexType(), loc));
-    actualValues.push_back(constants.back()->getResult(0));
-  }
+  materializeConstants(b, loc, values, constants, actualValues);
 
   // Compose, fold and construct maps for each result independently because they
   // may simplify more effectively.
@@ -777,35 +910,9 @@ mlir::applyMapToValues(IRRewriter &b, Location loc, AffineMap map,
     SmallVector<Value> operands = actualValues;
     fullyComposeAffineMapAndOperands(&submap, &operands);
     canonicalizeMapAndOperands(&submap, &operands);
-
-    // Identify the constant operands and extract their values as attributes.
-    // Note that we cannot use the original values directly because the list of
-    // operands may have changed due to canonicalization and composition.
-    SmallVector<Attribute> constantOperands;
-    constantOperands.reserve(operands.size());
-    for (Value operand : operands) {
-      IntegerAttr attr;
-      if (matchPattern(operand, m_Constant(&attr)))
-        constantOperands.push_back(attr);
-      else
-        constantOperands.push_back(nullptr);
-    }
-
-    // Create an apply operation and immediately attempt to fold it. On sucess,
-    // delete the operation and prepare the (unmaterialized) value for being
-    // returned. On failure, return the function result.
-    // TODO: arguably, the main folder (createOrFold) API should support this
-    // use case instead of indiscriminately materializing constants.
-    auto apply = b.create<AffineApplyOp>(loc, submap, operands);
-    SmallVector<OpFoldResult, 1> foldResult;
-    if (succeeded(apply->fold(constantOperands, foldResult))) {
-      assert(foldResult.size() == 1 && "expected single-result map");
-      b.eraseOp(apply);
-      results.push_back(foldResult.front());
-    } else {
-      results.push_back(apply.getResult());
+    results.push_back(createOrFold<AffineApplyOp>(b, loc, operands, submap));
+    if (!results.back().is<Attribute>())
       foldedAll = false;
-    }
   }
 
   // If the entire map could be folded, remove the constants that were used in

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 18ff6c45e2caf..3f17942a355eb 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -294,11 +294,11 @@ DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
     return emitSilenceableError() << "could not generate tile size computation";
   }
 
+  AffineExpr s0 = builder.getAffineSymbolExpr(0);
+  AffineExpr s1 = builder.getAffineSymbolExpr(1);
   Operation *splitPoint =
-      builder
-          .createOrFold<arith::MulIOp>(target.getLoc(), spec->lowTileSize,
-                                       spec->lowTripCount)
-          .getDefiningOp();
+      makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
+                              {spec->lowTileSize, spec->lowTripCount});
   Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
   Operation *highTileSize = spec->highTileSize.getDefiningOp();
   assert(lowTileSize && highTileSize && splitPoint &&

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index 875c844ea7937..8849e7f964b3a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 
 #include "llvm/ADT/STLExtras.h"
 
@@ -24,7 +25,7 @@ using namespace mlir::linalg;
 /// generated; otherwise, the "low" part with no offset is generated. Note that
 /// `operands` are not necessarily the actual operands of `op`.
 static SmallVector<Value>
-getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op,
+getOperandSlices(RewriterBase &b, Location loc, LinalgOp op,
                  ValueRange splitIterationSpace, ValueRange operands,
                  unsigned dimension, Value offset = nullptr) {
   SmallVector<Value> slices;
@@ -42,20 +43,24 @@ getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op,
       continue;
     }
 
-    SmallVector<Value, 4> sizes =
-        applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace);
-    SmallVector<OpFoldResult> offsets(type.getRank(), builder.getIndexAttr(0));
-    SmallVector<OpFoldResult> strides(type.getRank(), builder.getIndexAttr(1));
+    SmallVector<OpFoldResult> sizes;
+    sizes.reserve(indexing.getNumResults());
+    for (AffineExpr dimIndexing : indexing.getResults()) {
+      sizes.push_back(makeComposedFoldedAffineApply(
+          b, loc, dimIndexing,
+          getAsOpFoldResult(llvm::to_vector(splitIterationSpace))));
+    }
+    SmallVector<OpFoldResult> offsets(type.getRank(), b.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(type.getRank(), b.getIndexAttr(1));
 
     if (offset) {
       offsets[dimension] = offset;
-      IRRewriter rewriter(builder);
-      offsets = applyMapToValues(rewriter, builder.getLoc(), indexing, offsets);
+      offsets = applyMapToValues(b, loc, indexing, offsets);
     }
 
-    slices.push_back(createSlice(builder, op.getLoc(),
+    slices.push_back(createSlice(b, loc,
                                  operands[opOperand->getOperandNumber()],
-                                 offsets, getAsOpFoldResult(sizes), strides));
+                                 offsets, sizes, strides));
   }
 
   return slices;
@@ -69,21 +74,23 @@ getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op,
 /// original op and updates it to be the iteration space of the curent part.
 /// Returns the split-out op as well as the output operand values updated with
 /// the partial results produced by this op through `results`.
-static LinalgOp createSplitPart(
-    ImplicitLocOpBuilder &builder, LinalgOp op, ValueRange resultOperands,
-    llvm::MutableArrayRef<Value> splitIterationSpace, unsigned dimension,
-    Value size, SmallVectorImpl<Value> &results, Value offset = nullptr) {
-  splitIterationSpace[dimension] = size;
+static LinalgOp
+createSplitPart(RewriterBase &b, Location loc, LinalgOp op,
+                ValueRange resultOperands,
+                llvm::MutableArrayRef<Value> splitIterationSpace,
+                unsigned dimension, OpFoldResult size,
+                SmallVectorImpl<Value> &results, Value offset = nullptr) {
+  ImplicitLocOpBuilder implicit(op.getLoc(), b);
+  splitIterationSpace[dimension] = materializeOpFoldResult(implicit, size);
   SmallVector<Value> operands = llvm::to_vector(
       llvm::map_range(op.getInputOperands(),
                       [](OpOperand *opOperand) { return opOperand->get(); }));
   llvm::append_range(operands, resultOperands);
-  operands = getOperandSlices(builder, op, splitIterationSpace, operands,
+  operands = getOperandSlices(b, loc, op, splitIterationSpace, operands,
                               dimension, offset);
-  Operation *part = op.clone(builder, op.getLoc(),
-                             getTensorOutputTypes(op, operands), operands);
-  results = insertSlicesBack(builder, builder.getLoc(), op, operands,
-                             part->getResults());
+  Operation *part =
+      op.clone(b, loc, getTensorOutputTypes(op, operands), operands);
+  results = insertSlicesBack(b, loc, op, operands, part->getResults());
   return cast<LinalgOp>(part);
 }
 
@@ -95,45 +102,45 @@ std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
     return std::make_pair(op, LinalgOp());
 
   // Compute the iteration space size as values.
-  ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
   SmallVector<Value, 4> allShapes =
-      op.createFlatListOfOperandDims(builder, op.getLoc());
+      op.createFlatListOfOperandDims(rewriter, op.getLoc());
   AffineMap shapesToLoops = op.getShapesToLoopsMap();
   SmallVector<Value, 4> iterationSpaceShapes =
-      applyMapToValues(builder, op.getLoc(), shapesToLoops, allShapes);
+      applyMapToValues(rewriter, op.getLoc(), shapesToLoops, allShapes);
 
   // Update the iteration space to have `splitPoint` as the size of `dimension`
   // and use it to slice operands and results for a new, smaller instance of the
   // `op`. Adjust the size if necessary to prevent overflows. Insert the partial
   // results back.
-  Value splitPointValue = materializeOpFoldResult(builder, splitPoint);
-  splitPointValue = builder.createOrFold<AffineMinOp>(
-      builder.getIndexType(),
-      AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()),
-      ValueRange({splitPointValue, iterationSpaceShapes[dimension]}));
+  OpFoldResult dimSize = getAsOpFoldResult(iterationSpaceShapes[dimension]);
+  OpFoldResult minSplitPoint = makeComposedFoldedAffineMin(
+      rewriter, op->getLoc(),
+      AffineMap::getMultiDimIdentityMap(/*numDims=*/2, rewriter.getContext()),
+      {splitPoint, dimSize});
   SmallVector<Value> splitIterationSpace =
       llvm::to_vector(iterationSpaceShapes);
   SmallVector<Value> originalResults = llvm::to_vector(
       llvm::map_range(op.getOutputOperands(),
                       [](OpOperand *opOperand) { return opOperand->get(); }));
   SmallVector<Value> firstResults;
-  LinalgOp first =
-      createSplitPart(builder, op, originalResults, splitIterationSpace,
-                      dimension, splitPointValue, firstResults);
+  LinalgOp first = createSplitPart(rewriter, op.getLoc(), op, originalResults,
+                                   splitIterationSpace, dimension,
+                                   minSplitPoint, firstResults);
 
   // Update the iteration space to cover the remaining part of the original
   // space, then create another instance of the `op` in that space. The size of
   // the remaining part may become zero, but is never negative because of the
   // adjustment above.
-  AffineExpr d0 = builder.getAffineDimExpr(0);
-  AffineExpr d1 = builder.getAffineDimExpr(1);
-  SmallVector<Value, 4> remainingSizes = applyMapToValues(
-      builder, op.getLoc(), AffineMap::inferFromExprList({d0 - d1}).front(),
-      {iterationSpaceShapes[dimension], splitPointValue});
+  AffineExpr d0 = rewriter.getAffineDimExpr(0);
+  AffineExpr d1 = rewriter.getAffineDimExpr(1);
+  OpFoldResult remainingSize = makeComposedFoldedAffineApply(
+      rewriter, op.getLoc(), d0 - d1, {dimSize, minSplitPoint});
   SmallVector<Value> secondResults;
-  LinalgOp second =
-      createSplitPart(builder, op, firstResults, splitIterationSpace, dimension,
-                      remainingSizes.front(), secondResults, splitPointValue);
+  ImplicitLocOpBuilder implicit(op.getLoc(), rewriter);
+  Value splitPointValue = materializeOpFoldResult(implicit, minSplitPoint);
+  LinalgOp second = createSplitPart(
+      rewriter, op.getLoc(), op, firstResults, splitIterationSpace, dimension,
+      remainingSize, secondResults, splitPointValue);
 
   // Fixup the linalg.index results in the second part.
   SmallVector<Value> ivAdditions;

diff  --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index e30a140535fcd..f606c93ef265a 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -28,9 +28,6 @@ transform.with_pdl_patterns {
 
 func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
 
-// CHECK-DAG: #[[$MAP_MIN_4_2:.+]] = affine_map<(d0) -> (-d0 + 4, 2)>
-// CHECK-DAG: #[[$MAP_MIN_16_8:.+]] = affine_map<(d0) -> (-d0 + 16, 8)>
-
 // CHECK-LABEL: @two_d
 // CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
 func.func @two_d(%arg0: tensor<10x34xf32>,
@@ -54,35 +51,27 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
   // respectively, and in this order.
   // Check the full code for the first quadrant, the data flow for the second
   // quadrant and only the overall code structure for the remaining quadrants.
-  //
-  // TODO: unfortunately, the canonicalization is insufficiently powerful to
-  // remove the affine min for sizes, leading to dynamic sizes even when tiling
-  // statically-shaped operation with constant tile sizes.
+  // The canonicalizer is able to recover static shapes of for linalg.generic
+  // instances, use those to 
diff erentiate the quadrants.
 
   // CHECK:      %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
   // CHECK:      scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
-  // CHECK:        %[[SZ1:.+]] = affine.min #[[$MAP_MIN_4_2]](%[[I1]])
-  // CHECK:        %[[INSLICE_1:.+]] = tensor.extract_slice %[[IN]][%[[I1]], 0] [%[[SZ1]], 34] [1, 1]
-  // CHECK:        %[[SZ2:.+]] = affine.min #[[$MAP_MIN_4_2]](%[[I1]])
-  // CHECK:        %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [%[[SZ2]], 34] [1, 1]
+  // CHECK:        %[[INSLICE_1:.+]] = tensor.extract_slice %[[IN]][%[[I1]], 0] [2, 34] [1, 1]
+  // CHECK:        %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1]
 
-  // CHECK:        %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [%[[SZ1]], 16] [1, 1]
+  // CHECK:        %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
   // CHECK:        %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
-  // CHECK:          %[[SZ3:.+]] = affine.min #[[$MAP_MIN_16_8]](%[[I2]])
-  // CHECK:          %[[INSLICE_2:.+]] = tensor.extract_slice %[[INSLICE_1]][0, %[[I2]]] [%[[SZ1]], %[[SZ3]]] [1, 1]
-  // CHECK:          %[[SZ4:.+]] = tensor.dim %[[ITERARG_2]]
-  // CHECK:          %[[SZ5:.+]] = affine.min #[[$MAP_MIN_16_8]](%[[I2]])
-  // CHECK:          %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [%[[SZ4]], %[[SZ5]]] [1, 1]
-
-  // CHECK:          %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<?x?xf32>) outs(%[[OUTSLICE_2]] : tensor<?x?xf32>)
+  // CHECK:          %[[INSLICE_2:.+]] = tensor.extract_slice %[[INSLICE_1]][0, %[[I2]]] [2, 8] [1, 1]
+  // CHECK:          %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1]
+  // CHECK:          %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>)
   // CHECK:          %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
   // CHECK:          scf.yield %[[RESPARTIAL]]
 
-  // CHECK:        %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [%[[SZ1]], 16] [1, 1]
-  // CHECK:        %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [%[[SZ1]], 18] [1, 1]
+  // CHECK:        %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
+  // CHECK:        %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
   // CHECK:        scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
   // CHECK-COUNT-2:  tensor.extract_slice
-  // CHECK:          linalg.generic
+  // CHECK:          linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>)
   // CHECK:          tensor.insert_slice
   // CHECK:          scf.yield
   // CHECK:        %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]]
@@ -95,14 +84,14 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
   // CHECK-COUNT-3:  tensor.extract_slice
   // CHECK:          scf.for
   // CHECK-COUNT-2:    tensor.extract_slice
-  // CHECK:            linalg.generic
+  // CHECK:            linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>)
   // CHECK:            tensor.insert_slice
   // CHECK:            scf.yield
   // CHECK:          tensor.insert_slice
   // CHECK:          tensor.extract_slice
   // CHECK:          scf.for
   // CHECK-COUNT-2:    tensor.extract_slice
-  // CHECK:            linalg.generic
+  // CHECK:            linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>)
   // CHECK:            tensor.insert_slice
   // CHECK:            scf.yield
   // CHECK-COUNT-2:  tensor.insert_slice

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index 2eef84c82b4dd..9f896a4147933 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --canonicalize --split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CANON
 
 transform.with_pdl_patterns {
 ^bb0(%arg0: !pdl.operation):
@@ -59,6 +60,8 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
 
 // CHECK-LABEL: @one_d_static_overflow
 // CHECK-SAME:  %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
+// CANON-LABEL:  @one_d_static_overflow
+// CANON-SAME:  %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
 func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
   // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
   // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
@@ -69,6 +72,16 @@ func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -
   // CHECK:   func.call @elem
   // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1]
   //
+  // Due to overflow, the first part of the split computes everything and the
+  // insert/extract slices are folded away by the canonicalizer.
+  // CANON: %[[RES_PARTIAL:.+]] = linalg.generic
+  // CANON:   ins(%[[IN]]
+  // CANON:   outs(%[[OUT]]
+  // CANON:   linalg.index 0
+  // CANON:   func.call @elem
+  // The second part operates on zero-sized slices that are not currently
+  // folded away.
+  //
   // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
   // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
   // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
@@ -118,13 +131,13 @@ transform.with_pdl_patterns {
 
 func.func private @get_size() -> index
 
-// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<(d0, d1) -> (d0, 100)>
+// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<()[s0] -> (s0, 100)>
 // CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)>
 
 // CHECK-LABEL: @dynamic
 func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
   // CHECK: %[[SPLIT:.+]] = call @get_size
-  // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]](%[[SPLIT]]
+  // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]]()[%[[SPLIT]]
   // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
   // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
   // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
@@ -148,7 +161,8 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
   }
   ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
   ^bb0(%3: f32, %4: f32):
-    linalg.yield %3 : f32
+    %5 = arith.addf %3, %4 : f32
+    linalg.yield %5 : f32
   } -> tensor<100xf32>
   return %1 : tensor<100xf32>
 }


        


More information about the Mlir-commits mailing list