[Mlir-commits] [mlir] e58597e - [mlir][linalg] Fuse producers with non-permutation indexing maps
Lei Zhang
llvmlistbot at llvm.org
Wed Mar 24 15:19:33 PDT 2021
Author: Lei Zhang
Date: 2021-03-24T18:17:57-04:00
New Revision: e58597ee1c7dfe4fd2fdf6f5f0230f22b849c9be
URL: https://github.com/llvm/llvm-project/commit/e58597ee1c7dfe4fd2fdf6f5f0230f22b849c9be
DIFF: https://github.com/llvm/llvm-project/commit/e58597ee1c7dfe4fd2fdf6f5f0230f22b849c9be.diff
LOG: [mlir][linalg] Fuse producers with non-permutation indexing maps
Until now Linalg fusion only allow fusing producers whose operands
are all permutation indexing maps. It's easier to deduce the
subtensor/subview but it is an unnecessary constraint, as in tiling
we have more advanced logic to deduce the subranges even when the
operand is not of permutation indexing maps, e.g., the input operand
for convolution ops.
This patch uses the logic on tiling side to deduce subranges for
fusion. This enables fusing convolution with its consumer ops
when possible.
Along the way, we are now generating proper affine.min ops to guard
against size boundaries, if we cannot be certain they won't be
out of bounds.
Differential Revision: https://reviews.llvm.org/D99014
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/fusion-pattern.mlir
mlir/test/Dialect/Linalg/fusion-sequence.mlir
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
mlir/test/Dialect/Linalg/fusion.mlir
mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index ffa2811c37d2..ea7fd62baad4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -59,104 +59,6 @@ using llvm::dbgs;
/// More advanced use cases, analyses as well as profitability heuristics are
/// left for future work.
-// Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
-// by `permutationMap`.
-static void inferShapeComponents(AffineMap permutationMap,
- ArrayRef<Range> loopRanges,
- SmallVectorImpl<OpFoldResult> &offsets,
- SmallVectorImpl<OpFoldResult> &sizes,
- SmallVectorImpl<OpFoldResult> &strides) {
- assert(permutationMap.isProjectedPermutation() &&
- "expected some subset of a permutation map");
- SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
- unsigned idx = 0;
- for (AffineExpr e : permutationMap.getResults()) {
- // loopToOperandRangesMaps are permutations-only, just swap indices.
- unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
- shapeRanges[idx++] = loopRanges[loopPos];
- }
- // Construct a new subshape for the tile.
- unsigned rank = shapeRanges.size();
- offsets.reserve(rank);
- sizes.reserve(rank);
- strides.reserve(rank);
- for (auto r : shapeRanges) {
- offsets.push_back(r.offset);
- sizes.push_back(r.size);
- strides.push_back(r.stride);
- }
-}
-
-// Return a cloned version of `op` that operates on `loopRanges`, assumed to be
-// a subset of the original loop ranges of `op`.
-// This is achieved by applying the `loopToOperandRangesMaps` permutation maps
-// to the `loopRanges` in order to obtain view ranges.
-static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
- ArrayRef<Range> loopRanges) {
- SmallVector<Value, 8> clonedShapes;
- clonedShapes.reserve(op.getNumShapedOperands());
-
- // Iterate over the shape operands in order.
- // Extract the subranges from the linearized ranges.
- for (auto en : llvm::enumerate(op.getShapedOperands())) {
- unsigned shapedOperandIdx = en.index();
- AffineMap map = op.getIndexingMap(shapedOperandIdx);
- LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
- << " with indexingMap: " << map << "\n");
- SmallVector<OpFoldResult, 4> offsets, sizes, strides;
- inferShapeComponents(map, loopRanges, offsets, sizes, strides);
- Value shape = en.value();
- Value sub =
- shape.getType().isa<MemRefType>()
- ? b.create<memref::SubViewOp>(loc, shape, offsets, sizes, strides)
- .getResult()
- : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
- .getResult();
- clonedShapes.push_back(sub);
- }
- // Append the other operands.
- auto operands = op.getAssumedNonShapedOperands();
- clonedShapes.append(operands.begin(), operands.end());
-
- // Iterate over the results in order.
- // Extract the subtensor type from the linearized range.
- // Since we do not enforce any canonicalizations on the fly, this is always
- // fully dynamic at construction time.
- SmallVector<Type, 4> resultTypes;
- resultTypes.reserve(op->getNumResults());
- for (RankedTensorType t : op.getOutputTensorTypes()) {
- unsigned rank = t.getRank();
- SmallVector<int64_t, 4> staticOffsetsVector(
- rank, ShapedType::kDynamicStrideOrOffset);
- SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
- SmallVector<int64_t, 4> staticStridesVector(
- rank, ShapedType::kDynamicStrideOrOffset);
- resultTypes.push_back(SubTensorOp::inferResultType(
- t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
- staticStridesVector));
- }
-
- Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
- // When the producer is an IndexedGenericOp, we have to transform its block
- // IV arguments according to the tiling of the consumer, i.e. offset them by
- // the values computed in `loopRanges`.
- if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
- auto &block = indexedGenericOp.region().front();
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPointToStart(&block);
- for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
- Value oldIndex = block.getArgument(i);
- // TODO: replace by an affine_apply.
- AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
- loopRanges[i].offset);
- oldIndex.replaceAllUsesExcept(newIndex,
- SmallPtrSet<Operation *, 1>{newIndex});
- }
- }
-
- return clonedOp;
-}
-
struct ShapeDimension {
Value shape;
unsigned dimension;
@@ -208,35 +110,86 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
llvm_unreachable("Expect to be able to extract a shape defining loop range");
}
-/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
+/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
/// provides the loop range information for the fused loops. The rest are
/// obtained from the producer itself, since they are not tiled + fused.
-static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
+static LinalgOp fuse(OpBuilder &builder, LinalgOp producer,
const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
-
- unsigned nPar = producer.getNumParallelLoops();
- unsigned nRed = producer.getNumReductionLoops();
- unsigned nWin = producer.getNumWindowLoops();
- SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
- for (auto fusedLoops : fusedLoopsAndRanges)
- loopRanges[fusedLoops.first] = fusedLoops.second;
-
- // Iterate over all dimensions. For the dimensions not identified by the
- // producer map for `producerIdx`, we need to explicitly compute the shape
- // that defines the loop ranges using the `producer`.
- for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
- if (loopRanges[i].offset)
- LLVM_DEBUG(llvm::dbgs()
- << "existing LoopRange: " << loopRanges[i] << "\n");
- else {
+ SmallVector<Value, 8> ivs, tileSizes, sizeBounds;
+ SmallVector<Range, 8> loopRanges;
+ auto zero = std_constant_index(0);
+ auto one = std_constant_index(1);
+ Location loc = producer.getLoc();
+
+ for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
+ auto it = fusedLoopsAndRanges.find(i);
+ if (it != fusedLoopsAndRanges.end()) {
+ ivs.push_back(it->second.offset);
+ tileSizes.push_back(it->second.size);
+ sizeBounds.push_back(nullptr);
+ loopRanges.push_back(it->second);
+ LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
+ << loopRanges.back() << "\n");
+ } else {
auto shapeDim = getShapeDefiningLoopRange(producer, i);
Value dim = memref_dim(shapeDim.shape, shapeDim.dimension);
- loopRanges[i] = Range{std_constant_index(0), dim, std_constant_index(1)};
- LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
+ tileSizes.push_back(zero);
+ sizeBounds.push_back(dim);
+ loopRanges.push_back(Range{zero, dim, one});
+ LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
+ << loopRanges.back() << "\n");
+ }
+ }
+
+ SmallVector<Value, 8> clonedShapes;
+ clonedShapes.reserve(producer.getNumShapedOperands());
+
+ // Compute subranges for all tensor input/output operands.
+ auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands());
+ clonedShapes.append(makeTiledShapes(builder, loc, producer, tiledOperands,
+ ivs, tileSizes, sizeBounds));
+
+ // Append the other operands.
+ auto operands = producer.getAssumedNonShapedOperands();
+ clonedShapes.append(operands.begin(), operands.end());
+
+ // Iterate over the results in order.
+ // Extract the subtensor type from the linearized range.
+ // Since we do not enforce any canonicalizations on the fly, this is always
+ // fully dynamic at construction time.
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.reserve(producer->getNumResults());
+ for (RankedTensorType t : producer.getOutputTensorTypes()) {
+ unsigned rank = t.getRank();
+ SmallVector<int64_t, 4> staticOffsetsVector(
+ rank, ShapedType::kDynamicStrideOrOffset);
+ SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+ SmallVector<int64_t, 4> staticStridesVector(
+ rank, ShapedType::kDynamicStrideOrOffset);
+ resultTypes.push_back(SubTensorOp::inferResultType(
+ t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
+ staticStridesVector));
+ }
+
+ Operation *clonedOp = producer.clone(builder, loc, resultTypes, clonedShapes);
+ // When the producer is an IndexedGenericOp, we have to transform its block
+ // IV arguments according to the tiling of the consumer, i.e. offset them by
+ // the values computed in `loopRanges`.
+ if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
+ auto &block = indexedGenericOp.region().front();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToStart(&block);
+ for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
+ Value oldIndex = block.getArgument(i);
+ // TODO: replace by an affine_apply.
+ AddIOp newIndex = builder.create<AddIOp>(indexedGenericOp.getLoc(),
+ oldIndex, loopRanges[i].offset);
+ oldIndex.replaceAllUsesExcept(newIndex,
+ SmallPtrSet<Operation *, 1>{newIndex});
}
}
- return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
+ return clonedOp;
}
/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 8fe3d8530c62..8c8b0cf1f7bf 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -27,6 +27,9 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-utils"
using namespace mlir;
using namespace mlir::edsc;
@@ -447,11 +450,14 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
// that define tile subshapes.
SmallVector<Value, 8> lbs, subShapeSizes;
for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
+ LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
bool isTiled = !isZero(tileSizes[idx]);
lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0));
// Before composing, we need to make range a closed interval.
Value size = isTiled ? tileSizes[idx] : sizeBounds[idx];
subShapeSizes.push_back(size - std_constant_index(1));
+ LLVM_DEBUG(llvm::dbgs() << "lb: " << lbs.back() << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "size: " << subShapeSizes.back() << "\n");
}
MLIRContext *context = builder.getContext();
@@ -459,14 +465,18 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
tiledShapes.reserve(tiledOperands.size());
for (auto en : llvm::enumerate(tiledOperands)) {
Value shapedOp = en.value();
+ LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
ShapedType shapedType = shapedOp.getType().cast<ShapedType>();
unsigned rank = shapedType.getRank();
AffineMap map = linalgOp.getIndexingMap(en.index());
// If the shape is not tiled, we can use it as is.
if (!isTiled(map, tileSizes)) {
tiledShapes.push_back(shapedOp);
+ LLVM_DEBUG(llvm::dbgs()
+ << ": not tiled: use shape: " << shapedType << "\n");
continue;
}
+ LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
// Construct a new subview / subtensor for the tile.
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
@@ -474,22 +484,28 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
sizes.reserve(rank);
strides.reserve(rank);
for (unsigned r = 0; r < rank; ++r) {
+ LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for dim#" << r);
if (!isTiled(map.getSubMap({r}), tileSizes)) {
offsets.push_back(builder.getIndexAttr(0));
- sizes.push_back(memref_dim(shapedOp, r).value);
+ Value dim = memref_dim(shapedOp, r).value;
+ sizes.push_back(dim);
strides.push_back(builder.getIndexAttr(1));
+ LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
continue;
}
+ LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
// Tiling creates a new slice at the proper index, the slice step is 1
// (i.e. the op does not subsample, stepping occurs in the loop).
auto m = map.getSubMap({r});
+ LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: submap: " << map << "\n");
auto offset = applyMapToValues(builder, loc, m, lbs).front();
offsets.push_back(offset);
auto closedIntSize =
applyMapToValues(builder, loc, m, subShapeSizes).front();
// Resulting size needs to be made half open interval again.
auto size = closedIntSize + std_constant_index(1);
+ LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: raw size: " << size << "\n");
// The size of the subview / subtensor should be trimmed to avoid
// out-of-bounds accesses, unless we statically know the subshape size
@@ -498,6 +514,9 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
if (ShapedType::isDynamic(shapeSize) || !sizeCst ||
(shapeSize % sizeCst.getValue()) != 0) {
+ LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: shapeSize=" << shapeSize
+ << ", size: " << size
+ << ": make sure in bound with affine.min\n");
AffineExpr dim0, dim1, dim2;
bindDims(context, dim0, dim1, dim2);
// Compute min(size, dim - offset) to avoid out-of-bounds accesses.
@@ -510,6 +529,9 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
}
sizes.push_back(size);
+ LLVM_DEBUG(llvm::dbgs()
+ << "makeTiledShapes: new offset: " << offset << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: new size: " << size << "\n");
strides.push_back(builder.getIndexAttr(1));
}
diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
index aefeeb5e3ada..b9ba18bbd05a 100644
--- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
@@ -16,6 +16,7 @@ module {
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: func @basic_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@@ -47,8 +48,10 @@ module {
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
// CHECK: %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]]
+// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
+// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV1]], %[[TILE_N]])[%[[N_2]]]
// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
+// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N_3]]]
// CHECK: linalg.fill(%[[SV3_2]], %[[CST]])
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
@@ -86,6 +89,7 @@ module {
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: func @rhs_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@@ -112,10 +116,13 @@ module {
// CHECK: %[[SV2:.+]] = memref.subview %[[ARG3]][0, %[[IV0]]]
// CHECK-SAME: [%[[M]], %[[TILE_N_2]]]
// CHECK: %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]]
+// CHECK: %[[N_3:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N_3]]]
// CHECK: %[[SV3:.+]] = memref.subview %[[ARG1]][0, %[[IV0]]]
-// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
+// CHECK-SAME: [%[[K_2]], %[[TILE_N_3]]]
+// CHECK: %[[TILE_N_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N]]]
// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][0, %[[IV0]]]
-// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
+// CHECK-SAME: [%[[K]], %[[TILE_N_4]]]
// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer"
// CHECK-NOT: linalg.fill
@@ -164,6 +171,7 @@ module {
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: func @two_operand_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@@ -191,13 +199,17 @@ module {
// CHECK: %[[N:.+]] = memref.dim %[[ARG3]], %[[C1]]
// CHECK: %[[SV2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
+// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
// CHECK: %[[SV2_2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N]]]
+// CHECK-SAME: [%[[TILE_M_3]], %[[N]]]
+// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
// CHECK: %[[K_2:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
+// CHECK-SAME: [%[[TILE_M_4]], %[[K_2]]]
+// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]]
// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG1]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
+// CHECK-SAME: [%[[TILE_M_5]], %[[K]]]
// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
// CHECK: linalg.fill(%[[SV2_2]], %[[CST]])
@@ -271,23 +283,24 @@ module {
// CHECK: %[[N:.+]] = memref.dim %[[ARG4]], %[[C1]]
// CHECK: %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
-// CHECK: %[[K2_2:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
// CHECK: %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K1]]]
-// CHECK: %[[SV4:.+]] = memref.subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]]
+// CHECK-SAME: [%[[TILE_M_3]], %[[K1]]]
+// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]]
// CHECK: %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]]
+// CHECK-SAME: [%[[TILE_M_4]], %[[K2]]]
// CHECK: linalg.matmul
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
-// CHECK-SAME: ins(%[[SV3]], %[[SV4]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME: ins(%[[SV3]], %[[ARG1]]
+// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
-// CHECK-DAG: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
+// CHECK: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
// CHECK: scf.parallel (%[[IV1:.+]]) =
// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
-// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
-// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]]
+// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K2]] step %[[C16]] {
+// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K2]]]
// CHECK: %[[SV6:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]]
// CHECK: %[[K_2:.+]] = memref.dim %[[ARG3]], %[[C0]]
@@ -348,10 +361,11 @@ module {
// CHECK: %[[T6:.+]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
// CHECK: %[[T8:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0]
// CHECK: %[[T9:.+]] = memref.subview %[[ARG1]][0, %[[ARG4]]]
+// CHECK: %[[T10:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]]
// CHECK: linalg.matmul
// CHECK-SAME: after_transpose_fusion_producer
// CHECK-SAME: ins(%[[T8]], %[[T9]]
-// CHECK-SAME: outs(%[[T5]]
+// CHECK-SAME: outs(%[[T10]]
// CHECK-NOT: linalg.matmul
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[T5]], %[[T5]]
diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
index bec19b325a7b..981db2bfea7f 100644
--- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir
@@ -36,18 +36,19 @@ module {
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK: %[[TEMP:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
-// CHECK-DAG: %[[SV_TEMP:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[SV_TEMP_1:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
// CHECK-DAG: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[IV1]]]
// CHECK-DAG: %[[SV_ARG3:.+]] = memref.subview %[[ARG3]][%[[IV0]], %[[IV1]]]
// CHECK-DAG: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
// CHECK-DAG: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]]
-// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}})
+// CHECK: %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
+// CHECK: linalg.fill(%[[SV_TEMP_2]], %{{.+}})
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
-// CHECK-SAME: outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
+// CHECK-SAME: outs(%[[SV_TEMP_2]] : memref<?x?xf32, #[[MAP2]]>)
// CHECK: linalg.generic
-// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]]
+// CHECK-SAME: ins(%[[SV_TEMP_1]], %[[SV_ARG2]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
// CHECK: scf.yield
@@ -83,6 +84,8 @@ module {
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+
// CHECK: func @sequence_of_matmul
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@@ -100,37 +103,40 @@ module {
// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
// CHECK-SAME: step (%[[C16]]) {
// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
-// CHECK: %[[SV_ALLOC2:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
+// CHECK: %[[SV_ALLOC3:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
// CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]]
// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
// CHECK: %[[N3:.+]] = memref.dim %[[ARG4]], %[[C1]]
// CHECK: %[[SV_ARG4:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
+// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
// CHECK: %[[SV_ARG4_2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N3]]]
+// CHECK-SAME: [%[[TILE_M_3]], %[[N3]]]
+// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]]
// CHECK: %[[SV_ALLOC1:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
-// CHECK: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
+// CHECK-SAME: [%[[TILE_M_4]], %[[N1]]]
+// CHECK: %[[SV_ALLOC2:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M_4]], %[[N2]]]
// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]]
-// CHECK: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
+// CHECK-SAME: [%[[TILE_M_4]], %[[N0]]]
// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}})
-// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]]
+// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}})
-// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
-// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]]
+// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}})
-// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
+// CHECK: linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK: scf.yield
// CHECK: }
+
// -----
module {
@@ -189,8 +195,8 @@ module {
module {
func @tensor_matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
%arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
- %arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
- %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
+ %arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
%1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
@@ -200,7 +206,12 @@ module {
return %2 : tensor<?x?xf32>
}
}
-// CHECK-LABEL: func @tensor_matmul_fusion(
+
+// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (16, d0 - d1)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+
+// CHECK: func @tensor_matmul_fusion(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -210,36 +221,39 @@ module {
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK: %[[M:.+]] = memref.dim %[[ARG0]], %c0 : tensor<?x?xf32>
// CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] =
// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
-// CHECK: %[[N3:.+]] = memref.dim %[[ARG8]], %[[C1]]
-// CHECK: %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0]
-// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N3]]]
-// CHECK: %[[N2:.+]] = memref.dim %[[ARG3]], %[[C1]]
-// CHECK: %[[N1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK: %[[STARG3:.+]] = subtensor %[[ARG3]][0, 0]
-// CHECK-SAME: [%[[N1]], %[[N2]]]
-// CHECK: %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0]
-// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N2]]]
-// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N0]]]
-// CHECK: %[[STARG1:.+]] = subtensor %[[ARG1]][0, 0]
-// CHECK-SAME: [%[[N0]], %[[N1]]]
-// CHECK: %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N1]]]
-// CHECK: %[[T0:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]]
-// CHECK-SAME: ) outs(%[[STARG2]] : tensor<?x?xf32>)
-// CHECK: %[[T1:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[T0]], %[[STARG3]]
-// CHECK-SAME: ) outs(%[[STARG4]] : tensor<?x?xf32>)
-// CHECK: %[[T2:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[T1]], %[[ARG5]]
-// CHECK-SAME: ) outs(%[[STARG6]] : tensor<?x?xf32>)
-// CHECK: %[[R1:.+]] = subtensor_insert %[[T2]]
-// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0]
-// CHECK: scf.yield %[[R1]]
-// CHECK: }
-// CHECK: return %[[R0]]
+// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+// CHECK: %[[M_1:.+]] = memref.dim %[[ARG8]], %[[C0]]
+// CHECK: %[[TILE_M_1:.+]] = affine.min #[[MAP1]](%[[M_1]], %[[IV0]])
+// CHECK: %[[N3:.+]] = memref.dim %[[ARG8]], %[[C1]]
+// CHECK: %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M_1]], %[[N3]]]
+// CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]]
+// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
+// CHECK: %[[N2:.+]] = memref.dim %[[ARG4]], %[[C1]]
+// CHECK: %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M_2]], %[[N2]]]
+// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]]
+// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M_3]], %[[N0]]]
+// CHECK: %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]]
+// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
+// CHECK: %[[N1:.+]] = memref.dim %[[ARG2]], %[[C1]]
+// CHECK: %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
+// CHECK-SAME: [%[[TILE_M_4]], %[[N1]]]
+// CHECK: %[[T0:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[STARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>
+// CHECK-SAME: ) outs(%[[STARG2]] : tensor<?x?xf32>)
+// CHECK: %[[T1:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[T0]], %arg3 : tensor<?x?xf32>, tensor<?x?xf32>
+// CHECK-SAME: ) outs(%[[STARG4]] : tensor<?x?xf32>)
+// CHECK: %[[T2:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[T1]], %arg5 : tensor<?x?xf32>, tensor<?x?xf32>
+// CHECK-SAME: ) outs(%[[STARG6]] : tensor<?x?xf32>)
+// CHECK: %[[R1:.+]] = subtensor_insert %[[T2]]
+// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0] [%[[TILE_M_1]], %[[N3]]]
+// CHECK: scf.yield %[[R1]] : tensor<?x?xf32>
// CHECK: }
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index bd0d61c8580e..7f1131815d7c 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -17,12 +17,15 @@ module {
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (64, d0 - d1)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+
// CHECK: func @matmul_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C32:.+]] = constant 32 : index
@@ -38,18 +41,20 @@ module {
// CHECK: %[[N3:.+]] = memref.dim %[[ARG6]], %[[C1]]
// CHECK: %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
-// CHECK: %[[N2:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M]]]
// CHECK: %[[N1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
-// CHECK: %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, 0]
-// CHECK-SAME: [%[[N1]], %[[N2]]]
+// CHECK-SAME: [%[[TILE_M_3]], %[[N1]]]
+// CHECK: %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]]
+// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
+// CHECK: %[[N2_2:.+]] = memref.dim %[[ARG2]], %[[C1]]
// CHECK: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
-// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
+// CHECK-SAME: [%[[TILE_M_4]], %[[N2_2]]]
// CHECK: %[[LHS:.+]] = linalg.matmul
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
-// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
+// CHECK-SAME: ins(%[[ST_ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
+// CHECK: %[[N2:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[N3_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
// CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]]
@@ -59,7 +64,7 @@ module {
// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<?x?xf32>) {
// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]]
// CHECK: %[[ST_LHS:.+]] = subtensor %[[LHS]][0, %[[IV2]]]
-// CHECK-SAME: [%[[TILE_M]], %[[TILE_N2]]]
+// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N2]]]
// CHECK: %[[N2_3:.+]] = memref.dim %[[ARG3]], %[[C0]]
// CHECK: %[[TILE_N2_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2_3]]]
// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]]
diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
index 14fb995a6cfc..8bbecc091c45 100644
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -252,25 +252,36 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
}
return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
}
-// CHECK-LABEL: func @f5
-// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+// CHECK: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+// CHECK: #[[BOUND_ID_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+// CHECK: #[[BOUND_4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+// CHECK: func @f5
+// HECK-SAME: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
-// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
-// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %[[C0]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref<?x?xf32, #[[$strided2D]]>
+// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK-DAG: %[[B_00:.*]] = memref.subview %[[B]][0, 0]{{.*}}
// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
-// CHECK-DAG: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0]
-// CHECK-DAG: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0]
+// CHECK: %[[BOUND_2_C0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[C_0]]]
+// CHECK: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_2_C0]]
+// CHECK: %[[BOUND_2_D0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[D_0]]]
+// CHECK: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0]
+// Note that %[[BOUND_ID_C0]] is essentially %[[BOUND_2_C0]].
+// CHECK: %[[BOUND_ID_C0:.+]] = affine.min #[[BOUND_ID_MAP]](%[[I]], %[[BOUND_2_C0]])[%[[C_0]]]
+// CHECK: %[[C_I0_OUT:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_ID_C0]]
// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
// CHECK: %[[E_IJ:.*]] = memref.subview %[[E]][%[[I]], %[[J]]]
// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
-// CHECK-DAG: %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]]
-// CHECK-DAG: %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]]
-// CHECK-DAG: %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]]
-// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0]]
-// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK]]
+// CHECK: %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]] [2, 4]
+// CHECK: %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]]
+// CHECK: %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]]
+// CHECK: %[[BOUND_4_D1:.+]] = affine.min #[[BOUND_4_MAP]](%[[K]])[%[[D_1]]]
+// CHECK: %[[D_IK_OUT:.+]] = memref.subview %[[D]][%[[I]], %[[K]]] [%[[BOUND_2_D0]], %[[BOUND_4_D1]]]
+// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_OUT]]
+// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_OUT]]
// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
// -----
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index b742f8148dac..0c1aa43f4412 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -1,11 +1,5 @@
// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
-#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-#map2 = affine_map<(d0)[s0] -> (3, -d0 + s0)>
-#map3 = affine_map<(d0, d1) -> (2, d0 - d1)>
-#map4 = affine_map<(d0, d1) -> (3, d0 - d1)>
-
func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%t0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2: tensor<?x?xf32>)
@@ -36,23 +30,250 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
return %3 : tensor<?x?xf32>
}
-// CHECK-LABEL: func @matmul_tensors(
+// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+
+// CHECK: func @matmul_tensors(
// CHECK-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
+
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[dA0:.*]] = memref.dim %[[A]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[dA1:.*]] = memref.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[dB0:.*]] = memref.dim %[[B]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[dB1:.*]] = memref.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[dC0:.*]] = memref.dim %[[C]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[dC1:.*]] = memref.dim %[[C]], %[[C1]] : tensor<?x?xf32>
// CHECK: scf.for %[[I:[0-9a-z]*]]
-// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<2x?xf32>
+// CHECK: %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]]
+// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sizeC0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dC0]]]
// CHECK-NEXT: scf.for %[[J:[0-9a-z]*]]
// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
// CHECK-DAG: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor<?x?xf32> to tensor<4x3xf32>
// CHECK-DAG: %[[stF:.*]] = subtensor %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor<?x?xf32> to tensor<2x3xf32>
//
// subtensors of the producing matmul.
-// CHECK-DAG: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
-// CHECK-DAG: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor<?x?xf32> to tensor<2x4xf32>
-// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) outs(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
-// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
+// CHECK: %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]]
+// CHECK: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sizeC1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dC1]]]
+// CHECK: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [%[[sizeC0]], %[[sizeC1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[stC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %[[CAST:.*]] = tensor.cast %[[stD]] : tensor<?x?xf32> to tensor<?x4xf32>
+// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[CAST]], %[[stB1]] : tensor<?x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]]
+
+// -----
+
+func @conv_tensors_static(%input: tensor<1x225x225x32xf32>, %filter: tensor<3x3x3x32xf32>, %elementwise: tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> {
+ %c112 = constant 112 : index
+ %c32 = constant 32 : index
+ %c16 = constant 16 : index
+ %c8 = constant 8 : index
+ %c4 = constant 4 : index
+ %c0 = constant 0 : index
+ %cst = constant 0.0 : f32
+
+ %init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+ %fill = linalg.fill(%init, %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32>
+
+ %conv = linalg.conv_2d_input_nhwc_filter_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+ ins(%input, %filter : tensor<1x225x225x32xf32>, tensor<3x3x3x32xf32>)
+ outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
+
+ %for0 = scf.for %iv0 = %c0 to %c112 step %c8 iter_args(%arg0 = %fill) -> tensor<1x112x112x32xf32> {
+ %for1 = scf.for %iv1 = %c0 to %c112 step %c16 iter_args(%arg1 = %arg0) -> tensor<1x112x112x32xf32> {
+ %for2 = scf.for %iv2 = %c0 to %c32 step %c4 iter_args(%arg2 = %arg1) -> tensor<1x112x112x32xf32> {
+ %0 = subtensor %conv[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+ %1 = subtensor %elementwise[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+ %2 = subtensor %arg2[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+ %add = linalg.generic
+ {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%0, %1 : tensor<1x8x16x4xf32>, tensor<1x8x16x4xf32>) outs(%2 : tensor<1x8x16x4xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %result = addf %arg3, %arg4 : f32
+ linalg.yield %result : f32
+ } -> tensor<1x8x16x4xf32>
+
+ %insert = subtensor_insert %add into %arg2[0, %iv0, %iv1, %iv2] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x8x16x4xf32> into tensor<1x112x112x32xf32>
+ scf.yield %insert : tensor<1x112x112x32xf32>
+ }
+ scf.yield %for2 : tensor<1x112x112x32xf32>
+ }
+ scf.yield %for1 : tensor<1x112x112x32xf32>
+ }
+ return %for0 : tensor<1x112x112x32xf32>
+}
+
+// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK: func @conv_tensors_static
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x225x225x32xf32>, %[[FILTER:.+]]: tensor<3x3x3x32xf32>, %[[ELEM:.+]]: tensor<1x112x112x32xf32>)
+
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
+// CHECK-NEXT: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32>
+
+// CHECK-NEXT: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG0:.+]] = %[[FILL]])
+// CHECK-NEXT: %[[OFFSET_H:.+]] = affine.apply #[[MAP0]](%[[IV0]])
+// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG1:.+]] = %[[ARG0]])
+// CHECK-NEXT: %[[OFFSET_W:.+]] = affine.apply #[[MAP0]](%[[IV1]])
+// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %arg0[0, %[[OFFSET_H]], %[[OFFSET_W]], 0] [1, 17, 33, 32] [1, 1, 1, 1] : tensor<1x225x225x32xf32> to tensor<1x17x33x32xf32>
+// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG2:.+]] = %[[ARG1]])
+// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+// CHECK-NEXT: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV2]]] [3, 3, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x4xf32>
+// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
+// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<1x17x33x32xf32>, tensor<3x3x3x4xf32>)
+// CHECK-SAME: outs(%[[ST_FILL]] : tensor<1x8x16x4xf32>)
+// CHECK-NEXT: %[[ADD:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ST_CONV]], %[[ST_ELEM]] : tensor<1x8x16x4xf32>, tensor<1x8x16x4xf32>)
+// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<1x8x16x4xf32>)
+// CHECK: subtensor_insert %[[ADD]] into %[[ARG2]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4]
+
+// -----
+
+#bound4_map = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+#bound8_map = affine_map<(d0)[s0] -> (8, -d0 + s0)>
+#bound16_map = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+
+func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %elementwise: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %cst = constant 0.0 : f32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %c8 = constant 8 : index
+ %c16 = constant 16 : index
+
+ %n = memref.dim %elementwise, %c0 : tensor<?x?x?x?xf32>
+ %oh = memref.dim %elementwise, %c1 : tensor<?x?x?x?xf32>
+ %ow = memref.dim %elementwise, %c2 : tensor<?x?x?x?xf32>
+ %oc = memref.dim %elementwise, %c3 : tensor<?x?x?x?xf32>
+
+ %init = linalg.init_tensor [%n, %oh, %ow, %oc] : tensor<?x?x?x?xf32>
+ %fill = linalg.fill(%init, %cst) : tensor<?x?x?x?xf32>, f32 -> tensor<?x?x?x?xf32>
+
+ %conv = linalg.conv_2d_input_nhwc_filter_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+ ins(%input, %filter : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%fill : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+
+ %for0 = scf.for %iv0 = %c0 to %oh step %c8 iter_args(%arg0 = %fill) -> tensor<?x?x?x?xf32> {
+ %for1 = scf.for %iv1 = %c0 to %ow step %c16 iter_args(%arg1 = %arg0) -> tensor<?x?x?x?xf32> {
+ %for2 = scf.for %iv2 = %c0 to %oc step %c4 iter_args(%arg2 = %arg1) -> tensor<?x?x?x?xf32> {
+ %for3 = scf.for %iv3 = %c0 to %oc step %c2 iter_args(%arg3 = %arg2) -> tensor<?x?x?x?xf32> {
+ %n_size = affine.min #bound8_map(%iv0)[%n]
+ %oh_size = affine.min #bound16_map(%iv1)[%oh]
+ %ow_size = affine.min #bound4_map(%iv2)[%ow]
+ %oc_size = affine.min #bound4_map(%iv2)[%oc]
+ %0 = subtensor %conv[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+ %1 = subtensor %elementwise[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+ %2 = subtensor %arg3[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+ %add = linalg.generic
+ {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ }
+ ins(%0, %1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%2 : tensor<?x?x?x?xf32>) {
+ ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
+ %result = addf %arg4, %arg5 : f32
+ linalg.yield %result : f32
+ } -> tensor<?x?x?x?xf32>
+
+ %insert = subtensor_insert %add into %arg3[%iv0, %iv1, %iv2, %iv3] [%n_size, %oh_size, %ow_size, %oc_size] [1, 1, 1, 1] : tensor<?x?x?x?xf32> into tensor<?x?x?x?xf32>
+ scf.yield %insert : tensor<?x?x?x?xf32>
+ }
+ scf.yield %for3 : tensor<?x?x?x?xf32>
+ }
+ scf.yield %for2 : tensor<?x?x?x?xf32>
+ }
+ scf.yield %for1 : tensor<?x?x?x?xf32>
+ }
+ return %for0 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
+// CHECK: #[[BOUND_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
+// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
+// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)>
+// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
+
+// CHECK: func @conv_tensors_dynamic
+// CHECK-SAME: (%[[INPUT]]: tensor<?x?x?x?xf32>, %[[FILTER]]: tensor<?x?x?x?xf32>, %[[ELEM]]: tensor<?x?x?x?xf32>)
+
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+
+// CHECK-DAG: %[[ELEM_N:.+]] = memref.dim %[[ELEM]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[ELEM_OH:.+]] = memref.dim %[[ELEM]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[ELEM_OW:.+]] = memref.dim %[[ELEM]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[ELEM_OC:.+]] = memref.dim %[[ELEM]], %[[C3]] : tensor<?x?x?x?xf32>
+
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[ELEM_N]], %[[ELEM_OH]], %[[ELEM_OW]], %[[ELEM_OC]]] : tensor<?x?x?x?xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<?x?x?x?xf32>, f32 -> tensor<?x?x?x?xf32>
+
+// CHECK-DAG: %[[FILTER_H:.+]] = memref.dim %[[FILTER]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILTER_W:.+]] = memref.dim %[[FILTER]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[INPUT_N:.+]] = memref.dim %[[INPUT]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[INPUT_H:.+]] = memref.dim %[[INPUT]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[INPUT_W:.+]] = memref.dim %[[INPUT]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
+
+// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_OH]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
+// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
+// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[INPUT_N]]]
+// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[ELEM_N]]]
+// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OW]]
+// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
+// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
+// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[ELEM_OH]]]
+// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OC]]
+// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OC]]]
+// CHECK-NEXT: %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]])
+// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]]
+// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
+// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[ELEM_OW]]]
+// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
+// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
+// CHECK-NEXT: %[[ST_ARG:.+]] = subtensor %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[FILTER_OC]]]
+// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]]
+// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[ELEM_OC]]]
+// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK-SAME: [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]]
+// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
+// CHECK-SAME: ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+// CHECK-SAME: outs(%[[ST_FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+// CHECK-NEXT: %[[ST_ADD:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ST_CONV]], %[[ST_ELEM]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+// CHECK-SAME: outs(%[[ST_ARG]] : tensor<?x?x?x?xf32>)
+// CHECK: subtensor_insert %[[ST_ADD]] into %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index e752c46ecea9..3ef6ed5e4b4b 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -179,6 +179,10 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
namespace {
struct TestLinalgGreedyFusion
: public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
+ scf::SCFDialect>();
+ }
void runOnFunction() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns =
More information about the Mlir-commits
mailing list