[Mlir-commits] [mlir] [mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. (PR #127943)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 12 22:11:29 PDT 2025
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/127943
>From 1b675e9199d107ec091ec725fd0ef820d40807b1 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mravisha at amd.com>
Date: Mon, 17 Feb 2025 21:03:56 -0600
Subject: [PATCH 1/2] [mlir][Linalg] Allow expand shape propagation across
linalg ops with dynamic shapes.
With `tensor.expand_shape` allowing expanding dynamic dimension into
multiple dynamic dimension, adapt the reshape propagation through
expansion to handle cases where one dynamic dimension is expanded into
multiple dynamic dimension.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 186 +++++------
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 296 ++++++------------
2 files changed, 177 insertions(+), 305 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 33667e7ab0c5c..cfc5b25fa87a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
#include <optional>
#include <utility>
@@ -590,18 +591,17 @@ class ExpansionInfo {
// the expanded op.
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
+ ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter);
unsigned getOrigOpNumDims() const { return reassociation.size(); }
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
ReassociationIndicesRef getExpandedDims(unsigned i) const {
return reassociation[i];
}
- ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
+ ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
return expandedShapeMap[i];
}
- ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
+ ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
private:
/// Reassociation from the dimensions in the original operation to the
@@ -609,9 +609,9 @@ class ExpansionInfo {
SmallVector<ReassociationIndices> reassociation;
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
- SmallVector<SmallVector<int64_t>> expandedShapeMap;
+ SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
/// Extent of the loop in the original operation.
- SmallVector<int64_t> originalLoopExtent;
+ SmallVector<OpFoldResult> originalLoopExtent;
unsigned expandedOpNumDims;
};
} // namespace
@@ -619,15 +619,17 @@ class ExpansionInfo {
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
+ ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
- SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
- originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(linalgOp);
+ originalLoopExtent = llvm::map_to_vector(
+ linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
+ [](Range r) { return r.size; });
reassociation.clear();
expandedShapeMap.clear();
@@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
numExpandedDims[pos] = foldedDims.getNumResults();
- ArrayRef<int64_t> shape =
+ ArrayRef<OpFoldResult> shape =
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
expandedShapeMap[pos].assign(shape.begin(), shape.end());
}
@@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
return success();
}
-/// Expanding the body of a linalg operation requires adaptations of the
-/// accessed loop indices. Specifically, access of indices in the original
-/// operation need to be replaced with linearizations of indices in the expanded
-/// op. That requires the shape of the expanded dimensions to be static (at
-/// least all but the most significant). For now check that these are all
-/// statically sized. Note that this could be extended to handle dynamic case,
-/// but the implementation below uses `affine.apply` which seems to have issues
-/// when the shapes are not static.
-static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo,
- PatternRewriter &rewriter) {
- if (!linalgOp.hasIndexSemantics())
- return success();
- for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
- ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
- if (expandedShape.size() == 1)
- continue;
- for (int64_t shape : expandedShape.drop_front()) {
- if (ShapedType::isDynamic(shape)) {
- return rewriter.notifyMatchFailure(
- linalgOp, "cannot expand due to index semantics and dynamic dims");
- }
- }
- }
- return success();
-}
-
/// Return the indexing map to use in the expanded op for a given the
/// `indexingMap` of the original operation.
static AffineMap
@@ -708,16 +683,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
/// Return the type of the operand/result to use in the expanded op given the
/// type in the original op.
-static RankedTensorType getExpandedType(RankedTensorType originalType,
- AffineMap indexingMap,
- const ExpansionInfo &expansionInfo) {
- SmallVector<int64_t> expandedShape;
+static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
+getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
+ const ExpansionInfo &expansionInfo) {
+ SmallVector<int64_t> expandedStaticShape;
+ SmallVector<OpFoldResult> expandedShape;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
- auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
+ ArrayRef<OpFoldResult> dimExpansion =
+ expansionInfo.getExpandedShapeOfDim(dim);
+ llvm::append_range(expandedStaticShape,
+ llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
+ std::optional<int64_t> staticShape =
+ getConstantIntValue(ofr);
+ if (staticShape) {
+ return staticShape.value();
+ }
+ return ShapedType::kDynamic;
+ }));
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
}
- return RankedTensorType::get(expandedShape, originalType.getElementType());
+ return {expandedShape, RankedTensorType::get(expandedStaticShape,
+ originalType.getElementType())};
}
/// Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -765,49 +752,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
// Linearize the expanded indices of the original index dimension.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(indexOp);
- ArrayRef<int64_t> expandedDimsShape =
+ ArrayRef<OpFoldResult> expandedDimsShape =
expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
SmallVector<Value> expandedIndices;
expandedIndices.reserve(expandedDims.size() - 1);
llvm::transform(
expandedDims.drop_front(), std::back_inserter(expandedIndices),
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
- Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
+ OpFoldResult newIndex =
+ rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
- assert(!ShapedType::isDynamic(std::get<0>(it)));
- AffineExpr idx, acc;
+ AffineExpr idx, acc, shape;
bindDims(rewriter.getContext(), idx, acc);
- newIndex = rewriter.create<affine::AffineApplyOp>(
- indexOp.getLoc(), idx + acc * std::get<0>(it),
- ValueRange{std::get<1>(it), newIndex});
- }
- rewriter.replaceOp(indexOp, newIndex);
- }
-}
-
-/// Checks if a single dynamic dimension expanded into multiple dynamic
-/// dimensions.
-static LogicalResult
-validateDynamicDimExpansion(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo,
- PatternRewriter &rewriter) {
- for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
- ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
- if (expandedShape.size() == 1)
- continue;
- bool foundDynamic = false;
- for (int64_t shape : expandedShape) {
- if (!ShapedType::isDynamic(shape))
- continue;
- if (foundDynamic) {
- return rewriter.notifyMatchFailure(
- linalgOp, "cannot infer expanded shape with multiple dynamic "
- "dims in the same reassociation group");
- }
- foundDynamic = true;
+ bindSymbols(rewriter.getContext(), shape);
+ newIndex = affine::makeComposedFoldedAffineApply(
+ rewriter, indexOp.getLoc(), idx + acc * shape,
+ ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
}
+ Value newIndexVal =
+ getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
+ rewriter.replaceOp(indexOp, newIndexVal);
}
- return success();
}
// Create an expanded transpose op.
@@ -910,31 +875,31 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
"preconditions for fuse operation failed");
Location loc = linalgOp.getLoc();
- // Check if reshape is expanding or collapsing.
- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
- bool isExpanding = (expandingReshapeOp != nullptr);
- RankedTensorType expandedType = isExpanding
- ? expandingReshapeOp.getResultType()
- : collapsingReshapeOp.getSrcType();
- RankedTensorType collapsedType = isExpanding
- ? expandingReshapeOp.getSrcType()
- : collapsingReshapeOp.getResultType();
+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
+ SmallVector<AffineMap, 4> reassociationIndices;
+ Value src;
+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
+ // Try to move the dynamic dimensions in output shape before the `linalgOp`
+ // to maintain SSA validity
+ if (failed(moveValueDefinitions(
+ rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
+ return std::nullopt;
+
+ expandedShape = expandingReshapeOp.getMixedOutputShape();
+ reassociationIndices = expandingReshapeOp.getReassociationMaps();
+ src = expandingReshapeOp.getSrc();
+ } else {
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
+ expandedShape = tensor::getMixedSizes(
+ rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
+ reassociationIndices = collapsingReshapeOp.getReassociationMaps();
+ src = collapsingReshapeOp.getSrc();
+ }
ExpansionInfo expansionInfo;
- if (failed(expansionInfo.compute(
- linalgOp, fusableOpOperand,
- isExpanding ? expandingReshapeOp.getReassociationMaps()
- : collapsingReshapeOp.getReassociationMaps(),
- expandedType.getShape(), collapsedType.getShape(), rewriter)))
- return std::nullopt;
-
- // TODO: With the support of multiple dynamic dims expansion in
- // tensor.expand_shape op, this case can be handled.
- if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
- return std::nullopt;
-
- if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
+ if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
+ reassociationIndices, expandedShape,
+ rewriter)))
return std::nullopt;
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -950,15 +915,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
if (opOperand == fusableOpOperand) {
- expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
- : collapsingReshapeOp.getSrc());
+ expandedOpOperands.push_back(src);
continue;
}
if (auto opOperandType =
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
- RankedTensorType expandedOperandType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
+ SmallVector<OpFoldResult> expandedOperandShape;
+ RankedTensorType expandedOperandType;
+ std::tie(expandedOperandShape, expandedOperandType) =
+ getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
@@ -972,7 +938,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOperandType, opOperand->get(), reassociation));
+ loc, expandedOperandType, opOperand->get(), reassociation,
+ expandedOperandShape));
continue;
}
}
@@ -983,8 +950,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
- RankedTensorType expandedOutputType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
+ SmallVector<OpFoldResult> expandedOutputShape;
+ RankedTensorType expandedOutputType;
+ std::tie(expandedOutputShape, expandedOutputType) =
+ getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand.get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
@@ -997,7 +966,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOutputType, opOperand.get(), reassociation));
+ loc, expandedOutputType, opOperand.get(), reassociation,
+ expandedOutputShape));
} else {
outputs.push_back(opOperand.get());
}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 3244418d445b7..67b4f2b32bad5 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -30,20 +30,14 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
-// CHECK: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_1]], %[[C4]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_4]], %[[C4]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x4x?xf32>
+// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x4x?xf32>
+// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x4x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -88,21 +82,9 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
-// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -137,26 +119,9 @@ func.func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index
-// CHECK: %[[C12:.+]] = arith.constant 12 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index
-// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index
-// CHECK: %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[SZ2]], %[[SZ0]], 2, %[[SZ1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[SZ2]], %[[SZ1]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[SZ0]], 2, %[[SZ1]], 3, 4, %[[SZ2]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
@@ -258,7 +223,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
}
// Only check the body in the indexed version of the test.
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)>
// CHECK: func @indexed_consumer_reshape_producer_fusion
// CHECK: linalg.generic
// CHECK: ^{{.*}}(
@@ -268,7 +233,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
-// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX1]], %[[IDX0]])
+// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]]()[%[[IDX1]], %[[IDX0]]]
// CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
// CHECK: %[[T5:.+]] = arith.index_cast %[[T3]]
// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
@@ -307,8 +272,7 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
}
// Only check the body in the indexed version of the test.
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 5 + s1 * 20 + s2)>
// CHECK: func @indexed_producer_reshape_consumer_fusion
// CHECK: linalg.generic
// CHECK: ^{{.*}}(
@@ -318,12 +282,11 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
-// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]])
-// CHECK: %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]])
+// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]]()[%[[IDX2]], %[[IDX1]], %[[IDX3]]]
// CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
// CHECK: %[[T5:.+]] = arith.index_cast %[[IDX0]]
// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
-// CHECK: %[[T7:.+]] = arith.index_cast %[[T2]]
+// CHECK: %[[T7:.+]] = arith.index_cast %[[T1]]
// CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
// CHECK: linalg.yield %[[T8]]
@@ -362,16 +325,15 @@ func.func @reshape_as_consumer_permutation
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 7 + s1 * 42 + s2)>
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [5, 6, 7, 2, 3, 4] : tensor<210x6x4xi32> into tensor<5x6x7x2x3x4xi32>
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [5, 6, 7, 4] : tensor<210x4xi32> into tensor<5x6x7x4xi32>
-// CHECK: %[[T3:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
+// CHECK: %[[T3:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
// CHECK: %[[T4:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
@@ -385,13 +347,12 @@ func.func @reshape_as_consumer_permutation
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
// CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index
// CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index
-// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]])
-// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]])
-// CHECK-DAG: %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]])
+// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]]()[%[[IDX1]], %[[IDX0]]]
+// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]]()[%[[IDX3]], %[[IDX2]], %[[IDX4]]]
// CHECK-DAG: %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
// CHECK: %[[T9:.+]] = arith.index_cast %[[T5]]
// CHECK: %[[T10:.+]] = arith.addi %[[T8]], %[[T9]]
-// CHECK: %[[T11:.+]] = arith.index_cast %[[T7]]
+// CHECK: %[[T11:.+]] = arith.index_cast %[[T6]]
// CHECK: %[[T12:.+]] = arith.addi %[[T10]], %[[T11]]
// CHECK: %[[T13:.+]] = arith.index_cast %[[IDX5]]
// CHECK: %[[T14:.+]] = arith.addi %[[T12]], %[[T13]]
@@ -426,7 +387,7 @@ func.func @reshape_as_producer_projected_permutation(
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 8)>
// CHECK: @reshape_as_producer_projected_permutation
// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32>
// CHECK: %[[RES:.+]] = linalg.generic
@@ -439,7 +400,7 @@ func.func @reshape_as_producer_projected_permutation(
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
-// CHECK-DAG: %[[T0:.+]] = affine.apply #[[MAP2]](%[[IDX1]], %[[IDX0]])
+// CHECK-DAG: %[[T0:.+]] = affine.apply #[[MAP2]]()[%[[IDX1]], %[[IDX0]]]
// CHECK: %[[T1:.+]] = arith.index_cast %[[T0]] : index to i32
// CHECK: %[[T2:.+]] = arith.addi %[[ARG1]], %[[T1]] : i32
// CHECK: %[[T3:.+]] = arith.index_cast %[[IDX2]] : index to i32
@@ -481,21 +442,9 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index
-// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_1]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_1]], 4, 5, %[[DIM_2]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[SZ0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[SZ0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -528,9 +477,10 @@ func.func @fuse_collapse_reduction(%arg0: tensor<10x10x20xf32>) -> tensor<100xf3
// CHECK-SAME: ins(%[[ARG0]] : tensor<10x10x20xf32>)
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
// CHECK: return %[[COLLAPSE]]
+
// -----
-func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
+func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
%1 = tensor.dim %0, %c0 : tensor<?xf32>
@@ -546,39 +496,21 @@ func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
return %3 : tensor<?xf32>
}
-// CHECK: func @no_fuse_dynamic_dims
+// CHECK: func @fuse_dynamic_dims
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[EMPTY]] {{\[}}[0, 1]{{\]}}
+// CHECK-SAME: output_shape [%[[D0]], %[[D1]]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
-// CHECK: return %[[GENERIC]]
-
-// -----
-
-func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi64>) -> tensor<2xi64> {
- %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64>
- %1 = tensor.empty() : tensor<2xi64>
- %2 = linalg.generic
- {indexing_maps = [affine_map<(d0) -> (d0)>,
- affine_map<(d0) -> (d0)>,
- affine_map<(d0) -> (d0)>],
- iterator_types = ["parallel"]}
- ins(%0, %arg1 : tensor<2xi64>, tensor<?xi64>)
- outs(%1 : tensor<2xi64>) {
- ^bb0(%arg4: i64, %arg5: i64, %arg6: i64):
- %3 = arith.addi %arg4, %arg5 : i64
- linalg.yield %3 : i64
- } -> tensor<2xi64>
- return %2 : tensor<2xi64>
-}
-
-// CHECK: func @no_fuse_mismatched_dynamism
-// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
-// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
-// CHECK: %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
-// CHECK: return %[[GENERIC]]
+// CHECK-SAME: ins(%[[ARG0]] :
+// CHECK-SAME: outs(%[[EXPAND_SHAPE]] :
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]{{\]}}
+// CHECK: return %[[COLLAPSE]]
// -----
@@ -610,32 +542,10 @@ func.func @reshape_as_consumer_permutation_with_multiple_results
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index
-// CHECK: %[[C12:.+]] = arith.constant 12 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index
-// CHECK: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index
-// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index
-// CHECK: %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index
-// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
-// CHECK: %[[DIM_9:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_10:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_11:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_5:.+]] = arith.divsi %[[DIM_10]], %[[C2]] : index
-// CHECK: %[[VAL_6:.+]] = arith.divsi %[[DIM_11]], %[[C12]] : index
-// CHECK: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[DIM_9]], %[[VAL_5]], 2, 3, 4, %[[VAL_6]]] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
+// CHECK: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[SZ2]], %[[SZ4]], 2, %[[SZ3]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
+// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[SZ2]], %[[SZ3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
+// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[SZ4]], 2, %[[SZ3]], 3, 4, %[[SZ2]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+// CHECK: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[SZ3]], %[[SZ4]], 2, 3, 4, %[[SZ2]]] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] :
@@ -710,17 +620,10 @@ func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[DIM]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "reduction"]
@@ -760,21 +663,12 @@ func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C7:.+]] = arith.constant 7 : index
-// CHECK: %[[C8:.+]] = arith.constant 8 : index
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x4x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x4x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[VAL_0]], 8, 4, %[[VAL_1]], 7] : tensor<?x4x?xf32> into tensor<?x8x4x?x7xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
-// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C7]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x7x?x8xf32>
+// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x7x?x8xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[DIM_0]], 8, 4, %[[DIM]], 7] : tensor<?x4x?xf32> into tensor<?x8x4x?x7xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_0]], 8, %[[DIM]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel"]
@@ -807,21 +701,9 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], %[[VAL_0]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
// CHECK: %[[T4:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -848,20 +730,12 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[C8:.+]] = arith.constant 8 : index
-// CHECK: %[[C7:.+]] = arith.constant 7 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
-// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x7x?x8xf32>
+// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x7x?x8xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -888,15 +762,11 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK: func @linalg_copy_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[C8:.+]] = arith.constant 8 : index
-// CHECK: %[[C7:.+]] = arith.constant 7 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
// CHECK: %[[T2:.+]] = linalg.copy
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<?x7x?x8xf32>)
@@ -907,7 +777,6 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// -----
-
func.func @reshape_as_producer_transpose
(%a : tensor<4x5x6x7x2x3xf32>)
-> tensor<6x4x210xf32> {
@@ -991,3 +860,36 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
+ %arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> {
+ %c0 = arith.constant 0 : index
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<?x128xf16>)
+ outs(%empty : tensor<4x?x32x128xf16>) {
+ ^bb0(%b0: f16, %b1 : f16) :
+ %iv0 = linalg.index 0 : index
+ %iv1 = linalg.index 1 : index
+ %iv2 = linalg.index 2 : index
+ %iv3 = linalg.index 3 : index
+ %1 = tensor.extract %arg1[%iv0, %iv1, %iv2, %iv3] : tensor<4x?x32x128xf16>
+ %2 = arith.addf %1, %b0 : f16
+ linalg.yield %2 : f16
+ } -> tensor<4x?x32x128xf16>
+ %1 = tensor.dim %arg0, %c0 : tensor<?x128xf16>
+ %2 = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [4, %1, 32, 8, 16]
+ : tensor<4x?x32x128xf16> into tensor<4x?x32x8x16xf16>
+ func.return %2 : tensor<4x?x32x8x16xf16>
+}
+// CHECK: func @move_operand_deps(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x128xf16>
+// CHECK-DAG: %[[MOVED_OP:.+]] = tensor.dim %[[ARG0]]
+// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[EXPANDED]] :
+// CHECK: return %[[GENERIC]]
>From dbaa97a091a47a724812be410dc5798ee61c0fc2 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 12 Mar 2025 22:11:01 -0700
Subject: [PATCH 2/2] Address comments.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 25 ++++++++-----------
1 file changed, 11 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index cfc5b25fa87a1..afeb162a71e31 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -681,28 +681,21 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
builder.getContext());
}
-/// Return the type of the operand/result to use in the expanded op given the
-/// type in the original op.
+/// Return the shape and type of the operand/result to use in the expanded op
+/// given the type in the original op.
static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
- SmallVector<int64_t> expandedStaticShape;
SmallVector<OpFoldResult> expandedShape;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
ArrayRef<OpFoldResult> dimExpansion =
expansionInfo.getExpandedShapeOfDim(dim);
- llvm::append_range(expandedStaticShape,
- llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
- std::optional<int64_t> staticShape =
- getConstantIntValue(ofr);
- if (staticShape) {
- return staticShape.value();
- }
- return ShapedType::kDynamic;
- }));
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
}
+ SmallVector<int64_t> expandedStaticShape;
+ std::tie(expandedStaticShape, std::ignore) =
+ decomposeMixedValues(expandedShape);
return {expandedShape, RankedTensorType::get(expandedStaticShape,
originalType.getElementType())};
}
@@ -761,13 +754,14 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
OpFoldResult newIndex =
rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
- for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
+ for (auto [expandedShape, expandedIndex] :
+ llvm::zip(expandedDimsShape, expandedIndices)) {
AffineExpr idx, acc, shape;
bindDims(rewriter.getContext(), idx, acc);
bindSymbols(rewriter.getContext(), shape);
newIndex = affine::makeComposedFoldedAffineApply(
rewriter, indexOp.getLoc(), idx + acc * shape,
- ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
+ ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
}
Value newIndexVal =
getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
@@ -890,6 +884,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
src = expandingReshapeOp.getSrc();
} else {
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
+ if (!collapsingReshapeOp)
+ return std::nullopt;
+
expandedShape = tensor::getMixedSizes(
rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
reassociationIndices = collapsingReshapeOp.getReassociationMaps();
More information about the Mlir-commits
mailing list