[Mlir-commits] [mlir] e58597e - [mlir][linalg] Fuse producers with non-permutation indexing maps

Lei Zhang llvmlistbot at llvm.org
Wed Mar 24 15:19:33 PDT 2021


Author: Lei Zhang
Date: 2021-03-24T18:17:57-04:00
New Revision: e58597ee1c7dfe4fd2fdf6f5f0230f22b849c9be

URL: https://github.com/llvm/llvm-project/commit/e58597ee1c7dfe4fd2fdf6f5f0230f22b849c9be
DIFF: https://github.com/llvm/llvm-project/commit/e58597ee1c7dfe4fd2fdf6f5f0230f22b849c9be.diff

LOG: [mlir][linalg] Fuse producers with non-permutation indexing maps

Until now Linalg fusion only allow fusing producers whose operands
are all permutation indexing maps. It's easier to deduce the
subtensor/subview but it is an unnecessary constraint, as in tiling
we have more advanced logic to deduce the subranges even when the
operand is not of permutation indexing maps, e.g., the input operand
for convolution ops.

This patch uses the logic on tiling side to deduce subranges for
fusion. This enables fusing convolution with its consumer ops
when possible.

Along the way, we are now generating proper affine.min ops to guard
against size boundaries, if we cannot be certain they won't be
out of bounds.

Differential Revision: https://reviews.llvm.org/D99014

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/fusion-pattern.mlir
    mlir/test/Dialect/Linalg/fusion-sequence.mlir
    mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
    mlir/test/Dialect/Linalg/fusion.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index ffa2811c37d2..ea7fd62baad4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -59,104 +59,6 @@ using llvm::dbgs;
 /// More advanced use cases, analyses as well as profitability heuristics are
 /// left for future work.
 
-// Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
-// by `permutationMap`.
-static void inferShapeComponents(AffineMap permutationMap,
-                                 ArrayRef<Range> loopRanges,
-                                 SmallVectorImpl<OpFoldResult> &offsets,
-                                 SmallVectorImpl<OpFoldResult> &sizes,
-                                 SmallVectorImpl<OpFoldResult> &strides) {
-  assert(permutationMap.isProjectedPermutation() &&
-         "expected some subset of a permutation map");
-  SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
-  unsigned idx = 0;
-  for (AffineExpr e : permutationMap.getResults()) {
-    // loopToOperandRangesMaps are permutations-only, just swap indices.
-    unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
-    shapeRanges[idx++] = loopRanges[loopPos];
-  }
-  // Construct a new subshape for the tile.
-  unsigned rank = shapeRanges.size();
-  offsets.reserve(rank);
-  sizes.reserve(rank);
-  strides.reserve(rank);
-  for (auto r : shapeRanges) {
-    offsets.push_back(r.offset);
-    sizes.push_back(r.size);
-    strides.push_back(r.stride);
-  }
-}
-
-// Return a cloned version of `op` that operates on `loopRanges`, assumed to be
-// a subset of the original loop ranges of `op`.
-// This is achieved by applying the `loopToOperandRangesMaps` permutation maps
-// to the `loopRanges` in order to obtain view ranges.
-static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
-                                    ArrayRef<Range> loopRanges) {
-  SmallVector<Value, 8> clonedShapes;
-  clonedShapes.reserve(op.getNumShapedOperands());
-
-  // Iterate over the shape operands in order.
-  // Extract the subranges from the linearized ranges.
-  for (auto en : llvm::enumerate(op.getShapedOperands())) {
-    unsigned shapedOperandIdx = en.index();
-    AffineMap map = op.getIndexingMap(shapedOperandIdx);
-    LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
-                            << " with indexingMap: " << map << "\n");
-    SmallVector<OpFoldResult, 4> offsets, sizes, strides;
-    inferShapeComponents(map, loopRanges, offsets, sizes, strides);
-    Value shape = en.value();
-    Value sub =
-        shape.getType().isa<MemRefType>()
-            ? b.create<memref::SubViewOp>(loc, shape, offsets, sizes, strides)
-                  .getResult()
-            : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
-                  .getResult();
-    clonedShapes.push_back(sub);
-  }
-  // Append the other operands.
-  auto operands = op.getAssumedNonShapedOperands();
-  clonedShapes.append(operands.begin(), operands.end());
-
-  // Iterate over the results in order.
-  // Extract the subtensor type from the linearized range.
-  // Since we do not enforce any canonicalizations on the fly, this is always
-  // fully dynamic at construction time.
-  SmallVector<Type, 4> resultTypes;
-  resultTypes.reserve(op->getNumResults());
-  for (RankedTensorType t : op.getOutputTensorTypes()) {
-    unsigned rank = t.getRank();
-    SmallVector<int64_t, 4> staticOffsetsVector(
-        rank, ShapedType::kDynamicStrideOrOffset);
-    SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
-    SmallVector<int64_t, 4> staticStridesVector(
-        rank, ShapedType::kDynamicStrideOrOffset);
-    resultTypes.push_back(SubTensorOp::inferResultType(
-        t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
-        staticStridesVector));
-  }
-
-  Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
-  // When the producer is an IndexedGenericOp, we have to transform its block
-  // IV arguments according to the tiling of the consumer, i.e. offset them by
-  // the values computed in `loopRanges`.
-  if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
-    auto &block = indexedGenericOp.region().front();
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPointToStart(&block);
-    for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
-      Value oldIndex = block.getArgument(i);
-      // TODO: replace by an affine_apply.
-      AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
-                                         loopRanges[i].offset);
-      oldIndex.replaceAllUsesExcept(newIndex,
-                                    SmallPtrSet<Operation *, 1>{newIndex});
-    }
-  }
-
-  return clonedOp;
-}
-
 struct ShapeDimension {
   Value shape;
   unsigned dimension;
@@ -208,35 +110,86 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
   llvm_unreachable("Expect to be able to extract a shape defining loop range");
 }
 
-/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
+/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
 /// provides the loop range information for the fused loops. The rest are
 /// obtained from the producer itself, since they are not tiled + fused.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
+static LinalgOp fuse(OpBuilder &builder, LinalgOp producer,
                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
-
-  unsigned nPar = producer.getNumParallelLoops();
-  unsigned nRed = producer.getNumReductionLoops();
-  unsigned nWin = producer.getNumWindowLoops();
-  SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
-  for (auto fusedLoops : fusedLoopsAndRanges)
-    loopRanges[fusedLoops.first] = fusedLoops.second;
-
-  // Iterate over all dimensions. For the dimensions not identified by the
-  // producer map for `producerIdx`, we need to explicitly compute the shape
-  // that defines the loop ranges using the `producer`.
-  for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
-    if (loopRanges[i].offset)
-      LLVM_DEBUG(llvm::dbgs()
-                 << "existing LoopRange: " << loopRanges[i] << "\n");
-    else {
+  SmallVector<Value, 8> ivs, tileSizes, sizeBounds;
+  SmallVector<Range, 8> loopRanges;
+  auto zero = std_constant_index(0);
+  auto one = std_constant_index(1);
+  Location loc = producer.getLoc();
+
+  for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
+    auto it = fusedLoopsAndRanges.find(i);
+    if (it != fusedLoopsAndRanges.end()) {
+      ivs.push_back(it->second.offset);
+      tileSizes.push_back(it->second.size);
+      sizeBounds.push_back(nullptr);
+      loopRanges.push_back(it->second);
+      LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
+                              << loopRanges.back() << "\n");
+    } else {
       auto shapeDim = getShapeDefiningLoopRange(producer, i);
       Value dim = memref_dim(shapeDim.shape, shapeDim.dimension);
-      loopRanges[i] = Range{std_constant_index(0), dim, std_constant_index(1)};
-      LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
+      tileSizes.push_back(zero);
+      sizeBounds.push_back(dim);
+      loopRanges.push_back(Range{zero, dim, one});
+      LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
+                              << loopRanges.back() << "\n");
+    }
+  }
+
+  SmallVector<Value, 8> clonedShapes;
+  clonedShapes.reserve(producer.getNumShapedOperands());
+
+  // Compute subranges for all tensor input/output operands.
+  auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands());
+  clonedShapes.append(makeTiledShapes(builder, loc, producer, tiledOperands,
+                                      ivs, tileSizes, sizeBounds));
+
+  // Append the other operands.
+  auto operands = producer.getAssumedNonShapedOperands();
+  clonedShapes.append(operands.begin(), operands.end());
+
+  // Iterate over the results in order.
+  // Extract the subtensor type from the linearized range.
+  // Since we do not enforce any canonicalizations on the fly, this is always
+  // fully dynamic at construction time.
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.reserve(producer->getNumResults());
+  for (RankedTensorType t : producer.getOutputTensorTypes()) {
+    unsigned rank = t.getRank();
+    SmallVector<int64_t, 4> staticOffsetsVector(
+        rank, ShapedType::kDynamicStrideOrOffset);
+    SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+    SmallVector<int64_t, 4> staticStridesVector(
+        rank, ShapedType::kDynamicStrideOrOffset);
+    resultTypes.push_back(SubTensorOp::inferResultType(
+        t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
+        staticStridesVector));
+  }
+
+  Operation *clonedOp = producer.clone(builder, loc, resultTypes, clonedShapes);
+  // When the producer is an IndexedGenericOp, we have to transform its block
+  // IV arguments according to the tiling of the consumer, i.e. offset them by
+  // the values computed in `loopRanges`.
+  if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
+    auto &block = indexedGenericOp.region().front();
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPointToStart(&block);
+    for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
+      Value oldIndex = block.getArgument(i);
+      // TODO: replace by an affine_apply.
+      AddIOp newIndex = builder.create<AddIOp>(indexedGenericOp.getLoc(),
+                                               oldIndex, loopRanges[i].offset);
+      oldIndex.replaceAllUsesExcept(newIndex,
+                                    SmallPtrSet<Operation *, 1>{newIndex});
     }
   }
 
-  return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
+  return clonedOp;
 }
 
 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 8fe3d8530c62..8c8b0cf1f7bf 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -27,6 +27,9 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/LoopUtils.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-utils"
 
 using namespace mlir;
 using namespace mlir::edsc;
@@ -447,11 +450,14 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
   // that define tile subshapes.
   SmallVector<Value, 8> lbs, subShapeSizes;
   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
+    LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
     bool isTiled = !isZero(tileSizes[idx]);
     lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0));
     // Before composing, we need to make range a closed interval.
     Value size = isTiled ? tileSizes[idx] : sizeBounds[idx];
     subShapeSizes.push_back(size - std_constant_index(1));
+    LLVM_DEBUG(llvm::dbgs() << "lb: " << lbs.back() << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "size: " << subShapeSizes.back() << "\n");
   }
 
   MLIRContext *context = builder.getContext();
@@ -459,14 +465,18 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
   tiledShapes.reserve(tiledOperands.size());
   for (auto en : llvm::enumerate(tiledOperands)) {
     Value shapedOp = en.value();
+    LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
     ShapedType shapedType = shapedOp.getType().cast<ShapedType>();
     unsigned rank = shapedType.getRank();
     AffineMap map = linalgOp.getIndexingMap(en.index());
     // If the shape is not tiled, we can use it as is.
     if (!isTiled(map, tileSizes)) {
       tiledShapes.push_back(shapedOp);
+      LLVM_DEBUG(llvm::dbgs()
+                 << ": not tiled: use shape: " << shapedType << "\n");
       continue;
     }
+    LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
 
     // Construct a new subview / subtensor for the tile.
     SmallVector<OpFoldResult, 4> offsets, sizes, strides;
@@ -474,22 +484,28 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
     sizes.reserve(rank);
     strides.reserve(rank);
     for (unsigned r = 0; r < rank; ++r) {
+      LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for dim#" << r);
       if (!isTiled(map.getSubMap({r}), tileSizes)) {
         offsets.push_back(builder.getIndexAttr(0));
-        sizes.push_back(memref_dim(shapedOp, r).value);
+        Value dim = memref_dim(shapedOp, r).value;
+        sizes.push_back(dim);
         strides.push_back(builder.getIndexAttr(1));
+        LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
         continue;
       }
+      LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
 
       // Tiling creates a new slice at the proper index, the slice step is 1
       // (i.e. the op does not subsample, stepping occurs in the loop).
       auto m = map.getSubMap({r});
+      LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: submap: " << map << "\n");
       auto offset = applyMapToValues(builder, loc, m, lbs).front();
       offsets.push_back(offset);
       auto closedIntSize =
           applyMapToValues(builder, loc, m, subShapeSizes).front();
       // Resulting size needs to be made half open interval again.
       auto size = closedIntSize + std_constant_index(1);
+      LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: raw size: " << size << "\n");
 
       // The size of the subview / subtensor should be trimmed to avoid
       // out-of-bounds accesses, unless we statically know the subshape size
@@ -498,6 +514,9 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
       auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
       if (ShapedType::isDynamic(shapeSize) || !sizeCst ||
           (shapeSize % sizeCst.getValue()) != 0) {
+        LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: shapeSize=" << shapeSize
+                                << ", size: " << size
+                                << ": make sure in bound with affine.min\n");
         AffineExpr dim0, dim1, dim2;
         bindDims(context, dim0, dim1, dim2);
         // Compute min(size, dim - offset) to avoid out-of-bounds accesses.
@@ -510,6 +529,9 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
       }
 
       sizes.push_back(size);
+      LLVM_DEBUG(llvm::dbgs()
+                 << "makeTiledShapes: new offset: " << offset << "\n");
+      LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: new size: " << size << "\n");
       strides.push_back(builder.getIndexAttr(1));
     }
 

diff  --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index aefeeb5e3ada..b9ba18bbd05a 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -16,6 +16,7 @@ module {
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
 //  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
 //      CHECK: func @basic_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@@ -47,8 +48,10 @@ module {
 //      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
 //      CHECK:     %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[TILE_N_2]]]
+//      CHECK:     %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
+//      CHECK:     %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV1]], %[[TILE_N]])[%[[N_2]]]
 //      CHECK:     %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK-SAME:       [%[[TILE_M]], %[[TILE_N]]]
+// CHECK-SAME:       [%[[TILE_M_3]], %[[TILE_N_3]]]
 //      CHECK:     linalg.fill(%[[SV3_2]], %[[CST]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_basic_fusion_producer"
 //      CHECK:     scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
@@ -86,6 +89,7 @@ module {
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
 //  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
 //      CHECK: func @rhs_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@@ -112,10 +116,13 @@ module {
 //      CHECK:     %[[SV2:.+]] = memref.subview %[[ARG3]][0, %[[IV0]]]
 // CHECK-SAME:       [%[[M]], %[[TILE_N_2]]]
 //      CHECK:     %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]]
+//      CHECK:     %[[N_3:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//      CHECK:     %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N_3]]]
 //      CHECK:     %[[SV3:.+]] = memref.subview %[[ARG1]][0, %[[IV0]]]
-// CHECK-SAME:       [%[[K_2]], %[[TILE_N]]]
+// CHECK-SAME:       [%[[K_2]], %[[TILE_N_3]]]
+//      CHECK:     %[[TILE_N_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N]]]
 //      CHECK:     %[[SV3_2:.+]] = memref.subview %[[ARG2]][0, %[[IV0]]]
-// CHECK-SAME:       [%[[K_2]], %[[TILE_N]]]
+// CHECK-SAME:       [%[[K]], %[[TILE_N_4]]]
 //      CHECK:     linalg.copy(%[[SV3]], %[[SV3_2]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_rhs_fusion_producer"
 //  CHECK-NOT:     linalg.fill
@@ -164,6 +171,7 @@ module {
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
 //  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
 //      CHECK: func @two_operand_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@@ -191,13 +199,17 @@ module {
 //      CHECK:     %[[N:.+]] = memref.dim %[[ARG3]], %[[C1]]
 //      CHECK:     %[[SV2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
+//      CHECK:     %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
 //      CHECK:     %[[SV2_2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[N]]]
+// CHECK-SAME:       [%[[TILE_M_3]], %[[N]]]
+//      CHECK:     %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//      CHECK:     %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
 //      CHECK:     %[[K_2:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //      CHECK:     %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[K_2]]]
+// CHECK-SAME:       [%[[TILE_M_4]], %[[K_2]]]
+//      CHECK:     %[[TILE_M_5:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]]
 //      CHECK:     %[[SV3_2:.+]] = memref.subview %[[ARG1]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[K_2]]]
+// CHECK-SAME:       [%[[TILE_M_5]], %[[K]]]
 //      CHECK:     linalg.copy(%[[SV3]], %[[SV3_2]])
 // CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
 //      CHECK:     linalg.fill(%[[SV2_2]], %[[CST]])
@@ -271,23 +283,24 @@ module {
 //      CHECK:     %[[N:.+]] = memref.dim %[[ARG4]], %[[C1]]
 //      CHECK:     %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
-//      CHECK:     %[[K2_2:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//      CHECK:     %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//      CHECK:     %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
 //      CHECK:     %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //      CHECK:     %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[K1]]]
-//      CHECK:     %[[SV4:.+]] = memref.subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]]
+// CHECK-SAME:       [%[[TILE_M_3]], %[[K1]]]
+//      CHECK:     %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]]
 //      CHECK:     %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[K2_2]]]
+// CHECK-SAME:       [%[[TILE_M_4]], %[[K2]]]
 //      CHECK:     linalg.matmul
 // CHECK-SAME:         __internal_linalg_transform__ = "after_lhs_fusion_producer"
-// CHECK-SAME:         ins(%[[SV3]], %[[SV4]]
-// CHECK-SAME:           : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:         ins(%[[SV3]], %[[ARG1]]
+// CHECK-SAME:           : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
 // CHECK-SAME:         outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
-//  CHECK-DAG:     %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
+//      CHECK:     %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
 //      CHECK:     scf.parallel (%[[IV1:.+]]) =
 // CHECK-SAME:       (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
-// CHECK-NEXT:       scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
-//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]]
+// CHECK-NEXT:       scf.for %[[IV2:.+]] = %[[C0]] to %[[K2]] step %[[C16]] {
+//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K2]]]
 //      CHECK:         %[[SV6:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
 // CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
 //      CHECK:         %[[K_2:.+]] = memref.dim %[[ARG3]], %[[C0]]
@@ -348,10 +361,11 @@ module {
 //       CHECK:     %[[T6:.+]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
 //       CHECK:     %[[T8:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0]
 //       CHECK:     %[[T9:.+]] = memref.subview %[[ARG1]][0, %[[ARG4]]]
+//       CHECK:     %[[T10:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]]
 //       CHECK:     linalg.matmul
 //  CHECK-SAME:       after_transpose_fusion_producer
 //  CHECK-SAME:       ins(%[[T8]], %[[T9]]
-//  CHECK-SAME:       outs(%[[T5]]
+//  CHECK-SAME:       outs(%[[T10]]
 //   CHECK-NOT:     linalg.matmul
 //       CHECK:     linalg.generic
 //  CHECK-SAME:       ins(%[[T5]], %[[T5]]

diff  --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
index bec19b325a7b..981db2bfea7f 100644
--- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
@@ -36,18 +36,19 @@ module {
 //  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
 //       CHECK:   %[[TEMP:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
 //       CHECK:   scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
-//   CHECK-DAG:     %[[SV_TEMP:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
+//       CHECK:     %[[SV_TEMP_1:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
 //   CHECK-DAG:     %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[IV1]]]
 //   CHECK-DAG:     %[[SV_ARG3:.+]] = memref.subview %[[ARG3]][%[[IV0]], %[[IV1]]]
 //   CHECK-DAG:     %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
 //   CHECK-DAG:     %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]]
-//       CHECK:     linalg.fill(%[[SV_TEMP]], %{{.+}})
+//       CHECK:     %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
+//       CHECK:     linalg.fill(%[[SV_TEMP_2]], %{{.+}})
 //       CHECK:     linalg.matmul
 //  CHECK-SAME:       ins(%[[SV_ARG0]], %[[SV_ARG1]]
 //  CHECK-SAME:         : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
-//  CHECK-SAME:       outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
+//  CHECK-SAME:       outs(%[[SV_TEMP_2]] : memref<?x?xf32, #[[MAP2]]>)
 //       CHECK:     linalg.generic
-//  CHECK-SAME:       ins(%[[SV_TEMP]], %[[SV_ARG2]]
+//  CHECK-SAME:       ins(%[[SV_TEMP_1]], %[[SV_ARG2]]
 //  CHECK-SAME:         : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
 //  CHECK-SAME:       outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
 //       CHECK:     scf.yield
@@ -83,6 +84,8 @@ module {
 
 //   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
 //   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+
 //       CHECK: func @sequence_of_matmul
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@@ -100,37 +103,40 @@ module {
 //       CHECK:   scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
 //  CHECK-SAME:     step (%[[C16]]) {
 //       CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-//       CHECK:     %[[SV_ALLOC2:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
+//       CHECK:     %[[SV_ALLOC3:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
 //  CHECK-SAME:       [%[[TILE_M]], %[[N2]]]
 //       CHECK:     %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]]
 //       CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
 //       CHECK:     %[[N3:.+]] = memref.dim %[[ARG4]], %[[C1]]
 //       CHECK:     %[[SV_ARG4:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
 //  CHECK-SAME:       [%[[TILE_M_2]], %[[N3]]]
+//       CHECK:     %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
 //       CHECK:     %[[SV_ARG4_2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
-//  CHECK-SAME:       [%[[TILE_M]], %[[N3]]]
+//  CHECK-SAME:       [%[[TILE_M_3]], %[[N3]]]
+//       CHECK:     %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]]
 //       CHECK:     %[[SV_ALLOC1:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0]
-//  CHECK-SAME:       [%[[TILE_M]], %[[N1]]]
-//       CHECK:     %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
+//  CHECK-SAME:       [%[[TILE_M_4]], %[[N1]]]
+//       CHECK:     %[[SV_ALLOC2:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M_4]], %[[N2]]]
 //       CHECK:     %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //       CHECK:     %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-//  CHECK-SAME:       [%[[TILE_M:.+]], %[[N0]]]
-//       CHECK:     %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
+//  CHECK-SAME:       [%[[TILE_M_4]], %[[N0]]]
 //       CHECK:     linalg.fill(%[[SV_ALLOC1]], %{{.+}})
-//       CHECK:     linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
-//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+//       CHECK:     linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]]
+//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
 //  CHECK-SAME:        outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
 //       CHECK:     linalg.fill(%[[SV_ALLOC2]], %{{.+}})
-//       CHECK:     linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
-//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+//       CHECK:     linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]]
+//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
 //  CHECK-SAME:        outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
 //       CHECK:     linalg.fill(%[[SV_ARG4_2]], %{{.+}})
-//       CHECK:     linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
+//       CHECK:     linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]]
 //  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
 //  CHECK-SAME:        outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
 //       CHECK:     scf.yield
 //       CHECK:   }
 
+
 // -----
 
 module {
@@ -189,8 +195,8 @@ module {
 module {
   func @tensor_matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
                              %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
-			     %arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
-			     %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
+           %arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
+           %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
     %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
     %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
@@ -200,7 +206,12 @@ module {
     return %2 : tensor<?x?xf32>
   }
 }
-// CHECK-LABEL: func @tensor_matmul_fusion(
+
+//       CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//       CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (16, d0 - d1)>
+//       CHECK: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+
+//       CHECK: func @tensor_matmul_fusion(
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -210,36 +221,39 @@ module {
 //  CHECK-SAME:   %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
 //   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
 //   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//       CHECK:   %[[M:.+]] = memref.dim %[[ARG0]], %c0 : tensor<?x?xf32>
 //       CHECK:   %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] =
 //  CHECK-SAME:     iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
-//       CHECK:       %[[N3:.+]] = memref.dim %[[ARG8]], %[[C1]]
-//       CHECK:       %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0]
-//  CHECK-SAME:         [%{{[a-zA-Z0-9_]+}}, %[[N3]]]
-//       CHECK:       %[[N2:.+]] = memref.dim %[[ARG3]], %[[C1]]
-//       CHECK:       %[[N1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-//       CHECK:       %[[STARG3:.+]] = subtensor %[[ARG3]][0, 0]
-//  CHECK-SAME:         [%[[N1]], %[[N2]]]
-//       CHECK:       %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0]
-//  CHECK-SAME:         [%{{[a-zA-Z0-9_]+}}, %[[N2]]]
-//       CHECK:       %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
-//       CHECK:       %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
-//  CHECK-SAME:         [%{{[a-zA-Z0-9_]+}}, %[[N0]]]
-//       CHECK:       %[[STARG1:.+]] = subtensor %[[ARG1]][0, 0]
-//  CHECK-SAME:         [%[[N0]], %[[N1]]]
-//       CHECK:       %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
-//  CHECK-SAME:         [%{{[a-zA-Z0-9_]+}}, %[[N1]]]
-//       CHECK:       %[[T0:.+]] = linalg.matmul
-//  CHECK-SAME:         ins(%[[STARG0]], %[[STARG1]]
-//  CHECK-SAME:         ) outs(%[[STARG2]] : tensor<?x?xf32>)
-//       CHECK:       %[[T1:.+]] = linalg.matmul
-//  CHECK-SAME:         ins(%[[T0]], %[[STARG3]]
-//  CHECK-SAME:         ) outs(%[[STARG4]] : tensor<?x?xf32>)
-//       CHECK:       %[[T2:.+]] = linalg.matmul
-//  CHECK-SAME:         ins(%[[T1]], %[[ARG5]]
-//  CHECK-SAME:         ) outs(%[[STARG6]] : tensor<?x?xf32>)
-//       CHECK:       %[[R1:.+]] = subtensor_insert %[[T2]]
-//  CHECK-SAME:         into %[[ARG8]][%[[IV0]], 0]
-//       CHECK:       scf.yield %[[R1]]
-//       CHECK:     }
-//       CHECK:     return %[[R0]]
+//       CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//       CHECK:     %[[M_1:.+]] = memref.dim %[[ARG8]], %[[C0]]
+//       CHECK:     %[[TILE_M_1:.+]] = affine.min #[[MAP1]](%[[M_1]], %[[IV0]])
+//       CHECK:     %[[N3:.+]] = memref.dim %[[ARG8]], %[[C1]]
+//       CHECK:     %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M_1]], %[[N3]]]
+//       CHECK:     %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]]
+//       CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
+//       CHECK:     %[[N2:.+]] = memref.dim %[[ARG4]], %[[C1]]
+//       CHECK:     %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M_2]], %[[N2]]]
+//       CHECK:     %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]]
+//       CHECK:     %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//       CHECK:     %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M_3]], %[[N0]]]
+//       CHECK:     %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]]
+//       CHECK:     %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
+//       CHECK:     %[[N1:.+]] = memref.dim %[[ARG2]], %[[C1]]
+//       CHECK:     %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
+//  CHECK-SAME:       [%[[TILE_M_4]], %[[N1]]]
+//       CHECK:     %[[T0:.+]] = linalg.matmul
+//  CHECK-SAME:       ins(%[[STARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>
+//  CHECK-SAME:       ) outs(%[[STARG2]] : tensor<?x?xf32>)
+//       CHECK:     %[[T1:.+]] = linalg.matmul
+//  CHECK-SAME:       ins(%[[T0]], %arg3 : tensor<?x?xf32>, tensor<?x?xf32>
+//  CHECK-SAME:       ) outs(%[[STARG4]] : tensor<?x?xf32>)
+//       CHECK:     %[[T2:.+]] = linalg.matmul
+//  CHECK-SAME:       ins(%[[T1]], %arg5 : tensor<?x?xf32>, tensor<?x?xf32>
+//  CHECK-SAME:       ) outs(%[[STARG6]] : tensor<?x?xf32>)
+//       CHECK:     %[[R1:.+]] = subtensor_insert %[[T2]]
+//  CHECK-SAME:       into %[[ARG8]][%[[IV0]], 0] [%[[TILE_M_1]], %[[N3]]]
+//       CHECK:     scf.yield %[[R1]] : tensor<?x?xf32>
 //       CHECK:   }

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index bd0d61c8580e..7f1131815d7c 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -17,12 +17,15 @@ module {
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
 //  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
 //  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (64, d0 - d1)>
+//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+
 //      CHECK: func @matmul_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+
 //  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
 //  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
 //  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
@@ -38,18 +41,20 @@ module {
 //      CHECK:     %[[N3:.+]] = memref.dim %[[ARG6]], %[[C1]]
 //      CHECK:     %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], 0]
 // CHECK-SAME:       [%[[TILE_M_2]], %[[N3]]]
-//      CHECK:     %[[N2:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//      CHECK:     %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M]]]
 //      CHECK:     %[[N1:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //      CHECK:     %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[N1]]]
-//      CHECK:     %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, 0]
-// CHECK-SAME:       [%[[N1]], %[[N2]]]
+// CHECK-SAME:       [%[[TILE_M_3]], %[[N1]]]
+//      CHECK:     %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]]
+//      CHECK:     %[[TILE_M_4:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
+//      CHECK:     %[[N2_2:.+]] = memref.dim %[[ARG2]], %[[C1]]
 //      CHECK:     %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME:       [%[[TILE_M]], %[[N2]]]
+// CHECK-SAME:       [%[[TILE_M_4]], %[[N2_2]]]
 //      CHECK:     %[[LHS:.+]] = linalg.matmul
 // CHECK-SAME:       __internal_linalg_transform__ = "after_lhs_fusion_producer"
-// CHECK-SAME:       ins(%[[ST_ARG0]], %[[ST_ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME:       ins(%[[ST_ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
 // CHECK-SAME:       outs(%[[ST_ARG2]] : tensor<?x?xf32>)
+//      CHECK:     %[[N2:.+]] = memref.dim %[[ARG1]], %[[C1]]
 //      CHECK:     %[[N3_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
 //      CHECK:     %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
 // CHECK-SAME:       %[[C0]] to %[[N3_2]] step %[[C64]]
@@ -59,7 +64,7 @@ module {
 // CHECK-SAME:         iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<?x?xf32>) {
 //      CHECK:         %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]]
 //      CHECK:         %[[ST_LHS:.+]] = subtensor %[[LHS]][0, %[[IV2]]]
-// CHECK-SAME:           [%[[TILE_M]], %[[TILE_N2]]]
+// CHECK-SAME:           [%[[TILE_M_3]], %[[TILE_N2]]]
 //      CHECK:         %[[N2_3:.+]] = memref.dim %[[ARG3]], %[[C0]]
 //      CHECK:         %[[TILE_N2_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2_3]]]
 //      CHECK:         %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]]

diff  --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 14fb995a6cfc..8bbecc091c45 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -252,25 +252,36 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
   }
   return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
 }
-// CHECK-LABEL: func @f5
-// CHECK:  (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+//     CHECK: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+//     CHECK: #[[BOUND_ID_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+//     CHECK: #[[BOUND_4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+//     CHECK: func @f5
+// HECK-SAME:  (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
 // CHECK-DAG:  %[[C0:.*]] = constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = constant 1 : index
-// CHECK-DAG:  %[[B_1:.*]] = memref.dim %[[B]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG:  %[[D_0:.*]] = memref.dim %[[D]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG:  %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[B_1:.*]] = memref.dim %[[B]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[C_0:.*]] = memref.dim %[[C]], %[[C0]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG:  %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
 // CHECK-DAG:  %[[B_00:.*]] = memref.subview %[[B]][0, 0]{{.*}}
 //     CHECK:  scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
-// CHECK-DAG:    %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0]
-// CHECK-DAG:    %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0]
+//     CHECK:    %[[BOUND_2_C0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[C_0]]]
+//     CHECK:    %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_2_C0]]
+//     CHECK:    %[[BOUND_2_D0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[D_0]]]
+//     CHECK:    %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0]
+//               Note that %[[BOUND_ID_C0]] is essentially %[[BOUND_2_C0]].
+//     CHECK:    %[[BOUND_ID_C0:.+]] = affine.min #[[BOUND_ID_MAP]](%[[I]], %[[BOUND_2_C0]])[%[[C_0]]]
+//     CHECK:    %[[C_I0_OUT:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_ID_C0]]
 //     CHECK:    scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
 //     CHECK:      %[[E_IJ:.*]] = memref.subview %[[E]][%[[I]], %[[J]]]
 //     CHECK:      scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
-// CHECK-DAG:        %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]]
-// CHECK-DAG:        %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]]
-// CHECK-DAG:        %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]]
-//     CHECK:        linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0]]
-//     CHECK:        linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK]]
+//     CHECK:        %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]] [2, 4]
+//     CHECK:        %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]]
+//     CHECK:        %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]]
+//     CHECK:        %[[BOUND_4_D1:.+]] = affine.min #[[BOUND_4_MAP]](%[[K]])[%[[D_1]]]
+//     CHECK:        %[[D_IK_OUT:.+]] = memref.subview %[[D]][%[[I]], %[[K]]] [%[[BOUND_2_D0]], %[[BOUND_4_D1]]]
+//     CHECK:        linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_OUT]]
+//     CHECK:        linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_OUT]]
 //     CHECK:        linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index b742f8148dac..0c1aa43f4412 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -1,11 +1,5 @@
 // RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
 
-#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map2 = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-#map3 = affine_map<(d0, d1) -> (2, d0 - d1)>
-#map4 = affine_map<(d0, d1) -> (3, d0 - d1)>
-
 func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %t0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
                      outs(%arg2: tensor<?x?xf32>)
@@ -36,23 +30,250 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
   return %3 : tensor<?x?xf32>
 }
 
-// CHECK-LABEL: func @matmul_tensors(
+//       CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+//       CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+
+//       CHECK: func @matmul_tensors(
 //  CHECK-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
 //  CHECK-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
 //  CHECK-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
+
 //   CHECK-DAG: %[[C0:.*]] = constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = constant 1 : index
+//   CHECK-DAG: %[[dA0:.*]] = memref.dim %[[A]], %[[C0]] : tensor<?x?xf32>
 //   CHECK-DAG: %[[dA1:.*]] = memref.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+//   CHECK-DAG: %[[dB0:.*]] = memref.dim %[[B]], %[[C0]] : tensor<?x?xf32>
+//   CHECK-DAG: %[[dB1:.*]] = memref.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+//   CHECK-DAG: %[[dC0:.*]] = memref.dim %[[C]], %[[C0]] : tensor<?x?xf32>
+//   CHECK-DAG: %[[dC1:.*]] = memref.dim %[[C]], %[[C1]] : tensor<?x?xf32>
 //       CHECK: scf.for %[[I:[0-9a-z]*]]
-//       CHECK:     %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1]  : tensor<?x?xf32> to tensor<2x?xf32>
+//       CHECK:   %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]]
+//       CHECK:   %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
+//       CHECK:   %[[sizeC0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dC0]]]
 //  CHECK-NEXT:   scf.for %[[J:[0-9a-z]*]]
 //  CHECK-NEXT:     scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
 //   CHECK-DAG:       %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1]  : tensor<?x?xf32> to tensor<4x3xf32>
 //   CHECK-DAG:       %[[stF:.*]] = subtensor %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1]  : tensor<?x?xf32> to tensor<2x3xf32>
 //
 // subtensors of the producing matmul.
-//   CHECK-DAG:       %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1]  : tensor<?x?xf32> to tensor<?x4xf32>
-//   CHECK-DAG:       %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1]  : tensor<?x?xf32> to tensor<2x4xf32>
-//       CHECK:       %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) outs(%[[stC]] : tensor<2x4xf32>)  -> tensor<2x4xf32>
-//  CHECK-NEXT:       %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
+//       CHECK:       %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]]
+//       CHECK:       %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
+//       CHECK:       %[[sizeC1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dC1]]]
+//       CHECK:       %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [%[[sizeC0]], %[[sizeC1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
+//       CHECK:       %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[stC]] : tensor<?x?xf32>)  -> tensor<?x?xf32>
+//       CHECK:       %[[CAST:.*]] = tensor.cast %[[stD]] : tensor<?x?xf32> to tensor<?x4xf32>
+//  CHECK-NEXT:       %[[stG:.*]] = linalg.matmul ins(%[[CAST]], %[[stB1]] : tensor<?x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
 //  CHECK-NEXT:       subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]]
+
+// -----
+
+func @conv_tensors_static(%input: tensor<1x225x225x32xf32>, %filter: tensor<3x3x3x32xf32>, %elementwise: tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> {
+  %c112 = constant 112 : index
+  %c32 = constant 32 : index
+  %c16 = constant 16 : index
+  %c8 = constant 8 : index
+  %c4 = constant 4 : index
+  %c0 = constant 0 : index
+  %cst = constant 0.0 : f32
+
+  %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+  %fill = linalg.fill(%init, %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32>
+
+  %conv = linalg.conv_2d_input_nhwc_filter_hwcf
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+    ins(%input, %filter : tensor<1x225x225x32xf32>, tensor<3x3x3x32xf32>)
+    outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+
+  %for0 = scf.for %iv0 = %c0 to %c112 step %c8 iter_args(%arg0 = %fill) -> tensor<1x112x112x32xf32> {
+    %for1 = scf.for %iv1 = %c0 to %c112 step %c16 iter_args(%arg1 = %arg0) -> tensor<1x112x112x32xf32> {
+      %for2 = scf.for %iv2 = %c0 to %c32 step %c4 iter_args(%arg2 = %arg1) -> tensor<1x112x112x32xf32> {
+        %0 = subtensor %conv[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+        %1 = subtensor %elementwise[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+        %2 = subtensor %arg2[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+        %add = linalg.generic
+          {
+            indexing_maps = [
+              affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+              affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+              affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+            iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+          }
+          ins(%0, %1 : tensor<1x8x16x4xf32>, tensor<1x8x16x4xf32>) outs(%2 : tensor<1x8x16x4xf32>) {
+        ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+          %result = addf %arg3, %arg4 : f32
+          linalg.yield %result : f32
+        } -> tensor<1x8x16x4xf32>
+
+        %insert = subtensor_insert %add into %arg2[0, %iv0, %iv1, %iv2] [1, 8, 16, 4] [1, 1, 1, 1]  : tensor<1x8x16x4xf32> into tensor<1x112x112x32xf32>
+        scf.yield %insert : tensor<1x112x112x32xf32>
+      }
+      scf.yield %for2 : tensor<1x112x112x32xf32>
+    }
+    scf.yield %for1 : tensor<1x112x112x32xf32>
+  }
+  return %for0 : tensor<1x112x112x32xf32>
+}
+
+//      CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)>
+//      CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+//      CHECK: func @conv_tensors_static
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x225x225x32xf32>, %[[FILTER:.+]]: tensor<3x3x3x32xf32>, %[[ELEM:.+]]: tensor<1x112x112x32xf32>)
+
+//      CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+// CHECK-NEXT: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32>
+
+// CHECK-NEXT: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG0:.+]] = %[[FILL]])
+// CHECK-NEXT:   %[[OFFSET_H:.+]] = affine.apply #[[MAP0]](%[[IV0]])
+// CHECK-NEXT:   scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG1:.+]] = %[[ARG0]])
+// CHECK-NEXT:     %[[OFFSET_W:.+]] = affine.apply #[[MAP0]](%[[IV1]])
+// CHECK-NEXT:     %[[ST_INPUT:.+]] = subtensor %arg0[0, %[[OFFSET_H]], %[[OFFSET_W]], 0] [1, 17, 33, 32] [1, 1, 1, 1] : tensor<1x225x225x32xf32> to tensor<1x17x33x32xf32>
+// CHECK-NEXT:     scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG2:.+]] = %[[ARG1]])
+// CHECK-NEXT:       %[[ST_ELEM:.+]] = subtensor %[[ELEM]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+// CHECK-NEXT:       %[[ST_ARG2:.+]] = subtensor %[[ARG2]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+// CHECK-NEXT:       %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV2]]] [3, 3, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x4xf32>
+// CHECK-NEXT:       %[[ST_FILL:.+]] = subtensor %[[FILL]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+// CHECK-NEXT:       %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME:         ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<1x17x33x32xf32>, tensor<3x3x3x4xf32>)
+// CHECK-SAME:         outs(%[[ST_FILL]] : tensor<1x8x16x4xf32>)
+// CHECK-NEXT:       %[[ADD:.+]] = linalg.generic
+// CHECK-SAME:         ins(%[[ST_CONV]], %[[ST_ELEM]] : tensor<1x8x16x4xf32>, tensor<1x8x16x4xf32>)
+// CHECK-SAME:         outs(%[[ST_ARG2]] : tensor<1x8x16x4xf32>)
+//      CHECK:       subtensor_insert %[[ADD]] into %[[ARG2]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4]
+
+// -----
+
+#bound4_map = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#bound8_map = affine_map<(d0)[s0] -> (8, -d0 + s0)>
+#bound16_map = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+
+func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %elementwise: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %cst = constant 0.0 : f32
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %c8 = constant 8 : index
+  %c16 = constant 16 : index
+
+  %n = memref.dim %elementwise, %c0 : tensor<?x?x?x?xf32>
+  %oh = memref.dim %elementwise, %c1 : tensor<?x?x?x?xf32>
+  %ow = memref.dim %elementwise, %c2 : tensor<?x?x?x?xf32>
+  %oc = memref.dim %elementwise, %c3 : tensor<?x?x?x?xf32>
+
+  %init = linalg.init_tensor [%n, %oh, %ow, %oc] : tensor<?x?x?x?xf32>
+  %fill = linalg.fill(%init, %cst) : tensor<?x?x?x?xf32>, f32 -> tensor<?x?x?x?xf32>
+
+  %conv = linalg.conv_2d_input_nhwc_filter_hwcf
+    {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+    ins(%input, %filter : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs(%fill : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+  %for0 = scf.for %iv0 = %c0 to %oh step %c8 iter_args(%arg0 = %fill) -> tensor<?x?x?x?xf32> {
+    %for1 = scf.for %iv1 = %c0 to %ow step %c16 iter_args(%arg1 = %arg0) -> tensor<?x?x?x?xf32> {
+      %for2 = scf.for %iv2 = %c0 to %oc step %c4 iter_args(%arg2 = %arg1) -> tensor<?x?x?x?xf32> {
+        %for3 = scf.for %iv3 = %c0 to %oc step %c2 iter_args(%arg3 = %arg2) -> tensor<?x?x?x?xf32> {
+          %n_size = affine.min #bound8_map(%iv0)[%n]
+          %oh_size = affine.min #bound16_map(%iv1)[%oh]
+          %ow_size = affine.min #bound4_map(%iv2)[%ow]
+          %oc_size = affine.min #bound4_map(%iv2)[%oc]
+          %0 = subtensor %conv[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+          %1 = subtensor %elementwise[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+          %2 = subtensor %arg3[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+          %add = linalg.generic
+            {
+              indexing_maps = [
+                affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+                affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+                affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+              iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+            }
+            ins(%0, %1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%2 : tensor<?x?x?x?xf32>) {
+          ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
+            %result = addf %arg4, %arg5 : f32
+            linalg.yield %result : f32
+          } -> tensor<?x?x?x?xf32>
+
+          %insert = subtensor_insert %add into %arg3[%iv0, %iv1, %iv2, %iv3] [%n_size, %oh_size, %ow_size, %oc_size] [1, 1, 1, 1]  : tensor<?x?x?x?xf32> into tensor<?x?x?x?xf32>
+          scf.yield %insert : tensor<?x?x?x?xf32>
+        }
+        scf.yield %for3 : tensor<?x?x?x?xf32>
+      }
+      scf.yield %for2 : tensor<?x?x?x?xf32>
+    }
+    scf.yield %for1 : tensor<?x?x?x?xf32>
+  }
+  return %for0 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
+// CHECK: #[[BOUND_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)>
+// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+
+//      CHECK: func @conv_tensors_dynamic
+// CHECK-SAME: (%[[INPUT]]: tensor<?x?x?x?xf32>, %[[FILTER]]: tensor<?x?x?x?xf32>, %[[ELEM]]: tensor<?x?x?x?xf32>)
+
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = constant 3 : index
+
+//  CHECK-DAG:   %[[ELEM_N:.+]] = memref.dim %[[ELEM]], %[[C0]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[ELEM_OH:.+]] = memref.dim %[[ELEM]], %[[C1]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[ELEM_OW:.+]] = memref.dim %[[ELEM]], %[[C2]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[ELEM_OC:.+]] = memref.dim %[[ELEM]], %[[C3]] : tensor<?x?x?x?xf32>
+
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%[[ELEM_N]], %[[ELEM_OH]], %[[ELEM_OW]], %[[ELEM_OC]]] : tensor<?x?x?x?xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<?x?x?x?xf32>, f32 -> tensor<?x?x?x?xf32>
+
+//  CHECK-DAG:   %[[FILTER_H:.+]] = memref.dim %[[FILTER]], %[[C0]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[FILTER_W:.+]] = memref.dim %[[FILTER]], %[[C1]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[INPUT_N:.+]] = memref.dim %[[INPUT]], %[[C0]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[INPUT_H:.+]] = memref.dim %[[INPUT]], %[[C1]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[INPUT_W:.+]] = memref.dim %[[INPUT]], %[[C2]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
+
+//      CHECK:   scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_OH]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
+// CHECK-NEXT:     %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
+// CHECK-NEXT:     %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[INPUT_N]]]
+// CHECK-NEXT:     %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[ELEM_N]]]
+// CHECK-NEXT:     scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OW]]
+// CHECK-NEXT:       %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
+// CHECK-NEXT:       %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
+// CHECK-NEXT:       %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]]
+// CHECK-NEXT:       %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[ELEM_OH]]]
+// CHECK-NEXT:       scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OC]]
+// CHECK-NEXT:         %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
+// CHECK-NEXT:         %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OC]]]
+// CHECK-NEXT:         %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]])
+// CHECK-NEXT:         %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]]
+// CHECK-NEXT:         %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
+// CHECK-SAME:               [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
+// CHECK-NEXT:         %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[ELEM_OW]]]
+// CHECK-NEXT:         scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
+// CHECK-NEXT:           %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK-SAME:                 [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
+// CHECK-NEXT:           %[[ST_ARG:.+]] = subtensor %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK-SAME:                 [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
+// CHECK-NEXT:           %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[FILTER_OC]]]
+// CHECK-NEXT:           %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]]
+// CHECK-SAME:                 [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
+// CHECK-NEXT:           %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[ELEM_OC]]]
+// CHECK-NEXT:           %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK-SAME:                 [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]]
+// CHECK-NEXT:           %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME:                 ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+// CHECK-SAME:                 outs(%[[ST_FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK-NEXT:           %[[ST_ADD:.+]] = linalg.generic
+// CHECK-SAME:                 ins(%[[ST_CONV]], %[[ST_ELEM]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+// CHECK-SAME:                 outs(%[[ST_ARG]] : tensor<?x?x?x?xf32>)
+//      CHECK:           subtensor_insert %[[ST_ADD]] into %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK-SAME:                 [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index e752c46ecea9..3ef6ed5e4b4b 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -179,6 +179,10 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
 namespace {
 struct TestLinalgGreedyFusion
     : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
+                    scf::SCFDialect>();
+  }
   void runOnFunction() override {
     MLIRContext *context = &getContext();
     RewritePatternSet patterns =


        


More information about the Mlir-commits mailing list