[Mlir-commits] [mlir] f4eb681 - [mlir][Linalg] Drop unit-trip loops of reductions only if other reduction loops exists.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 8 22:31:48 PDT 2021
Author: MaheshRavishankar
Date: 2021-04-08T22:31:29-07:00
New Revision: f4eb681dc37ae84e08579bf96cd2a6f58c44f260
URL: https://github.com/llvm/llvm-project/commit/f4eb681dc37ae84e08579bf96cd2a6f58c44f260
DIFF: https://github.com/llvm/llvm-project/commit/f4eb681dc37ae84e08579bf96cd2a6f58c44f260.diff
LOG: [mlir][Linalg] Drop unit-trip loops of reductions only if other reduction loops exists.
Recent change enable dropping unit-trip loops of "reduction" iterator
type as well. This is fine as long as there is one other "reduction"
iterator in the operation. Without this the initialized value (value
of `out`) is not read which leads to a correctness issue.
Also fix a bug in the `fill` -> `tensor_reshape` folding. The `out`
operand of the `fill` needs to be reshaped to get the `out` operand of
the generated `fill` operation.
Differential Revision: https://reviews.llvm.org/D100145
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 786b9ec85dcf..344ffe977caf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -31,7 +31,9 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
"Only folds the one-trip loops from Linalg ops on tensors "
"(for testing purposes only)">
];
- let dependentDialects = ["linalg::LinalgDialect"];
+ let dependentDialects = [
+ "linalg::LinalgDialect", "AffineDialect", "memref::MemRefDialect"
+ ];
}
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
@@ -43,7 +45,9 @@ def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
"Allow fusing linalg.tensor_reshape ops that performs unit "
"dimension collapsing">
];
- let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
+ let dependentDialects = [
+ "AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
+ ];
}
def LinalgFoldReshapeOpsByLinearization :
@@ -51,7 +55,7 @@ def LinalgFoldReshapeOpsByLinearization :
let summary = "Fold TensorReshapeOps with generic/indexed generic ops by "
"linearization";
let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()";
- let dependentDialects = ["AffineDialect"];
+ let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
}
def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
@@ -64,7 +68,8 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
"interchange vector",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
- let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
+ let dependentDialects = [
+ "AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"];
}
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76fa96b6346c..d8b512cdeea0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1681,9 +1681,10 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
if (!oldFill)
return failure();
- auto newInit = rewriter.create<InitTensorOp>(
- oldFill.getLoc(), reshapeOp.getResultType().getShape(),
- reshapeOp.getResultType().getElementType());
+ Location loc = oldFill.getLoc();
+ auto newInit = rewriter.create<TensorReshapeOp>(
+ loc, reshapeOp.getResultType(), oldFill.output(),
+ reshapeOp.reassociation());
rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, newInit, oldFill.value());
return success();
@@ -1694,7 +1695,8 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseReshapeOps<TensorReshapeOp>, FoldFillWithTensorReshape,
- FoldReshapeWithConstant>(context);
+ FoldInitTensorWithTensorReshapeOp, FoldReshapeWithConstant>(
+ context);
}
LogicalResult TensorReshapeOp::reifyReturnTypeShapesPerResultDim(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 7aefefd642ea..b3af82c80cee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -190,15 +190,43 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
SmallVector<int64_t, 4> dims;
for (ShapedType shapedType : op.getShapedOperandTypes())
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
+
+ // Find all the reduction iterators. Those need some special consideration
+ // (see below).
+ auto getLoopDimsOfType =
+ [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
+ SmallVector<AffineExpr> dimExprs;
+ getDimsOfType(op, iteratorTypeName, dimExprs);
+ return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
+ return expr.cast<AffineDimExpr>().getPosition();
+ }));
+ };
+ auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
+
DenseSet<unsigned> unitDims;
+ SmallVector<unsigned, 4> unitDimsReductionLoops;
ArrayAttr iteratorTypes = op.iterator_types();
for (auto expr : enumerate(invertedMap.getResults())) {
if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
- if (dims[dimExpr.getPosition()] == 1 &&
- iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() ==
- getParallelIteratorTypeName())
- unitDims.insert(expr.index());
+ if (dims[dimExpr.getPosition()] == 1) {
+ if (isParallelIterator(iteratorTypes[expr.index()]))
+ unitDims.insert(expr.index());
+ else if (isReductionIterator(iteratorTypes[expr.index()]))
+ unitDimsReductionLoops.push_back(expr.index());
+ }
}
+
+ // Reduction loops can be dropped if there is at least one other reduction
+ // loop that is not dropped. This accounts for the initial value read in the
+ // reduction loop.
+ if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
+ if (unitDimsReductionLoops.size() == reductionDims.size())
+ unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
+ else
+ unitDims.insert(unitDimsReductionLoops.begin(),
+ unitDimsReductionLoops.end());
+ }
+
if (unitDims.empty())
return failure();
@@ -293,7 +321,6 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
- // TODO: support reductions.
if (!op.hasTensorSemantics())
return failure();
@@ -565,7 +592,6 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldReshapeOpWithUnitExtent>(context);
- populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
}
namespace {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 23c34111f826..c086b212c5c6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -827,6 +827,23 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
// -----
+// CHECK: func @fold_fill_reshape_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?xf32>
+func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> {
+ %zero = constant 0.0 : f32
+ // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
+ %0 = linalg.fill(%arg0, %zero) : tensor<?x?x?x?x?xf32>, f32 -> tensor<?x?x?x?x?xf32>
+ // CHECK: %[[RESULT:.+]] = linalg.fill(%[[RESHAPE]], %{{.+}})
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+ : tensor<?x?x?x?x?xf32> into tensor<?x?xf32>
+ // CHECK: return %[[RESULT]]
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
#map0 = affine_map<(d0) -> (24, -d0 + 192)>
#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
#map2 = affine_map<(d0) -> (16, -d0 + 192)>
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 2a6711018988..5c8866662d2b 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -456,3 +456,113 @@ func @no_fold_subtensor(
// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
// CHECK: return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
+ %cst = constant 1.000000e+00 : f32
+ %c3 = constant 3 : index
+ %0 = memref.dim %arg0, %c3 : tensor<1x?x1x?xf32>
+ %1 = linalg.init_tensor [1, %0] : tensor<1x?xf32>
+ %2 = linalg.fill(%1, %cst) : tensor<1x?xf32>, f32 -> tensor<1x?xf32>
+ %3 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ ins(%arg0 : tensor<1x?x1x?xf32>)
+ outs(%2 : tensor<1x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32): // no predecessors
+ %4 = addf %arg1, %arg2 : f32
+ linalg.yield %4 : f32
+ } -> tensor<1x?xf32>
+ return %3 : tensor<1x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: func @unit_dim_for_reduction
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32>
+// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] [#[[MAP0]], #[[MAP1]]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: iterator_types = ["parallel", "reduction"]
+// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
+// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
+// CHECK: return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> {
+ %cst = constant 1.000000e+00 : f32
+ %c3 = constant 3 : index
+ %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
+ %2 = linalg.fill(%1, %cst) : tensor<1x1xf32>, f32 -> tensor<1x1xf32>
+ %3 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ ins(%arg0 : tensor<1x?x1x1xf32>)
+ outs(%2 : tensor<1x1xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32): // no predecessors
+ %4 = addf %arg1, %arg2 : f32
+ linalg.yield %4 : f32
+ } -> tensor<1x1xf32>
+ return %3 : tensor<1x1xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: func @unit_dim_for_reduction_keep_one
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x1xf32>
+// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] [#[[MAP0]], #[[MAP1]]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: iterator_types = ["parallel", "reduction"]
+// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x1xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>)
+// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
+// CHECK: return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32> {
+ %cst = constant 1.000000e+00 : f32
+ %c2 = constant 2 : index
+ %0 = memref.dim %arg0, %c2 : tensor<?x1x?x1xf32>
+ %1 = linalg.init_tensor [%0, 1] : tensor<?x1xf32>
+ %2 = linalg.fill(%1, %cst) : tensor<?x1xf32>, f32 -> tensor<?x1xf32>
+ %3 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ ins(%arg0 : tensor<?x1x?x1xf32>)
+ outs(%2 : tensor<?x1xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32): // no predecessors
+ %4 = addf %arg1, %arg2 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x1xf32>
+ return %3 : tensor<?x1xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: func @unit_dim_for_reduction_inner
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x1xf32>
+// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] [#[[MAP0]], #[[MAP1]]]
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: iterator_types = ["parallel", "reduction"]
+// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
+// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
+// CHECK: return %[[RESULT_RESHAPE]]
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index d5dc176f1fdf..9c0fe41684ee 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -174,18 +174,16 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @generic_op_reshape_consumer_static
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
-// CHECK: %[[T0:.+]] = linalg.init_tensor [264, 4]
-// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32>
-// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T0]]
-// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
-// CHECK: %[[T3:.+]] = linalg.generic
+// CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4]
+// CHECK: %[[T2:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[T1]] : tensor<8x33x4xf32>)
-// CHECK-SAME: outs(%[[T2]] : tensor<8x33x4xf32>)
-// CHECK: return %[[T3]] : tensor<8x33x4xf32>
+// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>)
+// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
+// CHECK: return %[[T2]] : tensor<8x33x4xf32>
// -----
@@ -317,36 +315,31 @@ func @reshape_as_consumer_permutation
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
-// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
-// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
-// CHECK-DAG: #[[MAP11:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
-// CHECK-DAG: #[[MAP12:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
-// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [6, 4, 210]
// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-SAME: [#[[MAP3]], #[[MAP4]]]
-// CHECK: %[[T3:.+]] = linalg.tensor_reshape %[[T0]]
-// CHECK-SAME: [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
// CHECK: %[[T4:.+]] = linalg.indexed_generic
-// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
+// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
-// CHECK-SAME: outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>)
+// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index, %[[ARG7:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32)
-// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP11]](%[[ARG3]], %[[ARG2]])
-// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP12]](%[[ARG6]], %[[ARG5]], %[[ARG4]])
+// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[ARG3]], %[[ARG2]])
+// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[ARG6]], %[[ARG5]], %[[ARG4]])
// CHECK-DAG: %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]]
// CHECK: %[[T8:.+]] = index_cast %[[T5]]
// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]]
@@ -535,8 +528,11 @@ func @unit_dim_reshape_expansion_full
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
// FOLDUNITDIM: func @unit_dim_reshape_expansion_full
-// FOLDUNITDIM: linalg.init_tensor
-// FOLDUNITDIM-COUNT-2: linalg.tensor_reshape
+// FOLDUNITDIM-SAME: %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32>
+// FOLDUNITDIM-SAME: %[[ARG1:.+]]: tensor<?x2x4xf32>
+// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG1]]
+// FOLDUNITDIM-DAG: %[[INIT:.+]] = linalg.init_tensor [1, %{{.+}}, 1, 2, 1, 4]
// FOLDUNITDIM: linalg.generic
-// FOLDUNITDIM-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
+// FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
+// FOLDUNITDIM-SAME: outs(%[[INIT]] : tensor<1x?x1x2x1x4xf32>)
More information about the Mlir-commits
mailing list