[Mlir-commits] [mlir] [mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. (PR #127943)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 11 18:32:02 PDT 2025
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/127943
>From 00624428051e99ff9206942bbdab444905cf2eca Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Tue, 11 Mar 2025 18:21:25 -0700
Subject: [PATCH 1/2] [mlir][Transforms] Add a utility method to move value
definitions.
https://github.com/llvm/llvm-project/commit/205c5325b3c771d94feb0ec07e8ad89d27c2b29e
added a transform utility that moved all SSA dependences of an
operation before an insertion point. Similar to that, this PR adds a
transform utility function, `moveValueDefinitions` to move the slice
of operations that define all values in a `ValueRange` before the
insertion point. While very similar to `moveOperationDependencies`,
this method differs in a few ways
1. When computing the backward slice since the start of the slice is
value, the slice computed needs to be inclusive.
2. The combined backward slice needs to be sorted topologically before
moving them to avoid SSA use-def violations while moving individual
ops.
The PR also adds a new transform op to test this new utility function.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
mlir/include/mlir/Transforms/RegionUtils.h | 11 +
mlir/lib/Transforms/Utils/RegionUtils.cpp | 71 +++++-
mlir/test/Transforms/move-operation-deps.mlir | 226 ++++++++++++++++++
.../test/lib/Transforms/TestTransformsOps.cpp | 18 ++
mlir/test/lib/Transforms/TestTransformsOps.td | 22 ++
5 files changed, 347 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index e6b928d8ebecc..2ed96afbace81 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -11,6 +11,7 @@
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SetVector.h"
@@ -80,6 +81,16 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint);
+/// Move definitions of `values` before an insertion point. Current support is
+/// only for movement of definitions within the same basic block. Note that this
+/// is an all-or-nothing approach. Either definitions of all values are moved
+/// before insertion point, or none of them are.
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+ Operation *insertionPoint,
+ DominanceInfo &dominance);
+LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
+ Operation *insertionPoint);
+
/// Run a set of structural simplifications over the given regions. This
/// includes transformations like unreachable block elimination, dead argument
/// elimination, as well as some other DCE. This function returns success if any
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index da0d486f0fdcb..6987a13b309d7 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1070,7 +1070,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
// in different basic blocks.
if (op->getBlock() != insertionPoint->getBlock()) {
return rewriter.notifyMatchFailure(
- op, "unsupported caes where operation and insertion point are not in "
+ op, "unsupported case where operation and insertion point are not in "
"the same basic block");
}
// If `insertionPoint` does not dominate `op`, do nothing
@@ -1115,3 +1115,72 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
DominanceInfo dominance(op);
return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
}
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+ ValueRange values,
+ Operation *insertionPoint,
+ DominanceInfo &dominance) {
+ // Remove the values that already dominate the insertion point.
+ SmallVector<Value> prunedValues;
+ for (auto value : values) {
+ if (dominance.properlyDominates(value, insertionPoint)) {
+ continue;
+ }
+ // Block arguments are not supported.
+ if (isa<BlockArgument>(value)) {
+ return rewriter.notifyMatchFailure(
+ insertionPoint,
+ "unsupported case of moving block argument before insertion point");
+ }
+ // Check for currently unsupported case if the insertion point is in a
+ // different block.
+ if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) {
+ return rewriter.notifyMatchFailure(
+ insertionPoint,
+ "unsupported case of moving definition of value before an insertion "
+ "point in a different basic block");
+ }
+ prunedValues.push_back(value);
+ }
+
+ // Find the backward slice of operation for each `Value` the operation
+ // depends on. Prune the slice to only include operations not already
+ // dominated by the `insertionPoint`
+ BackwardSliceOptions options;
+ options.inclusive = true;
+ options.omitUsesFromAbove = false;
+ // Since current support is to only move within a same basic block,
+ // the slices dont need to look past block arguments.
+ options.omitBlockArguments = true;
+ options.filter = [&](Operation *sliceBoundaryOp) {
+ return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+ };
+ llvm::SetVector<Operation *> slice;
+ for (auto value : prunedValues) {
+ getBackwardSlice(value, &slice, options);
+ }
+
+ // If the slice contains `insertionPoint` cannot move the dependencies.
+ if (slice.contains(insertionPoint)) {
+ return rewriter.notifyMatchFailure(
+ insertionPoint,
+ "cannot move dependencies before operation in backward slice of op");
+ }
+
+ // Sort operations topologically before moving.
+ mlir::topologicalSort(slice);
+
+ // We should move the slice in topological order, but `getBackwardSlice`
+ // already does that. So no need to sort again.
+ for (Operation *op : slice) {
+ rewriter.moveOpBefore(op, insertionPoint);
+ }
+ return success();
+}
+
+LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
+ ValueRange values,
+ Operation *insertionPoint) {
+ DominanceInfo dominance(insertionPoint);
+ return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
+}
diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir
index 37637152938f6..aa7b5dc2a240a 100644
--- a/mlir/test/Transforms/move-operation-deps.mlir
+++ b/mlir/test/Transforms/move-operation-deps.mlir
@@ -234,3 +234,229 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Check simple move value definitions before insertion operation.
+func.func @simple_move_values() -> f32 {
+ %0 = "before"() : () -> (f32)
+ %1 = "moved_op_1"() : () -> (f32)
+ %2 = "moved_op_2"() : () -> (f32)
+ %3 = "foo"(%1, %2) : (f32, f32) -> (f32)
+ return %3 : f32
+}
+// CHECK-LABEL: func @simple_move_values()
+// CHECK: %[[MOVED1:.+]] = "moved_op_1"
+// CHECK: %[[MOVED2:.+]] = "moved_op_2"
+// CHECK: %[[BEFORE:.+]] = "before"
+// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]])
+// CHECK: return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["moved_op_1"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["moved_op_2"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op3 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1, %v2 before %op3
+ : (!transform.any_value, !transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Compute slice including the implicitly captured values.
+func.func @move_region_dependencies_values() -> f32 {
+ %0 = "before"() : () -> (f32)
+ %1 = "moved_op_1"() : () -> (f32)
+ %2 = "moved_op_2"() ({
+ %3 = "inner_op"(%1) : (f32) -> (f32)
+ "yield"(%3) : (f32) -> ()
+ }) : () -> (f32)
+ return %2 : f32
+}
+// CHECK-LABEL: func @move_region_dependencies_values()
+// CHECK: %[[MOVED1:.+]] = "moved_op_1"
+// CHECK: %[[MOVED2:.+]] = "moved_op_2"
+// CHECK: %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["moved_op_2"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1 before %op2
+ : (!transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Move operations in toplogical sort order
+func.func @move_values_in_topological_sort_order() -> f32 {
+ %0 = "before"() : () -> (f32)
+ %1 = "moved_op_1"() : () -> (f32)
+ %2 = "moved_op_2"() : () -> (f32)
+ %3 = "moved_op_3"(%1) : (f32) -> (f32)
+ %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
+ %5 = "moved_op_5"(%2) : (f32) -> (f32)
+ %6 = "foo"(%4, %5) : (f32, f32) -> (f32)
+ return %6 : f32
+}
+// CHECK-LABEL: func @move_values_in_topological_sort_order()
+// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
+// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
+// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
+// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
+// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
+// CHECK: %[[BEFORE:.+]] = "before"
+// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
+// CHECK: return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["moved_op_4"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["moved_op_5"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op3 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1, %v2 before %op3
+ : (!transform.any_value, !transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+ %0 = "unmoved_op"() : () -> (f32)
+ %1 = "dummy_op"() : () -> (f32)
+ %2 = "before"() : () -> (f32)
+ %3 = "moved_op"() : () -> (f32)
+ return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
+// CHECK: %[[DUMMY:.+]] = "dummy_op"
+// CHECK: %[[MOVED:.+]] = "moved_op"
+// CHECK: %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op3 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1, %v2 before %op3
+ : (!transform.any_value, !transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Move only those value definitions that are not dominated by insertion point
+
+func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
+ %0 = "unmoved_op"() : () -> (f32)
+ %1 = "dummy_op"() : () -> (f32)
+ %2 = "before"() : () -> (f32)
+ %3 = "moved_op"() : () -> (f32)
+ return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
+// CHECK: %[[DUMMY:.+]] = "dummy_op"
+// CHECK: %[[MOVED:.+]] = "moved_op"
+// CHECK: %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["dummy_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op3 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op4 = transform.structured.match ops{["moved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
+ %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1, %v2 before %op3
+ : (!transform.any_value, !transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Check handling of block arguments
+func.func @move_only_required_defns() -> (f32, f32) {
+ %0 = "unmoved_op"() : () -> (f32)
+ cf.br ^bb0(%0 : f32)
+ ^bb0(%arg0 : f32) :
+ %1 = "before"() : () -> (f32)
+ %2 = "moved_op"(%arg0) : (f32) -> (f32)
+ return %1, %2 : f32, f32
+}
+// CHECK-LABEL: func @move_only_required_defns()
+// CHECK: %[[MOVED:.+]] = "moved_op"
+// CHECK: %[[BEFORE:.+]] = "before"
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+ transform.test.move_value_defns %v1 before %op1
+ : (!transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Do not move across basic blocks
+func.func @no_move_across_basic_blocks() -> (f32, f32) {
+ %0 = "unmoved_op"() : () -> (f32)
+ %1 = "before"() : () -> (f32)
+ cf.br ^bb0(%0 : f32)
+ ^bb0(%arg0 : f32) :
+ %2 = "moved_op"(%arg0) : (f32) -> (f32)
+ return %1, %2 : f32, f32
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+ %op1 = transform.structured.match ops{["before"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %op2 = transform.structured.match ops{["moved_op"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
+ // expected-remark at +1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
+ transform.test.move_value_defns %v1 before %op1
+ : (!transform.any_value), !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
index aaa566d9938a3..3d95af59f6da3 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.cpp
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -39,6 +39,24 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+DiagnosedSilenceableFailure
+transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
+ TransformResults &TransformResults,
+ TransformState &state) {
+ SmallVector<Value> values;
+ for (auto tdValue : getValues()) {
+ values.push_back(*state.getPayloadValues(tdValue).begin());
+ }
+ Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
+ if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
+ auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
+ std::string errorMsg = listener->getLatestMatchFailureMessage();
+ (void)emitRemark(errorMsg);
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+
namespace {
class TestTransformsDialectExtension
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
index f514702cef5bc..495579b452dfc 100644
--- a/mlir/test/lib/Transforms/TestTransformsOps.td
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -38,4 +38,26 @@ def TestMoveOperandDeps :
}];
}
+def TestMoveValueDefns :
+ Op<Transform_Dialect, "test.move_value_defns",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Moves all dependencies of on operation before another operation.
+ }];
+
+ let arguments =
+ (ins Variadic<TransformValueHandleTypeInterface>:$values,
+ TransformHandleTypeInterface:$insertion_point);
+
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $values `before` $insertion_point attr-dict
+ `:` `(` type($values) `)` `` `,` type($insertion_point)
+ }];
+}
+
+
#endif // TEST_TRANSFORM_OPS
>From 0e134fc9c4c36f7d8d8039ea12f2a25e2fd4996c Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mravisha at amd.com>
Date: Mon, 17 Feb 2025 21:03:56 -0600
Subject: [PATCH 2/2] [mlir][Linalg] Allow expand shape propagation across
linalg ops with dynamic shapes.
With `tensor.expand_shape` allowing expanding dynamic dimension into
multiple dynamic dimension, adapt the reshape propagation through
expansion to handle cases where one dynamic dimension is expanded into
multiple dynamic dimension.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 186 +++++------
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 296 ++++++------------
2 files changed, 177 insertions(+), 305 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 33667e7ab0c5c..cfc5b25fa87a1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
#include <optional>
#include <utility>
@@ -590,18 +591,17 @@ class ExpansionInfo {
// the expanded op.
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
+ ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter);
unsigned getOrigOpNumDims() const { return reassociation.size(); }
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
ReassociationIndicesRef getExpandedDims(unsigned i) const {
return reassociation[i];
}
- ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
+ ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
return expandedShapeMap[i];
}
- ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
+ ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
private:
/// Reassociation from the dimensions in the original operation to the
@@ -609,9 +609,9 @@ class ExpansionInfo {
SmallVector<ReassociationIndices> reassociation;
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
- SmallVector<SmallVector<int64_t>> expandedShapeMap;
+ SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
/// Extent of the loop in the original operation.
- SmallVector<int64_t> originalLoopExtent;
+ SmallVector<OpFoldResult> originalLoopExtent;
unsigned expandedOpNumDims;
};
} // namespace
@@ -619,15 +619,17 @@ class ExpansionInfo {
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
+ ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
- SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
- originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(linalgOp);
+ originalLoopExtent = llvm::map_to_vector(
+ linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
+ [](Range r) { return r.size; });
reassociation.clear();
expandedShapeMap.clear();
@@ -639,7 +641,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
numExpandedDims[pos] = foldedDims.getNumResults();
- ArrayRef<int64_t> shape =
+ ArrayRef<OpFoldResult> shape =
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
expandedShapeMap[pos].assign(shape.begin(), shape.end());
}
@@ -660,33 +662,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
return success();
}
-/// Expanding the body of a linalg operation requires adaptations of the
-/// accessed loop indices. Specifically, access of indices in the original
-/// operation need to be replaced with linearizations of indices in the expanded
-/// op. That requires the shape of the expanded dimensions to be static (at
-/// least all but the most significant). For now check that these are all
-/// statically sized. Note that this could be extended to handle dynamic case,
-/// but the implementation below uses `affine.apply` which seems to have issues
-/// when the shapes are not static.
-static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo,
- PatternRewriter &rewriter) {
- if (!linalgOp.hasIndexSemantics())
- return success();
- for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
- ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
- if (expandedShape.size() == 1)
- continue;
- for (int64_t shape : expandedShape.drop_front()) {
- if (ShapedType::isDynamic(shape)) {
- return rewriter.notifyMatchFailure(
- linalgOp, "cannot expand due to index semantics and dynamic dims");
- }
- }
- }
- return success();
-}
-
/// Return the indexing map to use in the expanded op for a given the
/// `indexingMap` of the original operation.
static AffineMap
@@ -708,16 +683,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
/// Return the type of the operand/result to use in the expanded op given the
/// type in the original op.
-static RankedTensorType getExpandedType(RankedTensorType originalType,
- AffineMap indexingMap,
- const ExpansionInfo &expansionInfo) {
- SmallVector<int64_t> expandedShape;
+static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
+getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
+ const ExpansionInfo &expansionInfo) {
+ SmallVector<int64_t> expandedStaticShape;
+ SmallVector<OpFoldResult> expandedShape;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
- auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
+ ArrayRef<OpFoldResult> dimExpansion =
+ expansionInfo.getExpandedShapeOfDim(dim);
+ llvm::append_range(expandedStaticShape,
+ llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
+ std::optional<int64_t> staticShape =
+ getConstantIntValue(ofr);
+ if (staticShape) {
+ return staticShape.value();
+ }
+ return ShapedType::kDynamic;
+ }));
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
}
- return RankedTensorType::get(expandedShape, originalType.getElementType());
+ return {expandedShape, RankedTensorType::get(expandedStaticShape,
+ originalType.getElementType())};
}
/// Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -765,49 +752,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
// Linearize the expanded indices of the original index dimension.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(indexOp);
- ArrayRef<int64_t> expandedDimsShape =
+ ArrayRef<OpFoldResult> expandedDimsShape =
expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
SmallVector<Value> expandedIndices;
expandedIndices.reserve(expandedDims.size() - 1);
llvm::transform(
expandedDims.drop_front(), std::back_inserter(expandedIndices),
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
- Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
+ OpFoldResult newIndex =
+ rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
- assert(!ShapedType::isDynamic(std::get<0>(it)));
- AffineExpr idx, acc;
+ AffineExpr idx, acc, shape;
bindDims(rewriter.getContext(), idx, acc);
- newIndex = rewriter.create<affine::AffineApplyOp>(
- indexOp.getLoc(), idx + acc * std::get<0>(it),
- ValueRange{std::get<1>(it), newIndex});
- }
- rewriter.replaceOp(indexOp, newIndex);
- }
-}
-
-/// Checks if a single dynamic dimension expanded into multiple dynamic
-/// dimensions.
-static LogicalResult
-validateDynamicDimExpansion(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo,
- PatternRewriter &rewriter) {
- for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
- ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
- if (expandedShape.size() == 1)
- continue;
- bool foundDynamic = false;
- for (int64_t shape : expandedShape) {
- if (!ShapedType::isDynamic(shape))
- continue;
- if (foundDynamic) {
- return rewriter.notifyMatchFailure(
- linalgOp, "cannot infer expanded shape with multiple dynamic "
- "dims in the same reassociation group");
- }
- foundDynamic = true;
+ bindSymbols(rewriter.getContext(), shape);
+ newIndex = affine::makeComposedFoldedAffineApply(
+ rewriter, indexOp.getLoc(), idx + acc * shape,
+ ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
}
+ Value newIndexVal =
+ getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
+ rewriter.replaceOp(indexOp, newIndexVal);
}
- return success();
}
// Create an expanded transpose op.
@@ -910,31 +875,31 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
"preconditions for fuse operation failed");
Location loc = linalgOp.getLoc();
- // Check if reshape is expanding or collapsing.
- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
- bool isExpanding = (expandingReshapeOp != nullptr);
- RankedTensorType expandedType = isExpanding
- ? expandingReshapeOp.getResultType()
- : collapsingReshapeOp.getSrcType();
- RankedTensorType collapsedType = isExpanding
- ? expandingReshapeOp.getSrcType()
- : collapsingReshapeOp.getResultType();
+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
+ SmallVector<AffineMap, 4> reassociationIndices;
+ Value src;
+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
+ // Try to move the dynamic dimensions in output shape before the `linalgOp`
+ // to maintain SSA validity
+ if (failed(moveValueDefinitions(
+ rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
+ return std::nullopt;
+
+ expandedShape = expandingReshapeOp.getMixedOutputShape();
+ reassociationIndices = expandingReshapeOp.getReassociationMaps();
+ src = expandingReshapeOp.getSrc();
+ } else {
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
+ expandedShape = tensor::getMixedSizes(
+ rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
+ reassociationIndices = collapsingReshapeOp.getReassociationMaps();
+ src = collapsingReshapeOp.getSrc();
+ }
ExpansionInfo expansionInfo;
- if (failed(expansionInfo.compute(
- linalgOp, fusableOpOperand,
- isExpanding ? expandingReshapeOp.getReassociationMaps()
- : collapsingReshapeOp.getReassociationMaps(),
- expandedType.getShape(), collapsedType.getShape(), rewriter)))
- return std::nullopt;
-
- // TODO: With the support of multiple dynamic dims expansion in
- // tensor.expand_shape op, this case can be handled.
- if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
- return std::nullopt;
-
- if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
+ if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
+ reassociationIndices, expandedShape,
+ rewriter)))
return std::nullopt;
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -950,15 +915,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
if (opOperand == fusableOpOperand) {
- expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
- : collapsingReshapeOp.getSrc());
+ expandedOpOperands.push_back(src);
continue;
}
if (auto opOperandType =
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
- RankedTensorType expandedOperandType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
+ SmallVector<OpFoldResult> expandedOperandShape;
+ RankedTensorType expandedOperandType;
+ std::tie(expandedOperandShape, expandedOperandType) =
+ getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
@@ -972,7 +938,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOperandType, opOperand->get(), reassociation));
+ loc, expandedOperandType, opOperand->get(), reassociation,
+ expandedOperandShape));
continue;
}
}
@@ -983,8 +950,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
- RankedTensorType expandedOutputType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
+ SmallVector<OpFoldResult> expandedOutputShape;
+ RankedTensorType expandedOutputType;
+ std::tie(expandedOutputShape, expandedOutputType) =
+ getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand.get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
@@ -997,7 +966,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOutputType, opOperand.get(), reassociation));
+ loc, expandedOutputType, opOperand.get(), reassociation,
+ expandedOutputShape));
} else {
outputs.push_back(opOperand.get());
}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 3244418d445b7..67b4f2b32bad5 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -30,20 +30,14 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
-// CHECK: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_1]], %[[C4]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_4]], %[[C4]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x4x?xf32>
+// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x4x?xf32>
+// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x4x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -88,21 +82,9 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
-// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -137,26 +119,9 @@ func.func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index
-// CHECK: %[[C12:.+]] = arith.constant 12 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index
-// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index
-// CHECK: %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[SZ2]], %[[SZ0]], 2, %[[SZ1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[SZ2]], %[[SZ1]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[SZ0]], 2, %[[SZ1]], 3, 4, %[[SZ2]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
@@ -258,7 +223,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
}
// Only check the body in the indexed version of the test.
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)>
// CHECK: func @indexed_consumer_reshape_producer_fusion
// CHECK: linalg.generic
// CHECK: ^{{.*}}(
@@ -268,7 +233,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
-// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX1]], %[[IDX0]])
+// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]]()[%[[IDX1]], %[[IDX0]]]
// CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
// CHECK: %[[T5:.+]] = arith.index_cast %[[T3]]
// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
@@ -307,8 +272,7 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
}
// Only check the body in the indexed version of the test.
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 5 + s1 * 20 + s2)>
// CHECK: func @indexed_producer_reshape_consumer_fusion
// CHECK: linalg.generic
// CHECK: ^{{.*}}(
@@ -318,12 +282,11 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
-// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]])
-// CHECK: %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]])
+// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]]()[%[[IDX2]], %[[IDX1]], %[[IDX3]]]
// CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
// CHECK: %[[T5:.+]] = arith.index_cast %[[IDX0]]
// CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
-// CHECK: %[[T7:.+]] = arith.index_cast %[[T2]]
+// CHECK: %[[T7:.+]] = arith.index_cast %[[T1]]
// CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
// CHECK: linalg.yield %[[T8]]
@@ -362,16 +325,15 @@ func.func @reshape_as_consumer_permutation
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 7 + s1 * 42 + s2)>
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [5, 6, 7, 2, 3, 4] : tensor<210x6x4xi32> into tensor<5x6x7x2x3x4xi32>
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [5, 6, 7, 4] : tensor<210x4xi32> into tensor<5x6x7x4xi32>
-// CHECK: %[[T3:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
+// CHECK: %[[T3:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
// CHECK: %[[T4:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
@@ -385,13 +347,12 @@ func.func @reshape_as_consumer_permutation
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
// CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index
// CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index
-// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]])
-// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]])
-// CHECK-DAG: %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]])
+// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]]()[%[[IDX1]], %[[IDX0]]]
+// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]]()[%[[IDX3]], %[[IDX2]], %[[IDX4]]]
// CHECK-DAG: %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
// CHECK: %[[T9:.+]] = arith.index_cast %[[T5]]
// CHECK: %[[T10:.+]] = arith.addi %[[T8]], %[[T9]]
-// CHECK: %[[T11:.+]] = arith.index_cast %[[T7]]
+// CHECK: %[[T11:.+]] = arith.index_cast %[[T6]]
// CHECK: %[[T12:.+]] = arith.addi %[[T10]], %[[T11]]
// CHECK: %[[T13:.+]] = arith.index_cast %[[IDX5]]
// CHECK: %[[T14:.+]] = arith.addi %[[T12]], %[[T13]]
@@ -426,7 +387,7 @@ func.func @reshape_as_producer_projected_permutation(
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 8)>
// CHECK: @reshape_as_producer_projected_permutation
// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32>
// CHECK: %[[RES:.+]] = linalg.generic
@@ -439,7 +400,7 @@ func.func @reshape_as_producer_projected_permutation(
// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index
// CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index
-// CHECK-DAG: %[[T0:.+]] = affine.apply #[[MAP2]](%[[IDX1]], %[[IDX0]])
+// CHECK-DAG: %[[T0:.+]] = affine.apply #[[MAP2]]()[%[[IDX1]], %[[IDX0]]]
// CHECK: %[[T1:.+]] = arith.index_cast %[[T0]] : index to i32
// CHECK: %[[T2:.+]] = arith.addi %[[ARG1]], %[[T1]] : i32
// CHECK: %[[T3:.+]] = arith.index_cast %[[IDX2]] : index to i32
@@ -481,21 +442,9 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index
-// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_1]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_1]], 4, 5, %[[DIM_2]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[SZ0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[SZ0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -528,9 +477,10 @@ func.func @fuse_collapse_reduction(%arg0: tensor<10x10x20xf32>) -> tensor<100xf3
// CHECK-SAME: ins(%[[ARG0]] : tensor<10x10x20xf32>)
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
// CHECK: return %[[COLLAPSE]]
+
// -----
-func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
+func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
%1 = tensor.dim %0, %c0 : tensor<?xf32>
@@ -546,39 +496,21 @@ func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
return %3 : tensor<?xf32>
}
-// CHECK: func @no_fuse_dynamic_dims
+// CHECK: func @fuse_dynamic_dims
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[EMPTY]] {{\[}}[0, 1]{{\]}}
+// CHECK-SAME: output_shape [%[[D0]], %[[D1]]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
-// CHECK: return %[[GENERIC]]
-
-// -----
-
-func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi64>) -> tensor<2xi64> {
- %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64>
- %1 = tensor.empty() : tensor<2xi64>
- %2 = linalg.generic
- {indexing_maps = [affine_map<(d0) -> (d0)>,
- affine_map<(d0) -> (d0)>,
- affine_map<(d0) -> (d0)>],
- iterator_types = ["parallel"]}
- ins(%0, %arg1 : tensor<2xi64>, tensor<?xi64>)
- outs(%1 : tensor<2xi64>) {
- ^bb0(%arg4: i64, %arg5: i64, %arg6: i64):
- %3 = arith.addi %arg4, %arg5 : i64
- linalg.yield %3 : i64
- } -> tensor<2xi64>
- return %2 : tensor<2xi64>
-}
-
-// CHECK: func @no_fuse_mismatched_dynamism
-// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
-// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
-// CHECK: %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
-// CHECK: return %[[GENERIC]]
+// CHECK-SAME: ins(%[[ARG0]] :
+// CHECK-SAME: outs(%[[EXPAND_SHAPE]] :
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]{{\]}}
+// CHECK: return %[[COLLAPSE]]
// -----
@@ -610,32 +542,10 @@ func.func @reshape_as_consumer_permutation_with_multiple_results
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index
-// CHECK: %[[C12:.+]] = arith.constant 12 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index
-// CHECK: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index
-// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index
-// CHECK: %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index
-// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
-// CHECK: %[[DIM_9:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_10:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_11:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_5:.+]] = arith.divsi %[[DIM_10]], %[[C2]] : index
-// CHECK: %[[VAL_6:.+]] = arith.divsi %[[DIM_11]], %[[C12]] : index
-// CHECK: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[DIM_9]], %[[VAL_5]], 2, 3, 4, %[[VAL_6]]] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
+// CHECK: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[SZ2]], %[[SZ4]], 2, %[[SZ3]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
+// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[SZ2]], %[[SZ3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
+// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[SZ4]], 2, %[[SZ3]], 3, 4, %[[SZ2]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+// CHECK: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[SZ3]], %[[SZ4]], 2, 3, 4, %[[SZ2]]] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] :
@@ -710,17 +620,10 @@ func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[DIM]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "reduction"]
@@ -760,21 +663,12 @@ func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C7:.+]] = arith.constant 7 : index
-// CHECK: %[[C8:.+]] = arith.constant 8 : index
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x4x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x4x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[VAL_0]], 8, 4, %[[VAL_1]], 7] : tensor<?x4x?xf32> into tensor<?x8x4x?x7xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
-// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C7]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x7x?x8xf32>
+// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x7x?x8xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[DIM_0]], 8, 4, %[[DIM]], 7] : tensor<?x4x?xf32> into tensor<?x8x4x?x7xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_0]], 8, %[[DIM]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel"]
@@ -807,21 +701,9 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], %[[VAL_0]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
// CHECK: %[[T4:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -848,20 +730,12 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[C8:.+]] = arith.constant 8 : index
-// CHECK: %[[C7:.+]] = arith.constant 7 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
-// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x7x?x8xf32>
+// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x7x?x8xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -888,15 +762,11 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK: func @linalg_copy_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK: %[[C8:.+]] = arith.constant 8 : index
-// CHECK: %[[C7:.+]] = arith.constant 7 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
// CHECK: %[[T2:.+]] = linalg.copy
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<?x7x?x8xf32>)
@@ -907,7 +777,6 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// -----
-
func.func @reshape_as_producer_transpose
(%a : tensor<4x5x6x7x2x3xf32>)
-> tensor<6x4x210xf32> {
@@ -991,3 +860,36 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
+ %arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> {
+ %c0 = arith.constant 0 : index
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<?x128xf16>)
+ outs(%empty : tensor<4x?x32x128xf16>) {
+ ^bb0(%b0: f16, %b1 : f16) :
+ %iv0 = linalg.index 0 : index
+ %iv1 = linalg.index 1 : index
+ %iv2 = linalg.index 2 : index
+ %iv3 = linalg.index 3 : index
+ %1 = tensor.extract %arg1[%iv0, %iv1, %iv2, %iv3] : tensor<4x?x32x128xf16>
+ %2 = arith.addf %1, %b0 : f16
+ linalg.yield %2 : f16
+ } -> tensor<4x?x32x128xf16>
+ %1 = tensor.dim %arg0, %c0 : tensor<?x128xf16>
+ %2 = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [4, %1, 32, 8, 16]
+ : tensor<4x?x32x128xf16> into tensor<4x?x32x8x16xf16>
+ func.return %2 : tensor<4x?x32x8x16xf16>
+}
+// CHECK: func @move_operand_deps(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x128xf16>
+// CHECK-DAG: %[[MOVED_OP:.+]] = tensor.dim %[[ARG0]]
+// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[EXPANDED]] :
+// CHECK: return %[[GENERIC]]
More information about the Mlir-commits
mailing list