[Mlir-commits] [mlir] 01bdb0f - [mlir][linalg] Improve implementation of hoist padding.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jul 15 05:10:37 PDT 2021


Author: Nicolas Vasilache
Date: 2021-07-15T12:10:31Z
New Revision: 01bdb0f75efb2bb795a79cea9f3f918136d13a7f

URL: https://github.com/llvm/llvm-project/commit/01bdb0f75efb2bb795a79cea9f3f918136d13a7f
DIFF: https://github.com/llvm/llvm-project/commit/01bdb0f75efb2bb795a79cea9f3f918136d13a7f.diff

LOG: [mlir][linalg] Improve implementation of hoist padding.

Instead of relying on adhoc bounds calculations, use a projection-based
implementation. This simplifies the implementation and finds more static
constant sizes than previously/

Differential Revision: https://reviews.llvm.org/D106054

Added: 
    mlir/include/mlir/Dialect/Linalg/Analysis/ConstraintsSet.h
    mlir/lib/Dialect/Linalg/Analysis/ConstraintsSet.cpp

Modified: 
    mlir/lib/Dialect/Linalg/Analysis/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/test/Dialect/Linalg/hoist-padding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Analysis/ConstraintsSet.h b/mlir/include/mlir/Dialect/Linalg/Analysis/ConstraintsSet.h
new file mode 100644
index 000000000000..c1ed91123884
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/ConstraintsSet.h
@@ -0,0 +1,67 @@
+//===- ConstraintsSet.h - Extensions for FlatAffineConstraints --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Linalg-specific constraints set extensions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_
+#define MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/AffineMap.h"
+
+namespace mlir {
+class ValueRange;
+
+/// Linalg-specific constraints set extensions.
+class ConstraintsSet : public FlatAffineConstraints {
+public:
+  ConstraintsSet() : FlatAffineConstraints() {}
+
+  /// Assuming `val` is defined by `val = affine.min map (operands)`, introduce
+  /// all the constraints `val >= expr_i(operands)`, where expr_i are all the
+  /// results of `map`.
+  // This API avoids taking a dependence on the AffineMinOp definition.
+  LogicalResult composeMin(Value val, AffineMap map, ValueRange operands) {
+    return composeMinOrMaxMapAndOperands(val, map, operands, /*min=*/true);
+  }
+
+  /// Assuming `val` is defined by `val = affine.max map (operands)`, introduce
+  /// all the constraints `val <= expr_i(operands)`, where expr_i are all the
+  /// results of `map`.
+  // This API avoids taking a dependence on the AffineMaxOp definition.
+  LogicalResult composeMax(Value val, AffineMap map, ValueRange operands) {
+    return composeMinOrMaxMapAndOperands(val, map, operands, /*min=*/false);
+  }
+
+  /// Assuming `val` is defined by `val = affine.apply map (operands)`, call
+  /// composeMap.
+  // This API avoids taking a dependence on the AffineMApplyOp definition.
+  LogicalResult composeAffineApply(Value val, AffineMap map,
+                                   ValueRange operands);
+
+  /// Asserts the identifier `id` is in the constraints set and returns it.
+  unsigned lookupPos(Value id) const;
+
+  /// If v is not in the constraint set, insert it as a dim or symbol depending
+  /// on `asDim`.
+  /// Return success if v is of dim id type when `asDim` is true and of symbol
+  /// id type when `asDim` is false.
+  /// Return failure otherwise.
+  LogicalResult ensureIdOfType(Value v, bool asDim);
+
+private:
+  /// Implementation detail for composeMin/Max.
+  LogicalResult composeMinOrMaxMapAndOperands(Value val, AffineMap map,
+                                              ValueRange operands, bool min);
+};
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_ANALYSIS_CONSTRAINTS_SET_H_

diff  --git a/mlir/lib/Dialect/Linalg/Analysis/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Analysis/CMakeLists.txt
index 14a161b13325..966d6a7bc956 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Analysis/CMakeLists.txt
@@ -1,12 +1,15 @@
 add_mlir_dialect_library(MLIRLinalgAnalysis
+  ConstraintsSet.cpp
   DependenceAnalysis.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
 
   LINK_LIBS PUBLIC
+  MLIRAnalysis
   MLIRIR
   MLIRLinalg
+  MLIRLoopAnalysis
   MLIRMemRef
   MLIRStandard
   )

diff  --git a/mlir/lib/Dialect/Linalg/Analysis/ConstraintsSet.cpp b/mlir/lib/Dialect/Linalg/Analysis/ConstraintsSet.cpp
new file mode 100644
index 000000000000..7e37da5546a7
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Analysis/ConstraintsSet.cpp
@@ -0,0 +1,87 @@
+//===- ConstraintsSet.cpp - Extensions for FlatAffineConstraints ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Linalg-specific constraints set extensions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Analysis/ConstraintsSet.h"
+#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/IR/AffineMap.h"
+
+using namespace mlir;
+
+unsigned ConstraintsSet::lookupPos(Value id) const {
+  unsigned pos;
+  if (!findId(id, &pos)) {
+    llvm::errs() << "Lookup failed: " << id << "\n";
+    llvm_unreachable("Lookup failed");
+  }
+  return pos;
+}
+
+LogicalResult ConstraintsSet::ensureIdOfType(Value v, bool asDim) {
+  if (!containsId(v)) {
+    if (asDim)
+      addDimId(getNumDimIds(), v);
+    else
+      addSymbolId(getNumSymbolIds(), v);
+    return success();
+  }
+  unsigned pos = lookupPos(v);
+  return success((asDim && pos < getNumDimIds()) ||
+                 (!asDim && getNumDimIds() <= pos &&
+                  pos < getNumDimIds() + getNumSymbolIds()));
+}
+
+LogicalResult ConstraintsSet::composeAffineApply(Value val, AffineMap map,
+                                                 ValueRange operands) {
+  AffineValueMap avm(map, operands, val);
+  return composeMap(&avm);
+}
+
+LogicalResult ConstraintsSet::composeMinOrMaxMapAndOperands(Value val,
+                                                            AffineMap map,
+                                                            ValueRange operands,
+                                                            bool min) {
+  ConstraintsSet localCst;
+  std::vector<SmallVector<int64_t, 8>> flatExprs;
+  if (failed(getFlattenedAffineExprs(map, &flatExprs, &localCst)))
+    return failure();
+  assert(flatExprs.size() == map.getNumResults() &&
+         "incorrect number of flattened expressiosn");
+
+  // Local vars on a per-need basis.
+  if (localCst.getNumLocalIds() != 0)
+    return failure();
+
+  // Add one inequality for each result connecting `val` to the other ids in
+  // `operands`. For instance, uf the expression is:
+  //   `16 * i0 + i1` and
+  //   `min` is true
+  // add:
+  //  -d_val + 16 * i0 + i1 >= 0.
+  for (const auto &flatExpr : flatExprs) {
+    assert(flatExpr.size() >= operands.size() + 1);
+    SmallVector<int64_t, 8> ineq(getNumCols(), 0);
+    for (unsigned i = 0, e = operands.size(); i < e; i++)
+      ineq[lookupPos(operands[i])] = min ? flatExpr[i] : -flatExpr[i];
+
+    // Set the coefficient for `d_val`.
+    ineq[lookupPos(val)] = min ? -1 : 1;
+
+    // Set the constant term (upper bound in flatExpr is exclusive).
+    ineq[getNumCols() - 1] = min ? flatExpr[flatExpr.size() - 1] - 1
+                                 : -flatExpr[flatExpr.size() - 1];
+
+    // Add the inequality connecting the result of the map to the rest.
+    addInequality(ineq);
+  }
+
+  return success();
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 07a5fc5c7502..b6829171c580 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -12,8 +12,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Analysis/AffineStructures.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Linalg/Analysis/ConstraintsSet.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/SCF.h"
@@ -530,97 +532,6 @@ bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
   return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>();
 }
 
-/// Compute the tightest lower bound with quantities that are all defined
-/// outside of `outer`.
-/// Return null if such a bound cannot be computed.
-Value computeLoopIndependentLowerBound(OpBuilder &b, scf::ForOp outer,
-                                       Value v) {
-  if (isDefinedOutsideOrConstant(outer, v))
-    return v;
-  return Value();
-}
-
-/// Compute the tightest upper bound with quantities that are all defined
-/// outside of `outer`.
-/// Expects all ops in the backward slice of `v` up to `outer` to be either
-/// scf.for, affine.min or affine.apply.
-static Value computeLoopIndependentUpperBound(OpBuilder &b, scf::ForOp outer,
-                                              Value v) {
-  if (isDefinedOutsideOrConstant(outer, v))
-    return v;
-
-  LLVM_DEBUG(DBGS() << "Begin loopIndependentUpperBound for: " << v << "\n");
-
-  bool ok =
-      backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp, AffineApplyOp>(
-          outer, v);
-  assert(ok && "expected to only be defined by scf::ForOp and AffineMinOp");
-  (void)ok;
-
-  // Compute a backward slice up to, but not including, `outer`.
-  SetVector<Operation *> backwardSlice;
-  getBackwardSlice(v, &backwardSlice,
-                   [&](Operation *op) { return outer->isProperAncestor(op); });
-  backwardSlice.insert(v.getDefiningOp());
-
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(outer);
-  Value res = v;
-  BlockAndValueMapping bvm;
-  for (Operation *op : backwardSlice) {
-    if (isa<scf::ForOp>(op))
-      continue;
-    if (isa<AffineApplyOp>(op)) {
-      b.clone(*op, bvm);
-      continue;
-    }
-    auto sliceMinOp = cast<AffineMinOp>(op);
-    GetMinMaxExprFn getSCFMinMax = [&](Value value,
-                                       SmallVectorImpl<Value> &dims,
-                                       SmallVectorImpl<Value> &symbols) {
-      return getSCFMinMaxExpr(value, dims, symbols, [&](Operation *op) {
-        return outer->isAncestor(op);
-      });
-    };
-    // Perform the substitution of the operands of AffineMinOp.
-    auto mapAndOperands = substituteMin(sliceMinOp, getSCFMinMax);
-    SmallVector<Value> resultOperands = mapAndOperands.dims;
-    llvm::append_range(resultOperands, mapAndOperands.symbols);
-    AffineMap map = mapAndOperands.map;
-    canonicalizeMapAndOperands(&map, &resultOperands);
-    map = simplifyAffineMap(map);
-    res = b.create<AffineMinOp>(
-        outer->getLoc(), map,
-        llvm::to_vector<4>(llvm::map_range(resultOperands, [&](Value operand) {
-          return bvm.lookupOrDefault(operand);
-        })));
-    bvm.map(sliceMinOp, res);
-  }
-  LLVM_DEBUG(DBGS() << "End loopIndependentUpperBound with: " << res << "\n");
-  return res;
-}
-
-/// Return the number of iterations in the loop (ub - lb).ceilDiv(step).
-/// The returned Value is guaranteed not to depend on any loop comprised in
-/// [`outer`, `forOp`].
-/// Return null if such a loop-independent quantity cannot be computed.
-static Value buildLoopTripCount(OpBuilder &b, scf::ForOp outer,
-                                scf::ForOp forOp) {
-  MLIRContext *ctx = forOp->getContext();
-  AffineExpr lb, ub, step;
-  bindDims(ctx, lb, ub);
-  bindSymbols(ctx, step);
-  Value lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()),
-        ubVal = computeLoopIndependentUpperBound(b, outer, forOp.upperBound()),
-        stepVal = forOp.step();
-  if (!lbVal || !ubVal || !stepVal)
-    return Value();
-  auto loc = forOp->getLoc();
-  Value res = b.create<AffineApplyOp>(loc, (ub - lb).ceilDiv(step),
-                                      ValueRange{lbVal, ubVal, stepVal});
-  return res;
-}
-
 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step).
 /// The returned Value is guaranteed not to depend on any loop comprised in
 /// [`outer`, `forOp`].
@@ -631,14 +542,135 @@ static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer,
   AffineExpr iv, lb, step;
   bindDims(ctx, iv, lb);
   bindSymbols(ctx, step);
-  Value ivVal = forOp.getInductionVar(),
-        lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()),
-        stepVal = forOp.step();
-  if (!ivVal || !lbVal || !stepVal)
+  if (!isDefinedOutsideOrConstant(outer, forOp.lowerBound()) ||
+      !isDefinedOutsideOrConstant(outer, forOp.step()))
     return Value();
+  Value ivVal = forOp.getInductionVar(), lbVal = forOp.lowerBound(),
+        stepVal = forOp.step();
   auto loc = forOp->getLoc();
-  return b.create<AffineApplyOp>(loc, (iv - lb).ceilDiv(step),
-                                 ValueRange{ivVal, lbVal, stepVal});
+  return b.createOrFold<AffineApplyOp>(loc, (iv - lb).ceilDiv(step),
+                                       ValueRange{ivVal, lbVal, stepVal});
+}
+
+/// Given a set of loops, assumed to be scf::ForOp, create a constraint set
+/// containing the inequalities `iv - lb >= 0` and `-iv + ub >= 0` for each
+/// loop.
+static ConstraintsSet initLoopIvsAndBounds(ArrayRef<Operation *> loops) {
+  ConstraintsSet constraints;
+  for (Operation *op : loops)
+    constraints.addDimId(constraints.getNumDimIds(),
+                         cast<scf::ForOp>(op).getInductionVar());
+  for (Operation *op : loops)
+    constraints.addDimId(constraints.getNumDimIds(),
+                         cast<scf::ForOp>(op).lowerBound());
+  for (Operation *op : loops)
+    constraints.addDimId(constraints.getNumDimIds(),
+                         cast<scf::ForOp>(op).upperBound());
+  unsigned numLoops = loops.size();
+  for (unsigned ivIdx = 0, e = numLoops; ivIdx < e; ++ivIdx) {
+    // iv - lb >= 0
+    SmallVector<int64_t, 8> ineqLb(constraints.getNumCols(), 0);
+    ineqLb[ivIdx] = 1;
+    ineqLb[ivIdx + numLoops] = -1;
+    // -iv + ub >= 0
+    SmallVector<int64_t, 8> ineqUb(constraints.getNumCols(), 0);
+    ineqUb[ivIdx] = -1;
+    ineqUb[ivIdx + 2 * numLoops] = 1;
+    ineqUb[constraints.getNumCols() - 1] = -1;
+    constraints.addInequality(ineqLb);
+    constraints.addInequality(ineqUb);
+  }
+  return constraints;
+}
+
+/// For each loop in `loops`, determine the ops involved in the construction of
+/// its upper bound---up to the outerLimit loop--- and fold them as new
+/// inequalities in the constraint set.
+/// This is achieved by computing the backwardSlice of the loop's upper bound
+/// and iteratively folding each op in reverse topological order to guarantee
+/// use-def ordering.
+/// As operations are folded in, their result is projected out of the
+/// constraints set.
+/// The following operations are supported:
+///   - scf::ForOp are simply skipped.
+///   - AffineApplyOp are composed to replace the result by an equality.
+///   - AffineMinOp are composed by adding each entry as an upper bound.
+/// If any other operation is met, return failure.
+// TODO: extend on a per-need basis.
+static LogicalResult
+foldUpperBoundsIntoConstraintsSet(ConstraintsSet &constraints,
+                                  scf::ForOp outerLimit,
+                                  ArrayRef<Operation *> loops) {
+  SetVector<Value> toProjectOut;
+  for (Operation *loop : loops) {
+    auto ub = cast<scf::ForOp>(loop).upperBound();
+    if (isDefinedOutsideOrConstant(outerLimit, ub))
+      continue;
+
+    // Compute a backward slice up to, but not including, `outerLimit`.
+    SetVector<Operation *> backwardSlice;
+    getBackwardSlice(ub, &backwardSlice, [&](Operation *op) {
+      return outerLimit->isProperAncestor(op);
+    });
+    backwardSlice.insert(ub.getDefiningOp());
+
+    // Iterate over all ops in the slice and compose them in the constraints.
+    for (Operation *op : llvm::reverse(backwardSlice)) {
+      if (!isa<scf::ForOp, AffineApplyOp, AffineMinOp>(op))
+        return failure();
+      if (isa<scf::ForOp>(op))
+        continue;
+      // Ensure there is a
+      auto ensureIdFailed = [&](Value v) {
+        return failed(constraints.ensureIdOfType(v, /*asDim=*/true));
+      };
+
+      // Ensure all ids exist and add results for later projection.
+      if (llvm::any_of(op->getResults(), ensureIdFailed) ||
+          llvm::any_of(op->getOperands(), ensureIdFailed))
+        return failure();
+
+      // All supported ops have 1 result.
+      // TODO: extend when needed.
+      toProjectOut.insert(op->getResult(0));
+
+      // Compose supported ops.
+      if (auto affineApplyOp = dyn_cast<AffineApplyOp>(op)) {
+        if (failed(constraints.composeAffineApply(affineApplyOp.getResult(),
+                                                  affineApplyOp.getAffineMap(),
+                                                  affineApplyOp.getOperands())))
+          return failure();
+        continue;
+      }
+      auto affineMinOp = cast<AffineMinOp>(op);
+      if (failed(constraints.composeMin(affineMinOp.getResult(),
+                                        affineMinOp.getAffineMap(),
+                                        affineMinOp.operands())))
+        return failure();
+    }
+  }
+  for (Value v : toProjectOut)
+    constraints.projectOut(v);
+  return success();
+}
+
+/// Compute dynamic tensor sizes, independent of any value defined inside
+/// `outer` and such that every n-D iteration of the packingLoops has its own
+/// space (so that each packed buffer has a storage location). This is achieved
+/// by computing the extent for each of the packing loops.
+static LogicalResult computeBounds(scf::ForOp outer,
+                                   ArrayRef<Operation *> packingLoops,
+                                   SmallVector<AffineMap> &lbs,
+                                   SmallVector<AffineMap> &ubs) {
+  // Packing loop IVs are introduced as the first positions.
+  ConstraintsSet constraints = initLoopIvsAndBounds(packingLoops);
+  if (failed(
+          foldUpperBoundsIntoConstraintsSet(constraints, outer, packingLoops)))
+    return failure();
+  // Compute the bounds of the first positions, assuming the others are fixed.
+  constraints.getSliceBounds(/*pos=*/0, /*num=*/packingLoops.size(),
+                             outer->getContext(), &lbs, &ubs);
+  return success();
 }
 
 /// Ensure prerequisites that guarantee pad op hoisting can occur.
@@ -725,28 +757,49 @@ hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
   assert(outermostEnclosingForOp == backwardSlice.front());
 
   scf::ForOp outer = cast<scf::ForOp>(outermostEnclosingForOp);
-  if (llvm::any_of(packingLoops, [&](Operation *op) {
-        scf::ForOp forOp = cast<scf::ForOp>(op);
-        Value lb = forOp.lowerBound(), ub = forOp.upperBound(),
-              step = forOp.step();
-        return !isDefinedOutsideOrConstant(outer, lb) ||
-               !(isDefinedOutsideOrConstant(outer, ub) ||
-                 backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp,
-                                                AffineApplyOp>(outer, ub)) ||
-               !isDefinedOutsideOrConstant(outer, step);
-      }))
+
+  ConstraintsSet constraints = initLoopIvsAndBounds(packingLoops.getArrayRef());
+  if (failed(foldUpperBoundsIntoConstraintsSet(constraints, outer,
+                                               packingLoops.getArrayRef())))
+    return failure();
+
+  unsigned numLoops = packingLoops.size();
+  SmallVector<AffineMap> lbs(numLoops), ubs(numLoops);
+  if (failed(computeBounds(outer, packingLoops.getArrayRef(), lbs, ubs)))
     return failure();
 
+  SmallVector<Value> allValues;
+  constraints.getAllIdValues(&allValues);
+  SmallVector<Value> allNonLoopValues(allValues.begin() + numLoops,
+                                      allValues.end());
+
+  // For each packingLoop, create the extent by (ub - lb).ceilDiv(step).
   // IP just before the outermost loop considered that we hoist above.
-  OpBuilder b(outermostEnclosingForOp);
-  dynamicTensorSizes =
-      llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) {
-        return buildLoopTripCount(b, cast<scf::ForOp>(outermostEnclosingForOp),
-                                  cast<scf::ForOp>(op));
-      }));
-  // Assert all loop trip counts can be computed.
-  if (!llvm::all_of(dynamicTensorSizes, [](Value v) { return v; }))
-    llvm_unreachable("loop independence prerequisite not met");
+  ImplicitLocOpBuilder b(outer->getLoc(), outer);
+  assert(packingLoops.size() == lbs.size() && "expected matching lb sizes");
+  assert(packingLoops.size() == ubs.size() && "expected matching ub sizes");
+  for (auto it : llvm::zip(packingLoops, lbs, ubs)) {
+    scf::ForOp loop = cast<scf::ForOp>(std::get<0>(it));
+    AffineMap lbMap = std::get<1>(it);
+    AffineMap ubMap = std::get<2>(it);
+    SmallVector<Value> lbOperands(allNonLoopValues);
+    canonicalizeMapAndOperands(&lbMap, &lbOperands);
+    Value lbVal = b.createOrFold<AffineMaxOp>(lbMap, lbOperands);
+
+    SmallVector<Value> ubOperands(allNonLoopValues);
+    canonicalizeMapAndOperands(&ubMap, &ubOperands);
+    Value ubVal = b.createOrFold<AffineMinOp>(ubMap, ubOperands);
+
+    AffineExpr lb, ub, step;
+    bindDims(b.getContext(), lb, ub);
+    bindSymbols(b.getContext(), step);
+    Value res = b.createOrFold<AffineApplyOp>(
+        (ub - lb).ceilDiv(step),
+        ValueRange{lbVal, ubVal, cast<scf::ForOp>(loop).step()});
+
+    dynamicTensorSizes.push_back(res);
+  }
+
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir
index f345051827ab..6a51215be2f2 100644
--- a/mlir/test/Dialect/Linalg/hoist-padding.mlir
+++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir
@@ -141,8 +141,10 @@ func @matmul_tensors(
 
 // -----
 
+
 // CHECK-DAG: #[[$MIN_REST8:[0-9a-z]+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
-// CHECK-DAG: #[[$MIN_MOD4:[0-9a-z]+]] = affine_map<(d0) -> (4, d0 - ((d0 - 1) floordiv 4) * 4)>
+// CHECK-DAG: #[[$MIN_REST4:[0-9a-z]+]] = affine_map<(d0, d1) -> (4, d0 - d1)>
+// CHECK-DAG: #[[$MIN_REST2:[0-9a-z]+]] = affine_map<(d0, d1) -> (2, d0 - d1)>
 // CHECK-DAG: #[[$DIV4:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 4)>
 // CHECK-DAG: #[[$DIV2:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 2)>
 #map0 = affine_map<(d0)[s0] -> (8, -d0 + s0)>
@@ -167,20 +169,18 @@ func @dot(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<f32>)
   //
   //      CHECK:   %[[MR8:.*]] = affine.min #[[$MIN_REST8]](%[[I]])
   //      CHECK:   %[[D0:.*]] = affine.apply #[[$DIV4]](%[[MR8]])
-  //      CHECK:   %[[MM4:.*]] = affine.min #[[$MIN_MOD4]](%[[MR8]])
-  //      CHECK:   %[[D1:.*]] = affine.apply #[[$DIV2]](%[[MM4]])
   // Init tensor and pack.
-  //      CHECK:   %[[INIT_PACKED_A:.*]] = linalg.init_tensor [%[[D0]], %[[D1]], 2] : tensor<?x?x2xf32>
-  //      CHECK:   %[[PACKED_A:.*]] = scf.for %[[II:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[INIT_PACKED_A]]) -> (tensor<?x?x2xf32>) {
+  //      CHECK:   %[[INIT_PACKED_A:.*]] = linalg.init_tensor [%[[D0]], 2, 2] : tensor<?x2x2xf32>
+  //      CHECK:   %[[CAST_INIT_PACKED_A:.*]] = tensor.cast %[[INIT_PACKED_A]] : tensor<?x2x2xf32> to tensor<?x?x2xf32>
+  //      CHECK:   %[[PACKED_A:.*]] = scf.for %[[II:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_A]]) -> (tensor<?x?x2xf32>) {
   //      CHECK:     scf.for %[[III:[0-9a-z]+]] =
   //      CHECK:       tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor<?x?x2xf32>
   //
   //      CHECK:   %[[D0_2:.*]] = affine.apply #[[$DIV4]](%[[MR8]])
-  //      CHECK:   %[[MM4_2:.*]] = affine.min #[[$MIN_MOD4]](%[[MR8]])
-  //      CHECK:   %[[D1_2:.*]] = affine.apply #[[$DIV2]](%[[MM4_2]])
   // Init tensor and pack.
-  //      CHECK:   %[[INIT_PACKED_B:.*]] = linalg.init_tensor [%[[D0_2]], %[[D1_2]], 2] : tensor<?x?x2xf32>
-  //      CHECK:   %[[PACKED_B:.*]] = scf.for %[[II_2:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[INIT_PACKED_B]]) -> (tensor<?x?x2xf32>) {
+  //      CHECK:   %[[INIT_PACKED_B:.*]] = linalg.init_tensor [%[[D0_2]], 2, 2] : tensor<?x2x2xf32>
+  //      CHECK:   %[[CAST_INIT_PACKED_B:.*]] = tensor.cast %[[INIT_PACKED_B]] : tensor<?x2x2xf32> to tensor<?x?x2xf32>
+  //      CHECK:   %[[PACKED_B:.*]] = scf.for %[[II_2:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_B]]) -> (tensor<?x?x2xf32>) {
   //      CHECK:     scf.for %[[III_2:[0-9a-z]+]] =
   //      CHECK:       tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor<?x?x2xf32>
   // Compute.


        


More information about the Mlir-commits mailing list