[Mlir-commits] [mlir] 37e0fdd - [mlir][Linalg] Add basic support for TileAndFuse on Linalg on tensors.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Oct 26 10:19:56 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-26T17:19:08Z
New Revision: 37e0fdd072a95b51bcd0eb6b08d2762aa304e766

URL: https://github.com/llvm/llvm-project/commit/37e0fdd072a95b51bcd0eb6b08d2762aa304e766
DIFF: https://github.com/llvm/llvm-project/commit/37e0fdd072a95b51bcd0eb6b08d2762aa304e766.diff

LOG: [mlir][Linalg] Add basic support for TileAndFuse on Linalg on tensors.

This revision allows the fusion of the producer of input tensors in the consumer under a tiling transformation (which produces subtensors).
Many pieces are still missing (e.g. support init_tensors, better refactor LinalgStructuredOp interface support, try to merge implementations and reuse code) but this still allows getting started.

The greedy pass itself is just for testing purposes and will be extracted in a separate test pass.

Differential revision: https://reviews.llvm.org/D89491

Added: 
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index a35964c5eab4..389e5cc6d1fb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -27,6 +27,8 @@
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
 
+#include "llvm/ADT/STLExtras.h"
+
 namespace mlir {
 namespace linalg {
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 845873ff83df..1e1546407a56 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -213,6 +213,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return {range.begin(), range.begin() + $_op.getNumInputs()};
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the range over the input operands that are of buffer type.
+      }],
+      /*retTy=*/"SmallVector<Value, 4>",
+      /*methodName=*/"getInputBuffers",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return llvm::to_vector<4>(llvm::make_filter_range(
+          getInputs(), [](Value in){ return in.getType().isa<MemRefType>(); }));
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the subset of input operands that are of ranked tensor type.
@@ -337,6 +350,18 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return this->getOperation()->getOperand(i);
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the number of output buffers
+      }],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getNumOutputBuffers",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getNumOutputs() - this->getOperation()->getNumResults();
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the number of inputs and outputs, irrespective of their buffer or
@@ -404,6 +429,49 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return getInitTensors()[i];
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if the shaped operand index `i` is the index of an init
+        tensor.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isIndexOfAnInitTensor",
+      /*args=*/(ins "unsigned":$i),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(i < $_op.getNumShapedOperands() && "overflowing shaped operand index");
+        return i >= $_op.getNumInputs() + getNumOutputBuffers();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the relative init tensor index of the shaped operand index.
+      }],
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getInitTensorIndexFromShapedIndex",
+      /*args=*/(ins "unsigned":$i),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(isIndexOfAnInitTensor(i) && "expected an init tensor index");
+        return i - $_op.getNumInputs() - getNumOutputBuffers();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the index of the given init tensor value, or `None` if the value 
+        is not part of the init tensors.
+      }],
+      /*retTy=*/"llvm::Optional<unsigned>",
+      /*methodName=*/"getIndexOfInitTensor",
+      /*args=*/(ins "Value":$value),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto it = llvm::find(getInitTensors(), value);
+        if (it != getInitTensors().end())
+          return it - getInitTensors().begin();
+        return llvm::None;
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the number of inputs, output buffers and init tensors operands.
@@ -416,6 +484,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors();
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the `i`-th shaped operand value, which can be an arbitrary input
+        tensor/buffer, init tensor or output buffer.
+      }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getShapedOperand",
+      /*args=*/(ins "unsigned":$i),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(i < $_op.getNumShapedOperands());
+        return this->getOperation()->getOperand(i);
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the range over inputs, output buffers and init tensors.
@@ -473,19 +555,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the position of buffer in inputs + outputs list
+        Return the position of the shaped operand in the operand list.
       }],
       /*retTy=*/"Optional<unsigned>",
-      /*methodName=*/"getIndexOfInputAndOutputBuffer",
+      /*methodName=*/"getIndexOfShapedOperand",
       /*args=*/(ins "Value":$value),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         Optional<unsigned> inputIndex = getIndexOfInput(value);
         if (inputIndex.hasValue()) return inputIndex.getValue();
         Optional<unsigned> outputIndex = getIndexOfOutputBuffer(value);
-        if (outputIndex.hasValue()) {
+        if (outputIndex.hasValue())
           return $_op.getNumInputs() + outputIndex.getValue();
-        }
+        Optional<unsigned> initTensorIndex = getIndexOfInitTensor(value);
+        if (initTensorIndex.hasValue())
+          return $_op.getNumInputs() + $_op.getNumOutputBuffers() + initTensorIndex.getValue();
         return llvm::None;
       }]
     >,
@@ -628,8 +712,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     InterfaceMethod<
       /*desc=*/[{
         Clone the current operation with the given location and operands. This
-        is used to abstract away the optional underlying region creation. This 
-        does not change the balance between input, output_buffer and 
+        is used to abstract away the optional underlying region creation. This
+        does not change the balance between input, output_buffer and
         init_tensors operands.
       }],
       /*retTy=*/"Operation *",

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 61367dd79548..9b343b3b04ab 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -32,6 +32,9 @@ class PatternRewriter;
 namespace linalg {
 class LinalgDependenceGraph;
 
+/// A struct containing the Linalg producer before and after fusion.
+/// When operating on tensors, `fusedProducer` may feed into a `tensor_cast` op
+/// before the consumer Linalg op, until enough canonicalizations have applied.
 struct FusionInfo {
   LinalgOp originalProducer;
   LinalgOp fusedProducer;
@@ -81,13 +84,25 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
 
 /// Fuses producer into consumer if the producer is structurally feasible and
 /// the fusion would not violate dependencies.
+/// Implements the fusion part of the "tileAndFuse on buffers"
+/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
+/// to be a `subview` op (generally obtained by applying the tiling
+/// transformation).
 /// When non-null, the optional pointer `folder` is used to call into the
 /// `createAndFold` builder method. If `folder` is null, the regular `create`
 /// method is called.
-Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer,
-                                    unsigned consumerIdx,
-                                    const LinalgDependenceGraph &graph,
-                                    OperationFolder *folder = nullptr);
+Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
+                                          unsigned consumerIdx,
+                                          const LinalgDependenceGraph &graph,
+                                          OperationFolder *folder = nullptr);
+/// Tensor counterpart of `fuseProducerOfBuffer`.
+/// This implements the fusion part of the "tileAndFuse on tensors"
+/// transformation and thus requires the `consumerdIdx`^th operand of `consumer`
+/// to be the result of a `subtensor` op (generally obtained by applying the
+/// tiling transformation).
+Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
+                                          unsigned consumerIdx,
+                                          OperationFolder *folder);
 
 /// Fuse linalg operation on tensors, with the producer of the operand at
 /// position `consumerIdx` of the consumer.

diff  --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index bffd9bd1bd0c..58a8b3eddc3b 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -147,13 +147,9 @@ LinalgDependenceGraph::getDependencesInto(
 }
 
 void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
-  assert(src.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  assert(dst.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
   for (auto srcView : src.getOutputBuffers()) { // W
     // RAW graph
-    for (auto dstView : dst.getInputs()) {   // R
+    for (auto dstView : dst.getInputBuffers()) { // R
       if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
         addDependenceElem(DependenceType::RAW,
                           LinalgOpView{src.getOperation(), srcView},
@@ -169,9 +165,9 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
       }
     }
   }
-  for (auto srcView : src.getInputs()) { // R
+  for (auto srcView : src.getInputBuffers()) { // R
     // RAR graph
-    for (auto dstView : dst.getInputs()) {   // R
+    for (auto dstView : dst.getInputBuffers()) { // R
       if (aliases.alias(srcView, dstView)) { // if alias, fill RAR
         addDependenceElem(DependenceType::RAR,
                           LinalgOpView{src.getOperation(), srcView},

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 585b8810fdc2..8542c2afb086 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -41,97 +41,131 @@ using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>;
 
 using llvm::dbgs;
 
-/// Implements a simple high-level fusion pass of linalg library operations.
+/// Implements a simple high-level fusion pass on linalg structured operations.
 ///
 /// In each block, linalg ops are processed in reverse textual order.
 /// Given a linalg op `O`, fusion occurs by:
-///   1. inspecting the linalg ops that write into the views read by `O`. This
-///      uses the SSA value of the views and a simple subview/slice analysis to
-///      determine producer-consumer dependences;
-///   2. greedily fuse the linalg ops that produce subview
+///   1. inspecting the linalg ops that write into the views read by `O`. There
+///      are 2 cases:
+///      a) buffer case: use the SSA value of the views and a simple alias
+///         analysis on subview ops to determine producer-consumer dependences;
+///      b) tensor case: use SSA use-def chains on subtensor ops;
+///   2. greedily fuse the linalg ops that produce the subview/subtensor.
 ///   3. inspect the fused ops and determine whether they have other remaining
 ///      LinalgOp uses. If not, then erase the original producing linalg op.
 ///
 /// More advanced use cases, analyses as well as profitability heuristics are
 /// left for future work.
 
+// Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
+// by `permutationMap`.
+static void inferShapeComponents(AffineMap permutationMap,
+                                 ArrayRef<Range> loopRanges,
+                                 SmallVectorImpl<Value> &offsets,
+                                 SmallVectorImpl<Value> &sizes,
+                                 SmallVectorImpl<Value> &strides) {
+  assert(permutationMap.isProjectedPermutation() &&
+         "expected some subset of a permutation map");
+  SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
+  unsigned idx = 0;
+  for (AffineExpr e : permutationMap.getResults()) {
+    // loopToOperandRangesMaps are permutations-only, just swap indices.
+    unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
+    shapeRanges[idx++] = loopRanges[loopPos];
+  }
+  // Construct a new subshape for the tile.
+  unsigned rank = shapeRanges.size();
+  offsets.reserve(rank);
+  sizes.reserve(rank);
+  strides.reserve(rank);
+  for (auto r : shapeRanges) {
+    offsets.push_back(r.offset);
+    sizes.push_back(r.size);
+    strides.push_back(r.stride);
+  }
+}
+
 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
 // a subset of the original loop ranges of `op`.
 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
 // to the `loopRanges` in order to obtain view ranges.
 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
                                     ArrayRef<Range> loopRanges) {
-  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
-  auto maps = op.indexing_maps();
-  SmallVector<Value, 8> clonedViews;
-  clonedViews.reserve(op.getNumInputsAndOutputs());
-  // Iterate over the inputs and outputs in order.
+  SmallVector<Value, 8> clonedShapes;
+  clonedShapes.reserve(op.getNumShapedOperands());
+
+  // Iterate over the shape operands in order.
   // Extract the subranges from the linearized ranges.
-  SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
-  for (auto en : llvm::enumerate(ios)) {
-    unsigned idx = en.index();
-    auto map = maps[idx].cast<AffineMapAttr>().getValue();
-    LLVM_DEBUG(dbgs() << "map: " << map << "\n");
-    Value view = en.value();
-    SmallVector<Range, 4> viewRanges(map.getNumResults());
-    for (auto en2 : llvm::enumerate(map.getResults())) {
-      unsigned d = en2.index();
-      // loopToOperandRangesMaps are permutations-only.
-      unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
-      viewRanges[d] = loopRanges[loopPos];
-      LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
-                        << "\t"
-                        << "loopPos: " << loopPos << "\t" << viewRanges[d]);
-    }
-    // Construct a new subview for the tile.
-    unsigned rank = viewRanges.size();
+  for (auto en : llvm::enumerate(op.getShapedOperands())) {
+    unsigned shapedOperandIdx = en.index();
+    AffineMap map = op.getIndexingMap(shapedOperandIdx);
+    LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx
+                      << " with indexingMap: " << map << "\n");
     SmallVector<Value, 4> offsets, sizes, strides;
-    offsets.reserve(rank);
-    sizes.reserve(rank);
-    strides.reserve(rank);
-    for (auto r : viewRanges) {
-      offsets.push_back(r.offset);
-      sizes.push_back(r.size);
-      strides.push_back(r.stride);
-    }
-    clonedViews.push_back(
-        b.create<SubViewOp>(loc, view, offsets, sizes, strides));
+    inferShapeComponents(map, loopRanges, offsets, sizes, strides);
+    Value shape = en.value();
+    Value sub = shape.getType().isa<MemRefType>()
+                    ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
+                          .getResult()
+                    : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
+                          .getResult();
+    clonedShapes.push_back(sub);
   }
+  // Append the other operands.
   auto operands = op.getAssumedNonShapedOperands();
-  clonedViews.append(operands.begin(), operands.end());
+  clonedShapes.append(operands.begin(), operands.end());
+
+  // Iterate over the results in order.
+  // Extract the subtensor type from the linearized range.
+  // Since we do not enforce any canonicalizations on the fly, this is always
+  // fully dynamic at construction time.
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.reserve(op.getOperation()->getNumResults());
+  for (RankedTensorType t : op.getOutputTensorTypes()) {
+    unsigned rank = t.getRank();
+    SmallVector<int64_t, 4> staticOffsetsVector(
+        rank, ShapedType::kDynamicStrideOrOffset);
+    SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+    SmallVector<int64_t, 4> staticStridesVector(
+        rank, ShapedType::kDynamicStrideOrOffset);
+    resultTypes.push_back(SubTensorOp::inferResultType(
+        t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
+        staticStridesVector));
+  }
 
-  Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews);
-  // When the producer is an IndexedGenercOp, we have to transform its block
+  Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
+  // When the producer is an IndexedGenericOp, we have to transform its block
   // IV arguments according to the tiling of the consumer, i.e. offset them by
   // the values computed in `loopRanges`.
   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
     auto &block = indexedGenericOp.region().front();
-
     OpBuilder::InsertionGuard g(b);
     b.setInsertionPointToStart(&block);
     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
       Value oldIndex = block.getArgument(i);
+      // TODO: replace by an affine_apply.
       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
                                          loopRanges[i].offset);
       oldIndex.replaceAllUsesExcept(newIndex,
                                     SmallPtrSet<Operation *, 1>{newIndex});
     }
   }
+
   return clonedOp;
 }
 
-struct ViewDimension {
-  Value view;
+struct ShapeDimension {
+  Value shape;
   unsigned dimension;
 };
 
-// Given an `op`, returns the first (`view`, `dimension`) pair that identifies
+// Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
 // guarantees at least one such dimension is found. If multiple candidates exist
 // they must agree by construction (i.e. have the same size) and we just return
 // the first one.
-static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
-  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
+                                                unsigned loopDepth) {
   auto maps = op.indexing_maps();
   // Iterate over the inputs and outputs in order.
   // Extract the subranges from the linearized ranges.
@@ -139,43 +173,47 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
   for (auto en : llvm::enumerate(ios)) {
     unsigned idx = en.index();
     auto map = maps[idx].cast<AffineMapAttr>().getValue();
-    LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
-    LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
-    Value view = en.value();
-    SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr);
+    LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
+    LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n");
+    Value shape = en.value();
+    SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
     for (auto en2 : llvm::enumerate(map.getResults())) {
       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
-        LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
+        LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: "
+                          << loopDepth << "\n");
+        LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape
                           << "\n");
-        LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
-        return ViewDimension{view, static_cast<unsigned>(en2.index())};
+        return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
       }
     }
   }
-  llvm_unreachable("Expect to be able to extract a view defining loop range");
+  llvm_unreachable("Expect to be able to extract a shape defining loop range");
 }
 
+/// Fuses the producer of `producerIdx` into the loop immediately enclosing
+/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
+/// is needed just before the `consumer.
+///
+/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
+/// 2 cases:
+///   1. Buffer case: `producerIdx` is the index of the buffer in
+///      `producer.getOutputBuffers()`.
+///   2. Tensor case: `producerIdx` is the index of the tensor in
+///      `producer.getResults()`.
 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
                      LinalgOp consumer, unsigned consumerIdx,
                      OperationFolder *folder = nullptr) {
-  assert(producer.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  assert(consumer.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-
-  auto subView = dyn_cast_or_null<SubViewOp>(
-      consumer.getBuffer(consumerIdx).getDefiningOp());
-  auto slice = dyn_cast_or_null<SliceOp>(
-      consumer.getBuffer(consumerIdx).getDefiningOp());
-  assert(subView || slice);
-  (void)subView;
-  (void)slice;
+  Operation *shapeProducingOp =
+      consumer.getShapedOperand(consumerIdx).getDefiningOp();
+  assert((isa<SubViewOp>(shapeProducingOp) ||
+          isa<SubTensorOp>(shapeProducingOp)) &&
+         "SubviewOp or SubTensorOp expected");
 
   // loopToOperandRangesMaps are permutations-only by construction:
   //   we can always identify a data dimension with a (at least one) loop
   //   dimension.
-  AffineMap producerMap =
-      producer.indexing_maps()[producerIdx].cast<AffineMapAttr>().getValue();
+  // TODO: extend this with range inference.
+  AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
                     << ", producer map: " << producerMap << "\n");
 
@@ -190,20 +228,24 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
   for (auto en : llvm::enumerate(producerMap.getResults())) {
     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
     loopRanges[posInProducerLoop] =
-        subView.getOrCreateRanges(b, loc)[en.index()];
+        isa<SubViewOp>(shapeProducingOp)
+            ? cast<SubViewOp>(shapeProducingOp)
+                  .getOrCreateRanges(b, loc)[en.index()]
+            : cast<SubTensorOp>(shapeProducingOp)
+                  .getOrCreateRanges(b, loc)[en.index()];
   }
 
   // Iterate over all dimensions. For the dimensions not identified by the
-  // producer map for `producerIdx`, we need to explicitly compute the view that
-  // defines the loop ranges using the `producer`.
+  // producer map for `producerIdx`, we need to explicitly compute the shape
+  // that defines the loop ranges using the `producer`.
   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
     if (loopRanges[i].offset)
       LLVM_DEBUG(llvm::dbgs()
                  << "existing LoopRange: " << loopRanges[i] << "\n");
     else {
-      auto viewDim = getViewDefiningLoopRange(producer, i);
+      auto shapeDim = getShapeDefiningLoopRange(producer, i);
       loopRanges[i] = Range{folded_std_constant_index(folder, 0),
-                            std_dim(viewDim.view, viewDim.dimension),
+                            std_dim(shapeDim.shape, shapeDim.dimension),
                             folded_std_constant_index(folder, 1)};
       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
     }
@@ -269,7 +311,7 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
          "expected linalg op with buffer semantics");
   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
     return false;
-  // Check for any fusion-preventing dependence to any view read/written that
+  // Check for any fusion-preventing dependence to any shape read/written that
   // would violate dependences.
   if (!graph.findCoveringDependences(producer, consumer).empty()) {
     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
@@ -308,7 +350,7 @@ static bool isSameSubView(Value a, Value b) {
     return false;
   if (sva.static_strides() != svb.static_strides())
     return false;
-  /// Skip the "viewSource" operand.
+  /// Skip the "source" operand.
   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
     if (sva.getOperand(idx) != svb.getOperand(idx))
       return false;
@@ -354,7 +396,7 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
   return {};
 }
 
-Optional<FusionInfo> mlir::linalg::fuseProducerOf(
+Optional<FusionInfo> mlir::linalg::fuseProducerOfBuffer(
     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
     const LinalgDependenceGraph &graph, OperationFolder *folder) {
   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
@@ -381,7 +423,7 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
   ScopedContext scope(b, consumer.getLoc());
   LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
   Optional<unsigned> producerIdxOpt =
-      producerOp.getIndexOfInputAndOutputBuffer(producerView);
+      producerOp.getIndexOfOutputBuffer(producerView);
   assert(producerIdxOpt.hasValue() && "incorrect operand index");
   unsigned producerIdx = producerIdxOpt.getValue();
 
@@ -390,10 +432,75 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
   return FusionInfo{producerOp, fusedProducer};
 }
 
+/// Walk back use-def chain through scf::For yields.
+/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
+static void getProducerOfTensor(Value tensor, LinalgOp &producer,
+                                unsigned &outputIndex) {
+  if (!tensor.getType().isa<RankedTensorType>())
+    return;
+
+  while (true) {
+    if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
+      producer = linalgOp;
+      outputIndex = tensor.cast<OpResult>().getResultNumber();
+      return;
+    }
+    if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
+      tensor = subTensorOp.source();
+      continue;
+    }
+    if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
+      if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
+        tensor = forOp.getResult(blockArg.getArgNumber());
+        continue;
+      }
+    }
+    return;
+  }
+}
+
+Optional<FusionInfo>
+mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer,
+                                   unsigned consumerIdx,
+                                   OperationFolder *folder) {
+  Value inputTensor = consumer.getInput(consumerIdx);
+  LinalgOp producerOp;
+  unsigned producerIdx;
+  getProducerOfTensor(inputTensor, producerOp, producerIdx);
+
+  // Must be a subtensor to guarantee there are loops we can fuse into.
+  auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
+  if (!subTensor || !producerOp) {
+    LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)");
+    return {};
+  }
+
+  // Insert fused `producer` just before `consumer`.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(consumer.getOperation());
+  ScopedContext scope(b, consumer.getLoc());
+  LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
+  LinalgOp fusedProducer =
+      fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
+
+  // Replace use.
+  // Canonicalizations are not guaranteed to have happened before constructing
+  // `fusedProducer`. In the tensor case this can result in temporary type
+  // mismatches. Insert a `tensor_cast` op to propagate the transformation
+  // invariant that types are compatible.
+  Value def = fusedProducer.getOperation()->getResult(producerIdx);
+  OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx);
+  Type consumerType = use.get().getType();
+  if (consumerType != def.getType())
+    def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
+  use.set(def);
+  return FusionInfo{producerOp, fusedProducer};
+}
+
 /// Returns the positions of the loop in `op` that can be tiled based on the
 /// operations that are to be fused with it. For example, in a
 ///
-///   linalg. matmul ins(%a, %b : ...) outs(%c : ...)
+///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
 ///
 /// if the producer of %a needs to be fused with this op, only the `i` loop of
 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
@@ -475,7 +582,7 @@ static DenseSet<unsigned> collectTileAndFuseLoops(
   SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
   for (auto dependence : fusableDependences) {
     unsigned consumerIdx =
-        op.getIndexOfInputAndOutputBuffer(dependence.indexingView).getValue();
+        op.getIndexOfShapedOperand(dependence.indexingView).getValue();
     AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
     // Previously asserted that the consumerAccess map is a projected
     // permutation, so all results are known to be AffineDimExprs. To remove
@@ -522,8 +629,8 @@ findAllFusableDependences(LinalgOp op,
     LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
     Value producerView = fusableDependence->dependentOpView.view;
     unsigned producerIdx =
-        producerOp.getIndexOfInputAndOutputBuffer(producerView).getValue();
-    AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
+        producerOp.getIndexOfOutputBuffer(producerView).getValue();
+    AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx);
     if (!producerMap.isProjectedPermutation()) {
       op.emitError("unhandled non permutation indexing map for fused view in "
                    "producer for operand at index ")
@@ -531,8 +638,7 @@ findAllFusableDependences(LinalgOp op,
       return llvm::None;
     }
     Value consumerView = fusableDependence->indexingView;
-    unsigned consumerIdx =
-        op.getIndexOfInputAndOutputBuffer(consumerView).getValue();
+    unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue();
     if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
       op.emitError(
           "unhandled case where indexing map for fused view in the consumer is "
@@ -644,13 +750,11 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
   // Fuse the operands.
   for (auto producer : enumerate(fusableDependences)) {
     LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
-    unsigned producerIdx = producerOp
-                               .getIndexOfInputAndOutputBuffer(
-                                   producer.value().dependentOpView.view)
-                               .getValue();
-    unsigned consumerIdx =
-        op.getIndexOfInputAndOutputBuffer(producer.value().indexingView)
+    unsigned producerIdx =
+        producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view)
             .getValue();
+    unsigned consumerIdx =
+        op.getIndexOfShapedOperand(producer.value().indexingView).getValue();
     LinalgOp fusedOp =
         fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
     ret.fusedProducers.push_back(fusedOp);
@@ -703,34 +807,52 @@ static void fuseLinalgOpsGreedily(FuncOp f) {
   // Save original Linalg ops, we only want to make a pass over those.
   SmallVector<Operation *, 8> linalgOps;
   f.walk([&](LinalgOp op) {
-    if (op.hasBufferSemantics())
+    // TODO: support multi-results.
+    if (op.getOperation()->getNumResults() <= 1)
       linalgOps.push_back(op);
   });
 
-  // TODO: LinalgDependenceGraph should be able to update itself.
-  // The current naive and expensive reconstruction of the graph should be
-  // removed.
+  // Tile and Fuse for tensors inputs (TODO: all tensor operands).
   for (auto *op : llvm::reverse(linalgOps)) {
-    for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers();
-         id < e; ++id) {
-      linalg::Aliases aliases;
-      linalg::LinalgDependenceGraph graph(aliases, linalgOps);
-      if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
-        auto *originalOp = info->originalProducer.getOperation();
-        eraseSet.insert(originalOp);
-        auto *originalOpInLinalgOpsVector =
-            std::find(linalgOps.begin(), linalgOps.end(), originalOp);
-        *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
+      if (en.value().getType().isa<MemRefType>()) {
+        // TODO: LinalgDependenceGraph should be able to update itself.
+        // The current naive and expensive reconstruction of the graph should be
+        // removed.
+        linalg::Aliases aliases;
+        linalg::LinalgDependenceGraph graph(aliases, linalgOps);
+        if (auto info =
+                fuseProducerOfBuffer(b, op, en.index(), graph, &folder)) {
+          auto *originalOp = info->originalProducer.getOperation();
+          eraseSet.insert(originalOp);
+          auto *originalOpInLinalgOpsVector =
+              std::find(linalgOps.begin(), linalgOps.end(), originalOp);
+          *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+        }
+      } else {
+        assert(en.value().getType().isa<RankedTensorType>());
+        // Tile and Fuse tensor input (TODO: init_tensors too).
+        if (en.index() >= linalgOp.getNumInputs())
+          continue;
+        if (auto info = fuseProducerOfTensor(b, op, en.index(), &folder)) {
+          auto *originalOp = info->originalProducer.getOperation();
+          auto *originalOpInLinalgOpsVector =
+              std::find(linalgOps.begin(), linalgOps.end(), originalOp);
+          *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
+          // Don't mark for erasure in the tensor case, let DCE handle this.
+        }
       }
     }
   }
-  // The `fuseProducerOf` function performs structural checks and in particular
-  // that no covering read or write exist between the consumer and the producer.
-  // As a consequence, the only fusions that may occur preserve subsequent
-  // dependences and are guaranteed by construction to produce the whole view.
-  // We may thus erase the producer once it is fused.
+  // The `fuseProducerOfBuffer` function performs structural checks and in
+  // particular that no covering read or write exist between the consumer and
+  // the producer. As a consequence, the only fusions that may occur preserve
+  // subsequent dependences and are guaranteed by construction to produce the
+  // whole view. We may thus erase the producer once it is fused.
   for (auto *e : eraseSet)
     e->erase();
+
   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
 }
 

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
new file mode 100644
index 000000000000..e43f261632e9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -0,0 +1,84 @@
+// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-fusion -canonicalize -cse -split-input-file | FileCheck %s --check-prefix=CANONICALIZED
+
+#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#map2 = affine_map<(d0)[s0] -> (3, -d0 + s0)>
+#map3 = affine_map<(d0, d1) -> (2, d0 - d1)>
+#map4 = affine_map<(d0, d1) -> (3, d0 - d1)>
+
+func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %t0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+                     init(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+  %c4 = constant 4 : index
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %c3 = constant 3 : index
+  %c1 = constant 1 : index
+  %0 = dim %t0, %c0 : tensor<?x?xf32>
+  %1 = dim %t0, %c1 : tensor<?x?xf32>
+  %2 = dim %arg1, %c1 : tensor<?x?xf32>
+  %3 = scf.for %arg3 = %c0 to %0 step %c2 iter_args(%arg4 = %arg2) -> (tensor<?x?xf32>) {
+    %4 = scf.for %arg5 = %c0 to %2 step %c3 iter_args(%arg6 = %arg4) -> (tensor<?x?xf32>) {
+      %5 = scf.for %arg7 = %c0 to %1 step %c4 iter_args(%arg8 = %arg6) -> (tensor<?x?xf32>) {
+        %6 = subtensor %t0[%arg3, %arg7][%c2, 4][1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
+        %7 = subtensor %arg1[%arg7, %arg5][4, %c3][1, 1] : tensor<?x?xf32> to tensor<4x?xf32>
+        %8 = subtensor %arg8[%arg3, %arg5][%c2, %c3][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+        %9 = linalg.matmul ins(%6, %7 : tensor<?x4xf32>, tensor<4x?xf32>) init(%8 : tensor<?x?xf32>) -> tensor<?x?xf32>
+        %10 = subtensor_insert %9 into %arg8[%arg3, %arg5] [%c2, %c3] [1, 1]  : tensor<?x?xf32> into tensor<?x?xf32>
+        scf.yield %10 : tensor<?x?xf32>
+      }
+      scf.yield %5 : tensor<?x?xf32>
+    }
+    scf.yield %4 : tensor<?x?xf32>
+  }
+  return %3 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @matmul_tensors(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
+//  CHECK-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
+//  CHECK-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
+//       CHECK: %[[C0:.*]] = constant 0 : index
+//       CHECK: scf.for %[[I:[0-9a-z]*]]
+//  CHECK-NEXT:   scf.for %[[J:[0-9a-z]*]]
+//  CHECK-NEXT:     scf.for %[[K:[0-9a-z]*]]
+//
+// subtensor of the original program, first one refers to the unfused matmul and becomes a dead SSA value.
+//       CHECK:     subtensor %{{.*}}[%[[I]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x4xf32>
+//       CHECK:     %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] {{.*}} : tensor<?x?xf32> to tensor<4x?xf32>
+//       CHECK:     %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+//
+// subtensors of the producing matmul.
+//       CHECK:     %[[stA:.*]] = subtensor %[[A]][%[[I]], %[[C0]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+//  CHECK-NEXT:     %[[stB2:.*]] = subtensor %[[B]][%[[C0]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+//  CHECK-NEXT:     %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] {{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+//  CHECK-NEXT:     %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) init(%[[stC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//  CHECK-NEXT:     %[[stD2:.*]] = tensor_cast %[[stD]] : tensor<?x?xf32> to tensor<?x4xf32>
+//  CHECK-NEXT:     %[[stG:.*]] = linalg.matmul ins(%[[stD2]], %[[stB1]] : tensor<?x4xf32>, tensor<4x?xf32>) init(%[[stF]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//  CHECK-NEXT:     subtensor_insert %[[stG]]
+
+
+// CANONICALIZED-LABEL: func @matmul_tensors(
+//  CANONICALIZED-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
+//  CANONICALIZED-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
+//  CANONICALIZED-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
+//       CANONICALIZED: %[[C0:.*]] = constant 0 : index
+//       CANONICALIZED: %[[C1:.*]] = constant 1 : index
+//       CANONICALIZED: scf.for %[[I:[0-9a-z]*]]
+//  CANONICALIZED-NEXT:   scf.for %[[J:[0-9a-z]*]]
+//  CANONICALIZED-NEXT:     scf.for %[[K:[0-9a-z]*]]
+//
+//       CANONICALIZED:     %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1]  : tensor<?x?xf32> to tensor<4x3xf32>
+//       CANONICALIZED:     %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1]  : tensor<?x?xf32> to tensor<2x3xf32>
+//
+// subtensors of the producing matmul.
+//       CANONICALIZED:     %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor<?x?xf32>
+//       CANONICALIZED:     %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1]  : tensor<?x?xf32> to tensor<2x?xf32>
+//  CANONICALIZED-NEXT:     %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1]  : tensor<?x?xf32> to tensor<?x4xf32>
+//  CANONICALIZED-NEXT:     %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1]  : tensor<?x?xf32> to tensor<2x4xf32>
+//  CANONICALIZED-NEXT:     %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) init(%[[stC]] : tensor<2x4xf32>)  -> tensor<2x4xf32>
+//  CANONICALIZED-NEXT:     %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
+//  CANONICALIZED-NEXT:     subtensor_insert %[[stG]]


        


More information about the Mlir-commits mailing list