[Mlir-commits] [mlir] ddf93ab - [mlir][linalg] NFC: Move makeTiledShapes into Utils.{h|cpp}

Lei Zhang llvmlistbot at llvm.org
Wed Mar 24 15:19:31 PDT 2021


Author: Lei Zhang
Date: 2021-03-24T18:17:57-04:00
New Revision: ddf93abf49f7b753e8554fa47a4aaf811f40210a

URL: https://github.com/llvm/llvm-project/commit/ddf93abf49f7b753e8554fa47a4aaf811f40210a
DIFF: https://github.com/llvm/llvm-project/commit/ddf93abf49f7b753e8554fa47a4aaf811f40210a.diff

LOG: [mlir][linalg] NFC: Move makeTiledShapes into Utils.{h|cpp}

This is a preparation step to reuse makeTiledShapes in tensor
fusion. Along the way, did some lightweight cleanups.

Differential Revision: https://reviews.llvm.org/D99013

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 2dc208f429f4..33efeddadc9e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -17,13 +17,9 @@
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
-
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SetVector.h"
 
-using mlir::edsc::intrinsics::AffineIndexedValue;
-using mlir::edsc::intrinsics::MemRefIndexedValue;
-
 namespace mlir {
 class AffineExpr;
 class AffineForOp;
@@ -34,33 +30,32 @@ class PatternRewriter;
 namespace linalg {
 class LinalgDependenceGraph;
 
-/// A struct containing the Linalg producer before and after fusion.
-/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
-/// before the consumer Linalg op, until enough canonicalizations have applied.
-struct FusionInfo {
-  LinalgOp originalProducer;
-  LinalgOp fusedProducer;
-};
+//===----------------------------------------------------------------------===//
+// General utilities
+//===----------------------------------------------------------------------===//
 
-/// A struct containing common matchers over linalg op's region.
-struct RegionMatcher {
-  enum class BinaryOpKind {
-    IAdd,
-  };
+/// Apply the permutation defined by `permutation` to `inVec`.
+/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
+/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
+/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
+template <typename T, unsigned N>
+void applyPermutationToVector(SmallVector<T, N> &inVec,
+                              ArrayRef<unsigned> permutation) {
+  SmallVector<T, N> auxVec(inVec.size());
+  for (unsigned i = 0; i < permutation.size(); ++i)
+    auxVec[i] = inVec[permutation[i]];
+  inVec = auxVec;
+}
 
-  /// Matches the given linalg op if its body is performing binary operation on
-  /// int or float scalar values and returns the binary op kind.
-  ///
-  /// The linalg op's region is expected to be
-  /// ```
-  /// {
-  ///   ^bb(%a: <scalar-type>, %b: <scalar-type>):
-  ///     %0 = <binary-op> %a, %b: <scalar-type>
-  ///     linalg.yield %0: <scalar-type>
-  /// }
-  /// ```
-  static Optional<BinaryOpKind> matchAsScalarBinaryOp(GenericOp op);
-};
+/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
+/// is a constant then return a new value set to the smallest such constant.
+/// If `size` comes from a ConstantOp, return the constant.
+/// Otherwise return nullptr.
+IntegerAttr getSmallestBoundingIndex(Value size);
+
+//===----------------------------------------------------------------------===//
+// Iterator type utilities
+//===----------------------------------------------------------------------===//
 
 /// Checks if an iterator_type attribute is parallel.
 bool isParallelIteratorType(Attribute attr);
@@ -71,6 +66,10 @@ bool isReductionIteratorType(Attribute attr);
 /// Checks if an iterator_type attribute is parallel.
 bool isWindowIteratorType(Attribute attr);
 
+//===----------------------------------------------------------------------===//
+// Fusion utilities
+//===----------------------------------------------------------------------===//
+
 /// Checks whether the specific `producer` is the last write to exactly the
 /// whole `consumedView`. This checks structural dominance, that the dependence
 /// is a RAW without any interleaved write to any piece of `consumedView`.
@@ -84,6 +83,21 @@ bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
 bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
                    Value consumedView, LinalgOp producer);
 
+/// Creates subtensor/subview ops for all `tiledOperands` of the given
+/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
+/// nest for tiling with the given induction variables `ivs` and tile sizes
+/// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the
+/// implicit loops in `linalgOp`.
+///
+/// Note that a constant zero in `tileSizes` means no tiling at that implicit
+/// loop. The number of non-zero values in `tileSizes` should be equal to the
+/// number of values in `ivs`.
+SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
+                                      LinalgOp linalgOp,
+                                      ArrayRef<Value> tiledOperands,
+                                      ValueRange ivs, ValueRange tileSizes,
+                                      ArrayRef<Value> sizeBounds);
+
 using FusableOpDependencesTy = llvm::MapVector<
     Operation *,
     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
@@ -91,6 +105,14 @@ FusableOpDependencesTy
 findAllFusableDependences(ArrayRef<LinalgOp> ops,
                           const LinalgDependenceGraph &dependenceGraph);
 
+/// A struct containing the Linalg producer before and after fusion.
+/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
+/// before the consumer Linalg op, until enough canonicalizations have applied.
+struct FusionInfo {
+  LinalgOp originalProducer;
+  LinalgOp fusedProducer;
+};
+
 /// Fuses producer into consumer if the producer is structurally feasible and
 /// the fusion would not violate dependencies.
 /// Implements the fusion part of the "tileAndFuse on buffers" transformation
@@ -119,24 +141,9 @@ Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
 Optional<SmallVector<Value, 1>> fuseTensorOps(PatternRewriter &rewriter,
                                               OpOperand &consumerOpOperand);
 
-/// Apply the permutation defined by `permutation` to `inVec`.
-/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
-/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
-/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
-template <typename T, unsigned N>
-void applyPermutationToVector(SmallVector<T, N> &inVec,
-                              ArrayRef<unsigned> permutation) {
-  SmallVector<T, N> auxVec(inVec.size());
-  for (unsigned i = 0; i < permutation.size(); ++i)
-    auxVec[i] = inVec[permutation[i]];
-  inVec = auxVec;
-}
-
-/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
-/// is a constant then return a new value set to the smallest such constant.
-/// If `size` comes from a ConstantOp, return the constant.
-/// Otherwise return nullptr.
-IntegerAttr getSmallestBoundingIndex(Value size);
+//===----------------------------------------------------------------------===//
+// Distribution utilities
+//===----------------------------------------------------------------------===//
 
 /// Scheme used to distribute loops to processors.
 enum class DistributionMethod {
@@ -206,6 +213,34 @@ struct LinalgLoopDistributionOptions {
   SmallVector<DistributionMethod, 0> distributionMethod = {};
 };
 
+//===----------------------------------------------------------------------===//
+// Generic op region utilities
+//===----------------------------------------------------------------------===//
+
+/// A struct containing common matchers over linalg op's region.
+struct RegionMatcher {
+  enum class BinaryOpKind {
+    IAdd,
+  };
+
+  /// Matches the given linalg op if its body is performing binary operation on
+  /// int or float scalar values and returns the binary op kind.
+  ///
+  /// The linalg op's region is expected to be
+  /// ```
+  /// {
+  ///   ^bb(%a: <scalar-type>, %b: <scalar-type>):
+  ///     %0 = <binary-op> %a, %b: <scalar-type>
+  ///     linalg.yield %0: <scalar-type>
+  /// }
+  /// ```
+  static Optional<BinaryOpKind> matchAsScalarBinaryOp(GenericOp op);
+};
+
+//===----------------------------------------------------------------------===//
+// Loop nest utilities
+//===----------------------------------------------------------------------===//
+
 /// Utility class used to generate nested loops with ranges described by
 /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn`
 /// is used to generate the body of the innermost loop. It is passed a range
@@ -214,7 +249,8 @@ template <typename LoopTy>
 struct GenerateLoopNest {
   using IndexedValueTy =
       typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
-                                AffineIndexedValue, MemRefIndexedValue>::type;
+                                edsc::intrinsics::AffineIndexedValue,
+                                edsc::intrinsics::MemRefIndexedValue>::type;
 
   static void
   doit(ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index aaf00721732d..0c29bc05cb66 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -23,7 +23,6 @@
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -82,34 +81,6 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
         Range{std_constant_index(0), shapeSizes[idx], tileSizes[idx]});
   return std::make_tuple(res, loopIndexToRangeIndex);
 }
-namespace {
-
-// Helper visitor to determine whether an AffineExpr is tiled.
-// This is achieved by traversing every AffineDimExpr with position `pos` and
-// checking whether the corresponding `tileSizes[pos]` is non-zero.
-// This also enforces only positive coefficients occur in multiplications.
-//
-// Example:
-//   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
-//
-struct TileCheck : public AffineExprVisitor<TileCheck> {
-  TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {}
-
-  void visitDimExpr(AffineDimExpr expr) {
-    isTiled |= !isZero(tileSizes[expr.getPosition()]);
-  }
-  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
-    visit(expr.getLHS());
-    visit(expr.getRHS());
-    if (expr.getKind() == mlir::AffineExprKind::Mul)
-      assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
-             "nonpositive multiplying coefficient");
-  }
-  bool isTiled;
-  ValueRange tileSizes;
-};
-
-} // namespace
 
 // IndexedGenericOp explicitly uses induction variables in the loop body. The
 // values of the indices that are used in the loop body for any given access of
@@ -201,117 +172,6 @@ static void transformIndexedGenericOpIndices(
   }
 }
 
-static bool isTiled(AffineExpr expr, ValueRange tileSizes) {
-  if (!expr)
-    return false;
-  TileCheck t(tileSizes);
-  t.visit(expr);
-  return t.isTiled;
-}
-
-// Checks whether the `map  varies with respect to a non-zero `tileSize`.
-static bool isTiled(AffineMap map, ValueRange tileSizes) {
-  if (!map)
-    return false;
-  for (unsigned r = 0; r < map.getNumResults(); ++r)
-    if (isTiled(map.getResult(r), tileSizes))
-      return true;
-  return false;
-}
-
-static SmallVector<Value, 4>
-makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
-                ArrayRef<Value> tiledOperands, AffineMap map, ValueRange ivs,
-                ValueRange tileSizes, ValueRange allShapeSizes) {
-  assert(ivs.size() == static_cast<size_t>(llvm::count_if(
-                           llvm::make_range(tileSizes.begin(), tileSizes.end()),
-                           [](Value v) { return !isZero(v); })) &&
-         "expected as many ivs as non-zero sizes");
-
-  using namespace edsc::op;
-
-  auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
-  // Construct (potentially temporary) mins and maxes on which to apply maps
-  // that define tile subshapes.
-  SmallVector<Value, 8> lbs, subShapeSizes;
-  for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
-    bool isTiled = !isZero(tileSizes[idx]);
-    lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0));
-    // Before composing, we need to make range a closed interval.
-    Value size = isTiled ? tileSizes[idx] : shapeSizes[idx];
-    subShapeSizes.push_back(size - std_constant_index(1));
-  }
-
-  SmallVector<Value, 4> res;
-  res.reserve(tiledOperands.size());
-  for (auto en : llvm::enumerate(tiledOperands)) {
-    Value shapedOp = en.value();
-    ShapedType shapedType = shapedOp.getType().cast<ShapedType>();
-    unsigned rank = shapedType.getRank();
-    AffineMap map = linalgOp.getIndexingMap(en.index());
-    // If the shape is not tiled, we can use it as is.
-    if (!isTiled(map, tileSizes)) {
-      res.push_back(shapedOp);
-      continue;
-    }
-
-    // Construct a new subview / subtensor for the tile.
-    SmallVector<OpFoldResult, 4> offsets, sizes, strides;
-    offsets.reserve(rank);
-    sizes.reserve(rank);
-    strides.reserve(rank);
-    for (unsigned r = 0; r < rank; ++r) {
-      if (!isTiled(map.getSubMap({r}), tileSizes)) {
-        offsets.push_back(b.getIndexAttr(0));
-        sizes.push_back(memref_dim(shapedOp, r).value);
-        strides.push_back(b.getIndexAttr(1));
-        continue;
-      }
-
-      // Tiling creates a new slice at the proper index, the slice step is 1
-      // (i.e. the op does not subsample, stepping occurs in the loop).
-      auto m = map.getSubMap({r});
-      auto offset = applyMapToValues(b, loc, m, lbs).front();
-      offsets.push_back(offset);
-      auto closedIntSize = applyMapToValues(b, loc, m, subShapeSizes).front();
-      // Resulting size needs to be made half open interval again.
-      auto size = closedIntSize + std_constant_index(1);
-
-      // The size of the subview / subtensor should be trimmed to avoid
-      // out-of-bounds accesses, unless we statically know the subshape size
-      // divides the shape size evenly.
-      int64_t shapeSize = shapedType.getDimSize(r);
-      auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
-      if (ShapedType::isDynamic(shapeSize) || !sizeCst ||
-          (shapeSize % sizeCst.getValue()) != 0) {
-        // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
-        auto minMap = AffineMap::get(
-            /*dimCount=*/3, /*symbolCount=*/0,
-            {getAffineDimExpr(/*position=*/0, b.getContext()),
-             getAffineDimExpr(/*position=*/1, b.getContext()) -
-                 getAffineDimExpr(/*position=*/2, b.getContext())},
-            b.getContext());
-        Value d = memref_dim(shapedOp, r);
-        SmallVector<Value, 4> operands{size, d, offset};
-        fullyComposeAffineMapAndOperands(&minMap, &operands);
-        size = affine_min(b.getIndexType(), minMap, operands);
-      }
-
-      sizes.push_back(size);
-      strides.push_back(b.getIndexAttr(1));
-    }
-
-    if (shapedType.isa<MemRefType>())
-      res.push_back(
-          b.create<memref::SubViewOp>(loc, shapedOp, offsets, sizes, strides));
-    else
-      res.push_back(
-          b.create<SubTensorOp>(loc, shapedOp, offsets, sizes, strides));
-  }
-
-  return res;
-}
-
 template <typename LoopTy>
 static Optional<TiledLinalgOp>
 tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
@@ -401,9 +261,10 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
         assert(outputBuffers.empty() || iterArgs.empty());
         operands.append(outputBuffers.begin(), outputBuffers.end());
         operands.append(iterArgs.begin(), iterArgs.end());
-        SmallVector<Value, 4> tiledOperands =
-            makeTiledShapes(b, loc, op, operands, shapeSizesToLoopsMap,
-                            interchangedIvs, tileSizes, allShapeSizes);
+        auto sizeBounds =
+            applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
+        SmallVector<Value, 4> tiledOperands = makeTiledShapes(
+            b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds);
         auto nonShapedOperands = op.getAssumedNonShapedOperands();
         tiledOperands.append(nonShapedOperands.begin(),
                              nonShapedOperands.end());

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 8d91bd74712b..8fe3d8530c62 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -18,8 +18,10 @@
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/SCF/EDSC/Builders.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
@@ -27,9 +29,64 @@
 #include "mlir/Transforms/LoopUtils.h"
 
 using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
 using namespace mlir::linalg;
 using namespace mlir::scf;
 
+static bool isZero(Value v) {
+  if (auto cst = v.getDefiningOp<ConstantIndexOp>())
+    return cst.getValue() == 0;
+  return false;
+}
+
+namespace {
+
+// Helper visitor to determine whether an AffineExpr is tiled.
+// This is achieved by traversing every AffineDimExpr with position `pos` and
+// checking whether the corresponding `tileSizes[pos]` is non-zero.
+// This also enforces only positive coefficients occur in multiplications.
+//
+// Example:
+//   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
+//
+struct TileCheck : public AffineExprVisitor<TileCheck> {
+  TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {}
+
+  void visitDimExpr(AffineDimExpr expr) {
+    isTiled |= !isZero(tileSizes[expr.getPosition()]);
+  }
+  void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
+    visit(expr.getLHS());
+    visit(expr.getRHS());
+    if (expr.getKind() == mlir::AffineExprKind::Mul)
+      assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
+             "nonpositive multiplying coefficient");
+  }
+  bool isTiled;
+  ValueRange tileSizes;
+};
+
+} // namespace
+
+static bool isTiled(AffineExpr expr, ValueRange tileSizes) {
+  if (!expr)
+    return false;
+  TileCheck t(tileSizes);
+  t.visit(expr);
+  return t.isTiled;
+}
+
+// Checks whether the `map  varies with respect to a non-zero `tileSize`.
+static bool isTiled(AffineMap map, ValueRange tileSizes) {
+  if (!map)
+    return false;
+  for (unsigned r = 0; r < map.getNumResults(); ++r)
+    if (isTiled(map.getResult(r), tileSizes))
+      return true;
+  return false;
+}
+
 Optional<RegionMatcher::BinaryOpKind>
 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
   auto &region = op.region();
@@ -374,5 +431,98 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
 }
 
+SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
+                                      LinalgOp linalgOp,
+                                      ArrayRef<Value> tiledOperands,
+                                      ValueRange ivs, ValueRange tileSizes,
+                                      ArrayRef<Value> sizeBounds) {
+  assert(ivs.size() == static_cast<size_t>(llvm::count_if(
+                           llvm::make_range(tileSizes.begin(), tileSizes.end()),
+                           [](Value v) { return !isZero(v); })) &&
+         "expected as many ivs as non-zero sizes");
+
+  using namespace edsc::op;
+
+  // Construct (potentially temporary) mins and maxes on which to apply maps
+  // that define tile subshapes.
+  SmallVector<Value, 8> lbs, subShapeSizes;
+  for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
+    bool isTiled = !isZero(tileSizes[idx]);
+    lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0));
+    // Before composing, we need to make range a closed interval.
+    Value size = isTiled ? tileSizes[idx] : sizeBounds[idx];
+    subShapeSizes.push_back(size - std_constant_index(1));
+  }
+
+  MLIRContext *context = builder.getContext();
+  SmallVector<Value, 4> tiledShapes;
+  tiledShapes.reserve(tiledOperands.size());
+  for (auto en : llvm::enumerate(tiledOperands)) {
+    Value shapedOp = en.value();
+    ShapedType shapedType = shapedOp.getType().cast<ShapedType>();
+    unsigned rank = shapedType.getRank();
+    AffineMap map = linalgOp.getIndexingMap(en.index());
+    // If the shape is not tiled, we can use it as is.
+    if (!isTiled(map, tileSizes)) {
+      tiledShapes.push_back(shapedOp);
+      continue;
+    }
+
+    // Construct a new subview / subtensor for the tile.
+    SmallVector<OpFoldResult, 4> offsets, sizes, strides;
+    offsets.reserve(rank);
+    sizes.reserve(rank);
+    strides.reserve(rank);
+    for (unsigned r = 0; r < rank; ++r) {
+      if (!isTiled(map.getSubMap({r}), tileSizes)) {
+        offsets.push_back(builder.getIndexAttr(0));
+        sizes.push_back(memref_dim(shapedOp, r).value);
+        strides.push_back(builder.getIndexAttr(1));
+        continue;
+      }
+
+      // Tiling creates a new slice at the proper index, the slice step is 1
+      // (i.e. the op does not subsample, stepping occurs in the loop).
+      auto m = map.getSubMap({r});
+      auto offset = applyMapToValues(builder, loc, m, lbs).front();
+      offsets.push_back(offset);
+      auto closedIntSize =
+          applyMapToValues(builder, loc, m, subShapeSizes).front();
+      // Resulting size needs to be made half open interval again.
+      auto size = closedIntSize + std_constant_index(1);
+
+      // The size of the subview / subtensor should be trimmed to avoid
+      // out-of-bounds accesses, unless we statically know the subshape size
+      // divides the shape size evenly.
+      int64_t shapeSize = shapedType.getDimSize(r);
+      auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
+      if (ShapedType::isDynamic(shapeSize) || !sizeCst ||
+          (shapeSize % sizeCst.getValue()) != 0) {
+        AffineExpr dim0, dim1, dim2;
+        bindDims(context, dim0, dim1, dim2);
+        // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
+        auto minMap = AffineMap::get(
+            /*dimCount=*/3, /*symbolCount=*/0, {dim0, dim1 - dim2}, context);
+        Value d = memref_dim(shapedOp, r);
+        SmallVector<Value, 4> operands{size, d, offset};
+        fullyComposeAffineMapAndOperands(&minMap, &operands);
+        size = affine_min(builder.getIndexType(), minMap, operands);
+      }
+
+      sizes.push_back(size);
+      strides.push_back(builder.getIndexAttr(1));
+    }
+
+    if (shapedType.isa<MemRefType>())
+      tiledShapes.push_back(builder.create<memref::SubViewOp>(
+          loc, shapedOp, offsets, sizes, strides));
+    else
+      tiledShapes.push_back(
+          builder.create<SubTensorOp>(loc, shapedOp, offsets, sizes, strides));
+  }
+
+  return tiledShapes;
+}
+
 } // namespace linalg
 } // namespace mlir


        


More information about the Mlir-commits mailing list