[Mlir-commits] [mlir] a3adcba - [mlir][Linalg] Implement tiling on tensors

Nicolas Vasilache llvmlistbot at llvm.org
Tue Oct 6 10:53:42 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-06T17:51:11Z
New Revision: a3adcba645eec31b42ad0a1f727975c5c9c236f0

URL: https://github.com/llvm/llvm-project/commit/a3adcba645eec31b42ad0a1f727975c5c9c236f0
DIFF: https://github.com/llvm/llvm-project/commit/a3adcba645eec31b42ad0a1f727975c5c9c236f0.diff

LOG: [mlir][Linalg] Implement tiling on tensors

This revision implements tiling on tensors as described in:
https://llvm.discourse.group/t/an-update-on-linalg-on-tensors/1878/4

Differential revision: https://reviews.llvm.org/D88733

Added: 
    mlir/test/Dialect/Linalg/tile-tensors.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index c10a1e4f4e04..614fd8d2a7de 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -647,7 +647,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       res.reserve(nExtraOperands);
       for (unsigned i = 0; i < nExtraOperands; ++i) {
         res.push_back(getOperation()->getOperand(numShapedOperands + i));
-        assert((res.back().getType().isSignlessIntOrIndexOrFloat() 
+        assert((res.back().getType().isSignlessIntOrIndexOrFloat()
                 || res.back().getType().isa<VectorType>()) &&
                "expected scalar or vector type");
       }

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e47dafc9bf52..2e566c941894 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -29,6 +29,7 @@ using LinalgLoops = SmallVector<Operation *, 4>;
 struct TiledLinalgOp {
   LinalgOp op;
   SmallVector<Operation *, 8> loops;
+  SmallVector<Value, 4> tensorResults;
 };
 
 struct TiledAndFusedLinalgOps {
@@ -371,8 +372,9 @@ struct LinalgBaseTilingPattern : public RewritePattern {
                           LinalgTilingOptions options,
                           LinalgMarker marker = LinalgMarker(),
                           PatternBenefit benefit = 1);
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
+                      SmallVectorImpl<Value> &tensorResults) const;
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
@@ -390,9 +392,14 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
                                 marker, benefit) {}
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(LinalgBaseTilingPattern::matchAndRewrite(op, rewriter)))
+    SmallVector<Value, 4> tensorResults;
+    if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
+                                                            tensorResults)))
       return failure();
-    rewriter.eraseOp(op);
+    if (tensorResults.empty())
+      rewriter.eraseOp(op);
+    else
+      rewriter.replaceOp(op, tensorResults);
     return success();
   }
 };

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b4e5be58bad7..ffcac5f48aa4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -95,17 +95,17 @@ Operation *fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
                          unsigned consumerIdx,
                          OperationFolder *folder = nullptr);
 
-/// Returns the linearized list of all view dimensions in a `linalgOp`. Applying
-/// the inverse, concatenated loopToOperandRangeMaps to this list allows the
-/// derivation of loop ranges for any linalgOp.
-SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp);
+/// Returns the linearized list of all shape dimensions in a `linalgOp`.
+/// Applying the inverse, concatenated loopToOperandRangeMaps to this list
+/// allows the derivation of loop ranges for any linalgOp.
+SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp);
 template <typename ConcreteOpTy>
-SmallVector<Value, 8> getViewSizes(OpBuilder &builder, ConcreteOpTy linalgOp) {
-  return getViewSizes(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
+SmallVector<Value, 8> getShape(OpBuilder &builder, ConcreteOpTy linalgOp) {
+  return getShape(builder, cast<linalg::LinalgOp>(linalgOp.getOperation()));
 }
 
 /// Returns the loop ranges of the `linalgOp`. Applies the inverse of the
-/// concatenated indexing maps to the result of `getViewSizes`. Returns None if
+/// concatenated indexing maps to the result of `getShape`. Returns None if
 /// the bounds computation fails.
 Optional<SmallVector<Value, 4>>
 getLoopRanges(OpBuilder &builder, LinalgOp linalgOp,
@@ -119,11 +119,6 @@ SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
                                        AffineMap map, ValueRange values,
                                        OperationFolder *folder = nullptr);
 
-/// Returns all the operands of `linalgOp` that are not views.
-/// Asserts that these operands are value types to allow transformations like
-/// tiling to just use the values when cloning `linalgOp`.
-SmallVector<Value, 4> getAssumedNonViewOperands(LinalgOp linalgOp);
-
 /// Apply the permutation defined by `permutation` to `inVec`.
 /// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
 /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 409f54384aca..747a83414a08 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -315,7 +315,7 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
 /// source memref. This is useful to to fold a memref_cast into a consuming op
 /// and implement canonicalization patterns for ops in 
diff erent dialects that
 /// may consume the results of memref_cast operations. Such foldable memref_cast
-/// operations are typically inserted as `view` and `subview` ops are
+/// operations are typically inserted as `view` and `subview` ops and are
 /// canonicalized, to preserve the type compatibility of their uses.
 ///
 /// Returns true when all conditions are met:

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index bd45c8d667f9..abfc0001ed3e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -199,6 +199,11 @@ static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
   if (isTopLevelValue(dimOp.memrefOrTensor()))
     return true;
 
+  // Conservatively handle remaining BlockArguments as non-valid symbols.
+  // E.g. scf.for iterArgs.
+  if (dimOp.memrefOrTensor().isa<BlockArgument>())
+    return false;
+
   // The dim op is also okay if its operand memref/tensor is a view/subview
   // whose corresponding size is a valid symbol.
   Optional<int64_t> index = dimOp.getConstantIndex();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 7b16a9197f11..585b8810fdc2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -97,7 +97,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
     clonedViews.push_back(
         b.create<SubViewOp>(loc, view, offsets, sizes, strides));
   }
-  auto operands = getAssumedNonViewOperands(op);
+  auto operands = op.getAssumedNonShapedOperands();
   clonedViews.append(operands.begin(), operands.end());
 
   Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index a9e7a8660230..9e96c8cdc691 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -508,10 +508,10 @@ Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) {
       linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
   auto maps = llvm::to_vector<8>(
       llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
-  SmallVector<Value, 8> sizes = getViewSizes(builder, linalgOp);
+  SmallVector<Value, 8> sizes = getShape(builder, linalgOp);
   AffineMap map = concatAffineMaps(maps);
   auto loopRanges = emitLoopRanges(scope.getBuilderRef(), scope.getLocation(),
-                                   map, getViewSizes(builder, linalgOp));
+                                   map, getShape(builder, linalgOp));
   SmallVector<Value, 4> allIvs;
   GenerateLoopNest<LoopTy>::doit(
       loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(),

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 3e8e0b74c145..f7becae6e328 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -56,18 +56,17 @@ using LoopIndexToRangeIndexMap = DenseMap<int, int>;
 // indices of newly created loops.
 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
 makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
-                    ArrayRef<Value> allViewSizes,
-                    ArrayRef<Value> allTileSizes) {
+                    ValueRange allShapeSizes, ValueRange allTileSizes) {
   assert(allTileSizes.size() == map.getNumResults());
-  // Apply `map` to get view sizes in loop order.
-  auto viewSizes = applyMapToValues(b, loc, map, allViewSizes);
+  // Apply `map` to get shape sizes in loop order.
+  auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
   SmallVector<Value, 4> tileSizes(allTileSizes.begin(), allTileSizes.end());
 
   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
     if (isZero(tileSizes[idx - zerosCount])) {
-      viewSizes.erase(viewSizes.begin() + idx - zerosCount);
+      shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
       ++zerosCount;
       continue;
@@ -78,10 +77,10 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
   // Create a new range with the applied tile sizes.
   SmallVector<Range, 4> res;
   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
-    res.push_back(Range{std_constant_index(0), viewSizes[idx], tileSizes[idx]});
+    res.push_back(
+        Range{std_constant_index(0), shapeSizes[idx], tileSizes[idx]});
   return std::make_tuple(res, loopIndexToRangeIndex);
 }
-
 namespace {
 
 // Helper visitor to determine whether an AffineExpr is tiled.
@@ -93,7 +92,7 @@ namespace {
 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
 //
 struct TileCheck : public AffineExprVisitor<TileCheck> {
-  TileCheck(ArrayRef<Value> tileSizes) : isTiled(false), tileSizes(tileSizes) {}
+  TileCheck(ValueRange tileSizes) : isTiled(false), tileSizes(tileSizes) {}
 
   void visitDimExpr(AffineDimExpr expr) {
     isTiled |= !isZero(tileSizes[expr.getPosition()]);
@@ -106,7 +105,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
              "nonpositive multiplying coefficient");
   }
   bool isTiled;
-  ArrayRef<Value> tileSizes;
+  ValueRange tileSizes;
 };
 
 } // namespace
@@ -165,7 +164,6 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
 static void transformIndexedGenericOpIndices(
     OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
     const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
-  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
   auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op.getOperation());
   if (!indexedGenericOp)
     return;
@@ -202,7 +200,7 @@ static void transformIndexedGenericOpIndices(
   }
 }
 
-static bool isTiled(AffineExpr expr, ArrayRef<Value> tileSizes) {
+static bool isTiled(AffineExpr expr, ValueRange tileSizes) {
   if (!expr)
     return false;
   TileCheck t(tileSizes);
@@ -210,9 +208,8 @@ static bool isTiled(AffineExpr expr, ArrayRef<Value> tileSizes) {
   return t.isTiled;
 }
 
-// Checks whether the view with index `viewIndex` within `linalgOp` varies with
-// respect to a non-zero `tileSize`.
-static bool isTiled(AffineMap map, ArrayRef<Value> tileSizes) {
+// Checks whether the `map  varies with respect to a non-zero `tileSize`.
+static bool isTiled(AffineMap map, ValueRange tileSizes) {
   if (!map)
     return false;
   for (unsigned r = 0; r < map.getNumResults(); ++r)
@@ -221,13 +218,11 @@ static bool isTiled(AffineMap map, ArrayRef<Value> tileSizes) {
   return false;
 }
 
-static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
-                                            LinalgOp linalgOp, AffineMap map,
-                                            ArrayRef<Value> ivs,
-                                            ArrayRef<Value> tileSizes,
-                                            ArrayRef<Value> allViewSizes) {
-  assert(linalgOp.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
+static SmallVector<Value, 4>
+makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp,
+                ValueRange operands, AffineMap map, ValueRange ivs,
+                ValueRange tileSizes, ValueRange allShapeSizes) {
+  assert(operands.size() == linalgOp.getShapedOperands().size());
   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
                            [](Value v) { return !isZero(v); })) &&
@@ -235,37 +230,34 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
 
   using namespace edsc::op;
 
-  auto viewSizes = applyMapToValues(b, loc, map, allViewSizes);
+  auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
   // Construct (potentially temporary) mins and maxes on which to apply maps
-  // that define tile subviews.
-  SmallVector<Value, 8> lbs, subViewSizes;
+  // that define tile subshapes.
+  SmallVector<Value, 8> lbs, subShapeSizes;
   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
     bool isTiled = !isZero(tileSizes[idx]);
     lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0));
     // Before composing, we need to make range a closed interval.
-    Value size = isTiled ? tileSizes[idx] : viewSizes[idx];
-    subViewSizes.push_back(size - std_constant_index(1));
+    Value size = isTiled ? tileSizes[idx] : shapeSizes[idx];
+    subShapeSizes.push_back(size - std_constant_index(1));
   }
 
   auto *op = linalgOp.getOperation();
 
   SmallVector<Value, 4> res;
   res.reserve(op->getNumOperands());
-  auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin();
-  for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
-       ++viewIndex) {
-    Value view = *(viewIteratorBegin + viewIndex);
-    auto viewType = view.getType().cast<MemRefType>();
-    unsigned rank = viewType.getRank();
-    auto mapAttr = linalgOp.indexing_maps()[viewIndex];
-    auto map = mapAttr.cast<AffineMapAttr>().getValue();
-    // If the view is not tiled, we can use it as is.
+  for (auto en : llvm::enumerate(operands)) {
+    Value shapedOp = en.value();
+    ShapedType shapedType = shapedOp.getType().cast<ShapedType>();
+    unsigned rank = shapedType.getRank();
+    AffineMap map = linalgOp.getIndexingMap(en.index());
+    // If the shape is not tiled, we can use it as is.
     if (!isTiled(map, tileSizes)) {
-      res.push_back(view);
+      res.push_back(shapedOp);
       continue;
     }
 
-    // Construct a new subview for the tile.
+    // Construct a new subview / subtensor for the tile.
     SmallVector<Value, 4> offsets, sizes, strides;
     offsets.reserve(rank);
     sizes.reserve(rank);
@@ -273,27 +265,27 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
     for (unsigned r = 0; r < rank; ++r) {
       if (!isTiled(map.getSubMap({r}), tileSizes)) {
         offsets.push_back(std_constant_index(0));
-        sizes.push_back(std_dim(view, r));
+        sizes.push_back(std_dim(shapedOp, r));
         strides.push_back(std_constant_index(1));
         continue;
       }
 
       // Tiling creates a new slice at the proper index, the slice step is 1
-      // (i.e. the slice view does not subsample, stepping occurs in the loop).
+      // (i.e. the op does not subsample, stepping occurs in the loop).
       auto m = map.getSubMap({r});
       auto offset = applyMapToValues(b, loc, m, lbs).front();
       offsets.push_back(offset);
-      auto closedIntSize = applyMapToValues(b, loc, m, subViewSizes).front();
+      auto closedIntSize = applyMapToValues(b, loc, m, subShapeSizes).front();
       // Resulting size needs to be made half open interval again.
       auto size = closedIntSize + std_constant_index(1);
 
-      // The size of the subview should be trimmed to avoid out-of-bounds
-      // accesses, unless we statically know the subview size divides the view
-      // size evenly.
-      int64_t viewSize = viewType.getDimSize(r);
+      // The size of the subview / subtensor should be trimmed to avoid
+      // out-of-bounds accesses, unless we statically know the subshape size
+      // divides the shape size evenly.
+      int64_t shapeSize = shapedType.getDimSize(r);
       auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
-      if (ShapedType::isDynamic(viewSize) || !sizeCst ||
-          (viewSize % sizeCst.getValue()) != 0) {
+      if (ShapedType::isDynamic(shapeSize) || !sizeCst ||
+          (shapeSize % sizeCst.getValue()) != 0) {
         // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
         auto minMap = AffineMap::get(
             /*dimCount=*/3, /*symbolCount=*/0,
@@ -301,7 +293,7 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
              getAffineDimExpr(/*position=*/1, b.getContext()) -
                  getAffineDimExpr(/*position=*/2, b.getContext())},
             b.getContext());
-        auto d = std_dim(view, r);
+        auto d = std_dim(shapedOp, r);
         size =
             affine_min(b.getIndexType(), minMap, ValueRange{size, d, offset});
       }
@@ -310,7 +302,12 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
       strides.push_back(std_constant_index(1));
     }
 
-    res.push_back(b.create<SubViewOp>(loc, view, offsets, sizes, strides));
+    if (shapedType.isa<MemRefType>())
+      res.push_back(
+          b.create<SubViewOp>(loc, shapedOp, offsets, sizes, strides));
+    else
+      res.push_back(
+          b.create<SubTensorOp>(loc, shapedOp, offsets, sizes, strides));
   }
 
   return res;
@@ -318,7 +315,7 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
 
 template <typename LoopTy>
 static Optional<TiledLinalgOp>
-tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
+tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
                  const LinalgTilingOptions &options) {
   auto nLoops = op.getNumLoops();
   // Initial tile sizes may be too big, only take the first nLoops.
@@ -335,20 +332,20 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
   }
 
   // 1. Build the tiled loop ranges.
-  auto allViewSizes = getViewSizes(b, op);
+  auto allShapeSizes = getShape(b, op);
   // The flattened loopToOperandRangesMaps is expected to be an invertible
   // permutation map (asserted in the inverse calculation).
   auto mapsRange = op.indexing_maps().getAsRange<AffineMapAttr>();
   auto maps = llvm::to_vector<8>(
       llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); }));
-  auto viewSizesToLoopsMap = inversePermutation(concatAffineMaps(maps));
-  if (!viewSizesToLoopsMap)
+  auto shapeSizesToLoopsMap = inversePermutation(concatAffineMaps(maps));
+  if (!shapeSizesToLoopsMap)
     return llvm::None;
 
   SmallVector<Range, 4> loopRanges;
   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
-      b, op.getLoc(), viewSizesToLoopsMap, allViewSizes, tileSizes);
+      b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
   SmallVector<Attribute, 4> iteratorTypes;
   for (auto attr :
        enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
@@ -380,29 +377,77 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
 
   // 2. Create the tiled loops.
   LinalgOp res = op;
-  SmallVector<Value, 4> ivs;
+  SmallVector<Value, 4> ivs, tensorResults;
+  auto initTensors = op.getInitTensors();
   GenerateLoopNest<LoopTy>::doit(
-      loopRanges, /*iterArgInitValues*/ {}, iteratorTypes,
+      loopRanges, /*iterArgInitValues*/ initTensors, iteratorTypes,
       [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector {
         auto &b = ScopedContext::getBuilderRef();
         auto loc = ScopedContext::getLocation();
         ivs.assign(localIvs.begin(), localIvs.end());
-        SmallVector<Value, 4> ivValues(ivs.begin(), ivs.end());
 
-        // If we have to apply a permutation to the tiled loop nest, we have to
-        // reorder the induction variables This permutation is the right one
-        // assuming that loopRanges have previously been permuted by
-        // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation
-        // of that one: (d0,d1,d2)->(d2,d0,d1)
+        // When an `interchangeVector` is present, it has been applied to the
+        // loop ranges and the iterator types. Apply its inverse to the
+        // resulting loop `ivs` to match the op definition.
+        SmallVector<Value, 4> interchangedIvs;
         if (!options.interchangeVector.empty())
-          ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues);
-
-        auto views = makeTiledViews(b, loc, op, viewSizesToLoopsMap, ivValues,
-                                    tileSizes, allViewSizes);
-        auto operands = getAssumedNonViewOperands(op);
-        views.append(operands.begin(), operands.end());
-        res = op.clone(b, loc, /*resultTypes*/ {}, views);
-        return scf::ValueVector{};
+          interchangedIvs = applyMapToValues(b, loc, invPermutationMap, ivs);
+        else
+          interchangedIvs.assign(ivs.begin(), ivs.end());
+
+        assert(op.getNumInitTensors() == iterArgs.size() &&
+               "num init tensors must match number of loop iter arguments");
+        // This uses knowledge about position of the init tensor in the list
+        // of operands.
+        auto operands = llvm::to_vector<4>(op.getShapedOperands());
+        std::copy(iterArgs.begin(), iterArgs.end(),
+                  operands.begin() + op.getNumInputsAndOutputBuffers());
+
+        SmallVector<Value, 4> tiledOperands =
+            makeTiledShapes(b, loc, op, operands, shapeSizesToLoopsMap,
+                            interchangedIvs, tileSizes, allShapeSizes);
+        auto nonShapedOperands = op.getAssumedNonShapedOperands();
+        tiledOperands.append(nonShapedOperands.begin(),
+                             nonShapedOperands.end());
+
+        // If LinalgOp has results, they must all be tied to init tensors.
+        // We enforce this to ensure all tiled ops have been rewritten in
+        // "init tensor" form. This ensures tiling has anchor values into which
+        // to subtensor / subtensor_insert. Otherwise tiling would need to
+        // allocate which is not acceptable.
+        // This would not be the case with a special terminator op that
+        // generates the whole tensor (instead of inserting a subtensor). But
+        // the generator-based abstraction has other issues.
+        assert(op.getNumInitTensors() == op.getOperation()->getNumResults() &&
+               "expected same number of init tensors as number of results");
+
+        // Handle init tensor operands.
+        // This uses knowledge about position of the init tensor in the list
+        // of operands.
+        // TODO: InterfaceAdaptor ?
+        SmallVector<Type, 4> resultTensorTypes;
+        for (auto idx : llvm::seq<unsigned>(0, op.getNumInitTensors()))
+          resultTensorTypes.push_back(
+              tiledOperands[op.getNumInputsAndOutputBuffers() + idx].getType());
+
+        res = op.clone(b, loc, resultTensorTypes, tiledOperands);
+
+        // Insert a subtensor_insert for each init subtensor.
+        for (unsigned idx = 0, e = op.getNumInitTensors(); idx != e; ++idx) {
+          Value initTensor =
+              tiledOperands[op.getNumInputsAndOutputBuffers() + idx];
+          if (auto subtensor = initTensor.getDefiningOp<SubTensorOp>()) {
+            tensorResults.push_back(b.create<SubTensorInsertOp>(
+                loc, subtensor.source().getType(),
+                res.getOperation()->getResult(idx), subtensor.source(),
+                subtensor.offsets(), subtensor.sizes(), subtensor.strides(),
+                subtensor.static_offsets(), subtensor.static_sizes(),
+                subtensor.static_strides()));
+          } else {
+            tensorResults.push_back(res.getOperation()->getResult(idx));
+          }
+        }
+        return scf::ValueVector(tensorResults.begin(), tensorResults.end());
       },
       options.distribution);
 
@@ -422,7 +467,16 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
       loops.push_back(nullptr);
     }
   }
-  return TiledLinalgOp{res, loops};
+
+  // 5. Get the tensor results from the outermost loop if available. Otherwise
+  // use the previously captured `tensorResults`.
+  Operation *outermostLoop = nullptr;
+  for (Operation *loop : loops)
+    if ((outermostLoop = loop))
+      break;
+
+  return TiledLinalgOp{
+      res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
 }
 
 template <typename LoopTy>
@@ -432,7 +486,6 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
   b.setInsertionPoint(op);
   ScopedContext scope(b, op.getLoc());
 
-  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
   // Enforce the convention that "tiling by zero" skips tiling a particular
   // dimension. This convention is significantly simpler to handle instead of
   // adjusting affine maps to account for missing dimensions.
@@ -513,7 +566,9 @@ mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
   scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
   ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
+  SubTensorOp::getCanonicalizationPatterns(patterns, ctx);
   SubViewOp::getCanonicalizationPatterns(patterns, ctx);
+  TensorCastOp::getCanonicalizationPatterns(patterns, ctx);
   ViewOp::getCanonicalizationPatterns(patterns, ctx);
   CanonicalizationPatternList<
 #define GET_OP_LIST

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 56652cbcb527..71e3108b2b58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -111,19 +111,34 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
     : RewritePattern(opName, {}, benefit, context), marker(marker),
       options(options) {}
 
-LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
+LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
+    Operation *op, PatternRewriter &rewriter,
+    SmallVectorImpl<Value> &tensorResults) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
   if (!linalgOp)
     return failure();
   if (failed(marker.checkAndNotify(rewriter, linalgOp)))
     return failure();
 
+  // If LinalgOp has results, they must all be tied to init tensors.
+  // We enforce this to ensure all tiled ops have been rewritten in
+  // "init tensor" form. This ensures tiling has anchor values into which to
+  // subtensor / subtensor_insert. Otherwise tiling would need to allocate which
+  // is not acceptable.
+  // This would not be the case with a special terminator op that generates the
+  // whole tensor (instead of inserting a subtensor). But the generator-based
+  // abstraction has other issues.
+  if (linalgOp.getNumInitTensors() != linalgOp.getOperation()->getNumResults())
+    return failure();
+
   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
 
   if (!res)
     return failure();
 
+  // Return relevant information to derived pattern.
+  tensorResults = res->tensorResults;
+
   // New marker if specified.
   marker.replaceLinalgMarker(rewriter, res->op.getOperation());
   return success();

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index f9ea9092d55d..3f29949ffe63 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -85,26 +85,6 @@ SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
   return res;
 }
 
-/// Returns all the operands of `linalgOp` that are not views.
-/// Asserts that these operands are value types to allow transformations like
-/// tiling to just use the values when cloning `linalgOp`.
-SmallVector<Value, 4>
-mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
-  auto *op = linalgOp.getOperation();
-  unsigned numViews = linalgOp.getNumInputsAndOutputs();
-  unsigned nOperands = op->getNumOperands() - numViews;
-  SmallVector<Value, 4> res;
-  res.reserve(nOperands);
-  for (unsigned i = 0; i < nOperands; ++i) {
-    res.push_back(op->getOperand(numViews + i));
-    auto t = res.back().getType();
-    (void)t;
-    assert((t.isSignlessIntOrIndexOrFloat() || t.isa<VectorType>()) &&
-           "expected scalar or vector type");
-  }
-  return res;
-}
-
 bool mlir::linalg::isParallelIteratorType(Attribute attr) {
   if (auto strAttr = attr.dyn_cast<StringAttr>()) {
     return strAttr.getValue() == getParallelIteratorTypeName();
@@ -147,12 +127,12 @@ namespace mlir {
 namespace linalg {
 
 /// Return the linearized list of all view dimensions in a linalgOp.
-SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp) {
+SmallVector<Value, 8> getShape(OpBuilder &builder, LinalgOp linalgOp) {
   auto loc = linalgOp.getLoc();
   SmallVector<Value, 8> res;
   SmallVector<unsigned, 4> ranks;
-  for (auto v : linalgOp.getInputsAndOutputBuffers()) {
-    MemRefType t = v.getType().template cast<MemRefType>();
+  for (Value v : linalgOp.getShapedOperands()) {
+    ShapedType t = v.getType().template cast<ShapedType>();
     ranks.push_back(t.getRank());
     for (unsigned i = 0; i < t.getRank(); ++i)
       res.push_back(builder.create<DimOp>(loc, v, i));
@@ -181,7 +161,7 @@ SmallVector<Value, 8> getViewSizes(OpBuilder &builder, LinalgOp linalgOp) {
 
 Optional<SmallVector<Value, 4>>
 getLoopRanges(OpBuilder &builder, LinalgOp linalgOp, OperationFolder *folder) {
-  SmallVector<Value, 8> viewSizes = getViewSizes(builder, linalgOp);
+  SmallVector<Value, 8> viewSizes = getShape(builder, linalgOp);
   AffineMap invertedMap =
       inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps()));
   if (!invertedMap)

diff  --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
new file mode 100644
index 000000000000..b899cb3e0049
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -mlir-disable-threading=true | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors(
+// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @matmul_tensors(
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
+//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
+//      CHECK:       %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:                                  init(%[[sTC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//      CHECK:       %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
+//      CHECK:       scf.yield %[[TD]] : tensor<?x?xf32>
+//      CHECK:     scf.yield %[[TD2]] : tensor<?x?xf32>
+//      CHECK:   scf.yield %[[TD1]] : tensor<?x?xf32>
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+                     init(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+//      CHECK: return %[[TD0]] : tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}


        


More information about the Mlir-commits mailing list