[Mlir-commits] [mlir] f4eb681 - [mlir][Linalg] Drop unit-trip loops of reductions only if other reduction loops exists.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 8 22:31:48 PDT 2021


Author: MaheshRavishankar
Date: 2021-04-08T22:31:29-07:00
New Revision: f4eb681dc37ae84e08579bf96cd2a6f58c44f260

URL: https://github.com/llvm/llvm-project/commit/f4eb681dc37ae84e08579bf96cd2a6f58c44f260
DIFF: https://github.com/llvm/llvm-project/commit/f4eb681dc37ae84e08579bf96cd2a6f58c44f260.diff

LOG: [mlir][Linalg] Drop unit-trip loops of reductions only if other reduction loops exists.

Recent change enable dropping unit-trip loops of "reduction" iterator
type as well. This is fine as long as there is one other "reduction"
iterator in the operation. Without this the initialized value (value
of `out`) is not read which leads to a correctness issue.

Also fix a bug in the `fill` -> `tensor_reshape` folding. The `out`
operand of the `fill` needs to be reshaped to get the `out` operand of
the generated `fill` operation.

Differential Revision: https://reviews.llvm.org/D100145

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/Linalg/reshape_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 786b9ec85dcf..344ffe977caf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -31,7 +31,9 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
            "Only folds the one-trip loops from Linalg ops on tensors "
            "(for testing purposes only)">
   ];
-  let dependentDialects = ["linalg::LinalgDialect"];
+  let dependentDialects = [
+    "linalg::LinalgDialect", "AffineDialect", "memref::MemRefDialect"
+  ];
 }
 
 def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
@@ -43,7 +45,9 @@ def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
            "Allow fusing linalg.tensor_reshape ops that performs unit "
            "dimension collapsing">
   ];
-  let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
+  let dependentDialects = [
+    "AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
+  ];
 }
 
 def LinalgFoldReshapeOpsByLinearization :
@@ -51,7 +55,7 @@ def LinalgFoldReshapeOpsByLinearization :
   let summary = "Fold TensorReshapeOps with generic/indexed generic ops by "
                 "linearization";
   let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()";
-  let dependentDialects = ["AffineDialect"];
+  let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
 }
 
 def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
@@ -64,7 +68,8 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
                "interchange vector",
                "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
   ];
-  let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
+  let dependentDialects = [
+    "AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"];
 }
 
 def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 76fa96b6346c..d8b512cdeea0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1681,9 +1681,10 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
     if (!oldFill)
       return failure();
 
-    auto newInit = rewriter.create<InitTensorOp>(
-        oldFill.getLoc(), reshapeOp.getResultType().getShape(),
-        reshapeOp.getResultType().getElementType());
+    Location loc = oldFill.getLoc();
+    auto newInit = rewriter.create<TensorReshapeOp>(
+        loc, reshapeOp.getResultType(), oldFill.output(),
+        reshapeOp.reassociation());
     rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, newInit, oldFill.value());
 
     return success();
@@ -1694,7 +1695,8 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
 void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results.add<CollapseReshapeOps<TensorReshapeOp>, FoldFillWithTensorReshape,
-              FoldReshapeWithConstant>(context);
+              FoldInitTensorWithTensorReshapeOp, FoldReshapeWithConstant>(
+      context);
 }
 
 LogicalResult TensorReshapeOp::reifyReturnTypeShapesPerResultDim(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 7aefefd642ea..b3af82c80cee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -190,15 +190,43 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
     SmallVector<int64_t, 4> dims;
     for (ShapedType shapedType : op.getShapedOperandTypes())
       dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
+
+    // Find all the reduction iterators. Those need some special consideration
+    // (see below).
+    auto getLoopDimsOfType =
+        [&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
+      SmallVector<AffineExpr> dimExprs;
+      getDimsOfType(op, iteratorTypeName, dimExprs);
+      return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
+        return expr.cast<AffineDimExpr>().getPosition();
+      }));
+    };
+    auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName());
+
     DenseSet<unsigned> unitDims;
+    SmallVector<unsigned, 4> unitDimsReductionLoops;
     ArrayAttr iteratorTypes = op.iterator_types();
     for (auto expr : enumerate(invertedMap.getResults())) {
       if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
-        if (dims[dimExpr.getPosition()] == 1 &&
-            iteratorTypes[expr.index()].dyn_cast<StringAttr>().getValue() ==
-                getParallelIteratorTypeName())
-          unitDims.insert(expr.index());
+        if (dims[dimExpr.getPosition()] == 1) {
+          if (isParallelIterator(iteratorTypes[expr.index()]))
+            unitDims.insert(expr.index());
+          else if (isReductionIterator(iteratorTypes[expr.index()]))
+            unitDimsReductionLoops.push_back(expr.index());
+        }
     }
+
+    // Reduction loops can be dropped if there is at least one other reduction
+    // loop that is not dropped. This accounts for the initial value read in the
+    // reduction loop.
+    if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) {
+      if (unitDimsReductionLoops.size() == reductionDims.size())
+        unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end()));
+      else
+        unitDims.insert(unitDimsReductionLoops.begin(),
+                        unitDimsReductionLoops.end());
+    }
+
     if (unitDims.empty())
       return failure();
 
@@ -293,7 +321,6 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(GenericOpTy op,
                                 PatternRewriter &rewriter) const override {
-    // TODO: support reductions.
     if (!op.hasTensorSemantics())
       return failure();
 
@@ -565,7 +592,6 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
                ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldReshapeOpWithUnitExtent>(context);
-  populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
 }
 
 namespace {

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 23c34111f826..c086b212c5c6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -827,6 +827,23 @@ func @fold_fill_reshape() -> tensor<6x4xf32> {
 
 // -----
 
+//       CHECK: func @fold_fill_reshape_dynamic
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?x?xf32>
+func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> {
+  %zero = constant 0.0 : f32
+  // CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
+  %0 = linalg.fill(%arg0, %zero) : tensor<?x?x?x?x?xf32>, f32 -> tensor<?x?x?x?x?xf32>
+  // CHECK: %[[RESULT:.+]] = linalg.fill(%[[RESHAPE]], %{{.+}})
+  %1 = linalg.tensor_reshape %0
+      [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+       affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+      : tensor<?x?x?x?x?xf32> into tensor<?x?xf32>
+  // CHECK: return %[[RESULT]]
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
 #map0 = affine_map<(d0) -> (24, -d0 + 192)>
 #map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
 #map2 = affine_map<(d0) -> (16, -d0 + 192)>

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 2a6711018988..5c8866662d2b 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -456,3 +456,113 @@ func @no_fold_subtensor(
 //      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
 // CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
 //      CHECK:   return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
+  %cst = constant 1.000000e+00 : f32
+  %c3 = constant 3 : index
+  %0 = memref.dim %arg0, %c3 : tensor<1x?x1x?xf32>
+  %1 = linalg.init_tensor [1, %0] : tensor<1x?xf32>
+  %2 = linalg.fill(%1, %cst) : tensor<1x?xf32>, f32 -> tensor<1x?xf32>
+  %3 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+                     affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    ins(%arg0 : tensor<1x?x1x?xf32>)
+    outs(%2 : tensor<1x?xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
+    %4 = addf %arg1, %arg2 : f32
+    linalg.yield %4 : f32
+  } -> tensor<1x?xf32>
+  return %3 : tensor<1x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
+//      CHECK: func @unit_dim_for_reduction
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x?xf32>
+//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] [#[[MAP0]], #[[MAP1]]]
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
+//      CHECK:   %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
+// CHECK-SAME:     iterator_types = ["parallel", "reduction"]
+// CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
+// CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
+//      CHECK:   return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> {
+  %cst = constant 1.000000e+00 : f32
+  %c3 = constant 3 : index
+  %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
+  %2 = linalg.fill(%1, %cst) : tensor<1x1xf32>, f32 -> tensor<1x1xf32>
+  %3 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+                     affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    ins(%arg0 : tensor<1x?x1x1xf32>)
+    outs(%2 : tensor<1x1xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
+    %4 = addf %arg1, %arg2 : f32
+    linalg.yield %4 : f32
+  } -> tensor<1x1xf32>
+  return %3 : tensor<1x1xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
+//      CHECK: func @unit_dim_for_reduction_keep_one
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x1xf32>
+//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] [#[[MAP0]], #[[MAP1]]]
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
+//      CHECK:   %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
+// CHECK-SAME:     iterator_types = ["parallel", "reduction"]
+// CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x1xf32>)
+// CHECK-SAME:     outs(%[[FILL]] : tensor<1xf32>)
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
+//      CHECK:   return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32> {
+  %cst = constant 1.000000e+00 : f32
+  %c2 = constant 2 : index
+  %0 = memref.dim %arg0, %c2 : tensor<?x1x?x1xf32>
+  %1 = linalg.init_tensor [%0, 1] : tensor<?x1xf32>
+  %2 = linalg.fill(%1, %cst) : tensor<?x1xf32>, f32 -> tensor<?x1xf32>
+  %3 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+                     affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    ins(%arg0 : tensor<?x1x?x1xf32>)
+    outs(%2 : tensor<?x1xf32>) {
+  ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
+    %4 = addf %arg1, %arg2 : f32
+    linalg.yield %4 : f32
+  } -> tensor<?x1xf32>
+  return %3 : tensor<?x1xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
+//      CHECK: func @unit_dim_for_reduction_inner
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x1x?x1xf32>
+//  CHECK-DAG:   %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] [#[[MAP0]], #[[MAP1]]]
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor [%{{.+}}] : tensor<?xf32>
+//      CHECK:   %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
+//      CHECK:   %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
+// CHECK-SAME:     iterator_types = ["parallel", "reduction"]
+// CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
+// CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
+//      CHECK:   return %[[RESULT_RESHAPE]]

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index d5dc176f1fdf..9c0fe41684ee 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -174,18 +174,16 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK: func @generic_op_reshape_consumer_static
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
-//      CHECK:   %[[T0:.+]] = linalg.init_tensor [264, 4]
-//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
+//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
 // CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
 // CHECK-SAME:     tensor<264x4xf32> into tensor<8x33x4xf32>
-//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[T0]]
-// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
-//      CHECK:   %[[T3:.+]] = linalg.generic
+//      CHECK:   %[[T1:.+]] = linalg.init_tensor [8, 33, 4]
+//      CHECK:   %[[T2:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[T1]] : tensor<8x33x4xf32>)
-// CHECK-SAME:     outs(%[[T2]] : tensor<8x33x4xf32>)
-//      CHECK:   return %[[T3]] : tensor<8x33x4xf32>
+// CHECK-SAME:     ins(%[[T0]] : tensor<8x33x4xf32>)
+// CHECK-SAME:     outs(%[[T1]] : tensor<8x33x4xf32>)
+//      CHECK:   return %[[T2]] : tensor<8x33x4xf32>
 
 // -----
 
@@ -317,36 +315,31 @@ func @reshape_as_consumer_permutation
 //   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
 //   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 //   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
-//   CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>
-//   CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
-//   CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-//   CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-//   CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
-//   CHECK-DAG: #[[MAP11:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
-//   CHECK-DAG: #[[MAP12:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
+//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+//   CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+//   CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+//   CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+//   CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
 //       CHECK: func @reshape_as_consumer_permutation
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<210x6x4xi32>
 //  CHECK-SAME:   %[[ARG1:.+]]: tensor<210x4xi32>
-//   CHECK-DAG:   %[[T0:.+]] = linalg.init_tensor [6, 4, 210]
 //   CHECK-DAG:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
 //  CHECK-SAME:     [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 //   CHECK-DAG:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
 //  CHECK-SAME:     [#[[MAP3]], #[[MAP4]]]
-//       CHECK:   %[[T3:.+]] = linalg.tensor_reshape %[[T0]]
-//  CHECK-SAME:     [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+//   CHECK-DAG:   %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
 //       CHECK:   %[[T4:.+]] = linalg.indexed_generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
+//  CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
 //  CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
-//  CHECK-SAME:     outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>)
+//  CHECK-SAME:     outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
 //       CHECK:   ^{{.+}}(
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG6:[a-zA-Z0-9]+]]: index, %[[ARG7:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
 //  CHECK-SAME:     %[[ARG10:[a-zA-Z0-9]+]]: i32)
-//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP11]](%[[ARG3]], %[[ARG2]])
-//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP12]](%[[ARG6]], %[[ARG5]], %[[ARG4]])
+//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP8]](%[[ARG3]], %[[ARG2]])
+//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP9]](%[[ARG6]], %[[ARG5]], %[[ARG4]])
 //   CHECK-DAG:       %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]]
 //       CHECK:       %[[T8:.+]] = index_cast %[[T5]]
 //       CHECK:       %[[T9:.+]] = addi %[[T7]], %[[T8]]
@@ -535,8 +528,11 @@ func @unit_dim_reshape_expansion_full
 // CHECK-SAME:     ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
 
 //         FOLDUNITDIM: func @unit_dim_reshape_expansion_full
-//         FOLDUNITDIM:   linalg.init_tensor
-// FOLDUNITDIM-COUNT-2:   linalg.tensor_reshape
+//    FOLDUNITDIM-SAME:   %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32>
+//    FOLDUNITDIM-SAME:   %[[ARG1:.+]]: tensor<?x2x4xf32>
+//     FOLDUNITDIM-DAG:   %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG1]]
+//     FOLDUNITDIM-DAG:   %[[INIT:.+]] = linalg.init_tensor [1, %{{.+}}, 1, 2, 1, 4]
 //         FOLDUNITDIM:   linalg.generic
-//    FOLDUNITDIM-SAME:     ins(%{{.+}}, %{{.+}} : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
+//    FOLDUNITDIM-SAME:     ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
+//    FOLDUNITDIM-SAME:     outs(%[[INIT]] : tensor<1x?x1x2x1x4xf32>)
 


        


More information about the Mlir-commits mailing list