[Mlir-commits] [mlir] c4486cf - [mlir][Linalg] Fix reshape fusion to reshape the outs instead of creating new tensors.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 11 09:26:45 PST 2021


Author: MaheshRavishankar
Date: 2021-01-11T09:26:22-08:00
New Revision: c4486cfd556869a837911c7719fb6c36018bbd1f

URL: https://github.com/llvm/llvm-project/commit/c4486cfd556869a837911c7719fb6c36018bbd1f
DIFF: https://github.com/llvm/llvm-project/commit/c4486cfd556869a837911c7719fb6c36018bbd1f.diff

LOG: [mlir][Linalg] Fix reshape fusion to reshape the outs instead of creating new tensors.

When fusing tensor_reshape ops with generic/indexed_Generic op, new
linalg.init_tensor operations were created for the `outs` of the fused
op. While correct (technically) it is better to just reshape the
original `outs` operands and rely on canonicalization of init_tensor
-> tensor_reshape to achieve the same effect.

Differential Revision: https://reviews.llvm.org/D93774

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/reshape_fusion.mlir
    mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a2e7d436eeb8..0ce86e403681 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -183,6 +183,14 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
       return llvm::to_vector<4>(llvm::map_range(reassociation(), [
       ](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
     }
+    SmallVector<ReassociationExprs, 4> getReassociationExprs() {
+      return
+        llvm::to_vector<4>(llvm::map_range(reassociation(),
+	  [](Attribute a) {
+	    return llvm::to_vector<2>(
+	      a.cast<AffineMapAttr>().getValue().getResults());
+	  }));
+    }
   }];
   let assemblyFormat = [{
     $src $reassociation attr-dict `:` type($src) `into` type(results)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 37062ac33e2b..833662d282b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -566,45 +566,6 @@ static RankedTensorType getExpandedType(RankedTensorType originalType,
   return RankedTensorType::get(expandedShape, originalType.getElementType());
 }
 
-/// Get the value to use for the output in the expanded operation given the
-/// `indexingMap` for the output in the original op. Creates an
-/// `linalg.init_tensor` operation to materialize the tensor that carries the
-/// shape information. This is only used when the tensor_reshape is expanding
-/// and is a consumer. In such cases, the tensor_reshape op semantics gaurantees
-/// that the shape of the output is computable from the shape of the input since
-/// at most one of the expanded dims can be dynamic.
-static Value getOutputValueForExpandedOp(OpBuilder &builder, Location loc,
-                                         AffineMap indexingMap, Value result,
-                                         const ExpansionInfo &expansionInfo) {
-  SmallVector<Value, 4> dynamicDims;
-  SmallVector<int64_t, 4> staticDims;
-  ShapedType resultType = result.getType().cast<ShapedType>();
-  ArrayRef<int64_t> origShape = resultType.getShape();
-  for (AffineExpr expr : indexingMap.getResults()) {
-    unsigned origDimPos = expr.cast<AffineDimExpr>().getPosition();
-    bool foundDynamic = false;
-    int64_t linearizedShape = 1;
-    for (int64_t extent : expansionInfo.getExpandedShapeOfDim(origDimPos)) {
-      if (ShapedType::isDynamic(extent)) {
-        assert(!foundDynamic &&
-               "Expanded dimensions of reshape can have only one dynamic dim");
-        staticDims.push_back(ShapedType::kDynamicSize);
-        foundDynamic = true;
-        continue;
-      }
-      staticDims.push_back(extent);
-      linearizedShape *= extent;
-    }
-    if (ShapedType::isDynamic(origShape[origDimPos])) {
-      Value origDim = builder.create<DimOp>(loc, result, origDimPos);
-      dynamicDims.push_back(builder.create<UnsignedDivIOp>(
-          loc, origDim, builder.create<ConstantIndexOp>(loc, linearizedShape)));
-    }
-  }
-  return builder.create<linalg::InitTensorOp>(loc, dynamicDims, staticDims,
-                                              resultType.getElementType());
-}
-
 /// Returns the reassociation maps to use in the `linalg.tensor_reshape`
 /// operation to convert the operands of the origial operation to operands of
 /// the expanded operation. The same method is used to compute the
@@ -734,8 +695,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
   SmallVector<Value, 1> outputs;
   for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
     AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index());
-    outputs.push_back(getOutputValueForExpandedOp(
-        rewriter, loc, indexingMap, result.value(), expansionInfo));
+    RankedTensorType expandedOutputType =
+        getExpandedType(result.value().getType().cast<RankedTensorType>(),
+                        indexingMap, expansionInfo);
+    if (expandedOutputType != result.value().getType()) {
+      SmallVector<ReassociationIndices, 4> reassociation =
+          getReassociationForExpansion(indexingMap, expansionInfo);
+      outputs.push_back(rewriter.create<TensorReshapeOp>(
+          linalgOp.getLoc(), expandedOutputType, result.value(),
+          reassociation));
+    }
   }
 
   // The iterator types of the expanded op are all parallel.
@@ -779,47 +748,6 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
   return resultVals;
 }
 
-static Value
-getOutputValueForLinearization(OpBuilder &builder, Location loc,
-                               Value origOutput,
-                               ArrayRef<AffineMap> reassociationMaps) {
-  SmallVector<Value, 4> dynamicDims;
-  SmallVector<int64_t, 4> staticDims;
-  auto shapedType = origOutput.getType().cast<ShapedType>();
-  ArrayRef<int64_t> origShape = shapedType.getShape();
-  for (auto map : reassociationMaps) {
-    Optional<Value> dynamicDim;
-    int64_t staticLinearizedShape = 1;
-    for (AffineDimExpr expr :
-         llvm::map_range(map.getResults(), [](AffineExpr e) {
-           return e.cast<AffineDimExpr>();
-         })) {
-      unsigned pos = expr.getPosition();
-      if (ShapedType::isDynamic(origShape[pos])) {
-        Value dim = builder.create<DimOp>(loc, origOutput, pos);
-        if (dynamicDim) {
-          dynamicDim = builder.create<MulIOp>(loc, dynamicDim.getValue(), dim);
-        } else {
-          dynamicDim = dim;
-        }
-      } else {
-        staticLinearizedShape *= origShape[pos];
-      }
-    }
-    if (dynamicDim) {
-      dynamicDim = builder.create<MulIOp>(
-          loc, dynamicDim.getValue(),
-          builder.create<ConstantIndexOp>(loc, staticLinearizedShape));
-      dynamicDims.push_back(dynamicDim.getValue());
-      staticDims.push_back(ShapedType::kDynamicSize);
-    } else {
-      staticDims.push_back(staticLinearizedShape);
-    }
-  }
-  return builder.create<InitTensorOp>(loc, dynamicDims, staticDims,
-                                      shapedType.getElementType());
-}
-
 namespace {
 
 /// Pattern to fold tensor_reshape op with its consumer by using the source of
@@ -973,7 +901,7 @@ struct FoldConsumerReshapeOpByLinearization
                                reshapeOp.getReassociationMaps());
     for (AffineExpr expr : modifiedMap.getResults()) {
       if (!expr.isPureAffine())
-        return reshapeOp.emitRemark("fused op indexing map is not affine");
+        return producer.emitRemark("fused op indexing map is not affine");
     }
     fusedIndexMaps.back() = modifiedMap;
 
@@ -983,9 +911,8 @@ struct FoldConsumerReshapeOpByLinearization
       return reshapeOp.emitRemark("fused op loop bound computation failed");
 
     Location loc = producer.getLoc();
-    Value output =
-        getOutputValueForLinearization(rewriter, loc, producer.getOutputs()[0],
-                                       reshapeOp.getReassociationMaps());
+    Value output = rewriter.create<TensorReshapeOp>(
+        loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
     LinalgOp fusedOp = createLinalgOpOfSameType(
         producer, rewriter, loc, reshapeOp.getResultType(),
         /*inputs=*/producer.getInputs(),

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 92805218dde7..c8c3c12a3ea4 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -14,7 +14,7 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
      indexing_maps = [#map0, #map1, #map1],
      iterator_types = ["parallel", "parallel", "parallel"]}
        ins(%0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
-      outs(%0 : tensor<?x?x?xf32>) {
+       outs(%0 : tensor<?x?x?xf32>) {
     ^bb0(%arg3: f32, %arg4: f32, %s: f32):       // no predecessors
       %1 = mulf %arg3, %arg4 : f32
       linalg.yield %1 : f32
@@ -32,19 +32,12 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
 //      CHECK: func @generic_op_reshape_producer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//  CHECK-DAG:   %[[C2:.+]] = constant 2 : index
-//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
 //      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
 // CHECK-SAME:     [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 //      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
 // CHECK-SAME:     [#[[MAP0]], #[[MAP3]], #[[MAP4]]]
-//  CHECK-DAG:   %[[D0:.+]] = dim %[[T0]], %[[C0]]
-//  CHECK-DAG:   %[[D1:.+]] = dim %[[T0]], %[[C1]]
-//  CHECK-DAG:   %[[D2:.+]] = dim %[[T0]], %[[C2]]
-//      CHECK:   %[[D3:.+]] = divi_unsigned %[[D0]], %[[C4]]
-//      CHECK:   %[[T2:.+]] = linalg.init_tensor [%[[D1]], %[[D2]], %[[D3]], 4]
+//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[T0]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP3]], #[[MAP4]]]
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -66,7 +59,7 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
      indexing_maps = [#map0, #map0, #map0],
      iterator_types = ["parallel", "parallel"]}
        ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%arg0 : tensor<?x?xf32>) {
+       outs(%arg0 : tensor<?x?xf32>) {
     ^bb0(%arg3: f32, %arg4: f32, %s: f32):       // no predecessors
       %1 = mulf %arg3, %arg4 : f32
       linalg.yield %1 : f32
@@ -83,19 +76,14 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 //      CHECK: func @generic_op_reshape_consumer_fusion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//  CHECK-DAG:   %[[C20:.+]] = constant 20 : index
 //      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
 // CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
 //      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
 // CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
-//  CHECK-DAG:   %[[D0:.+]] = dim %[[ARG0]], %[[C0]]
-//  CHECK-DAG:   %[[D1:.+]] = dim %[[ARG0]], %[[C1]]
-//      CHECK:   %[[D2:.+]] = divi_unsigned %[[D1]], %[[C20]]
-//      CHECK:   %[[T2:.+]] = linalg.init_tensor [%[[D0]], 4, %[[D2]], 5]
+//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -132,30 +120,25 @@ func @reshape_as_consumer_permutation
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
 //  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 //  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-//  CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-//  CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>
+//  CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+//  CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+//  CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+//  CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
 //      CHECK: func @reshape_as_consumer_permutation
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//  CHECK-DAG:   %[[C2:.+]] = constant 2 : index
-//  CHECK-DAG:   %[[C12:.+]] = constant 12 : index
 //      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
 // CHECK-SAME:     [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 // CHECK-SAME:     tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
 //      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
 // CHECK-SAME:     [#[[MAP3]], #[[MAP4]]]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<3x4x?x?xf32>
-//  CHECK-DAG:   %[[D0:.+]] = dim %[[ARG0]], %[[C0]]
-//      CHECK:   %[[D1:.+]] = divi_unsigned %[[D0]], %[[C2]]
-//  CHECK-DAG:   %[[D2:.+]] = dim %[[ARG0]], %[[C2]]
-//  CHECK-DAG:   %[[D3:.+]] = dim %[[ARG0]], %[[C1]]
-//  CHECK-DAG:   %[[D4:.+]] = divi_unsigned %[[D3]], %[[C12]]
-//      CHECK:   %[[T2:.+]] = linalg.init_tensor [%[[D1]], 2, %[[D2]], 3, 4, %[[D4]]]
+//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:     [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
 //      CHECK:   %[[T3:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+// CHECK-SAME:     indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>)
 // CHECK-SAME:     outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>)
@@ -170,18 +153,19 @@ func @reshape_as_consumer_permutation
 func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
                                             -> tensor<8x33x4xf32> {
   %cst = constant dense<2.000000e+00> : tensor<264x4xf32>
-  %0 = linalg.generic {
+  %0 = linalg.init_tensor [264, 4] : tensor<264x4xf32>
+  %1 = linalg.generic {
      indexing_maps = [#map0, #map0, #map0],
      iterator_types = ["parallel", "parallel"]}
        ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>)
-      outs(%arg0 : tensor<264x4xf32>) {
+       outs(%0 : tensor<264x4xf32>) {
     ^bb0(%arg1: f32, %arg2: f32, %s: f32):  // no predecessors
       %2 = mulf %arg1, %arg2 : f32
       linalg.yield %2 : f32
     } -> tensor<264x4xf32>
-  %1 = linalg.tensor_reshape %0 [#map1, #map2] :
+  %2 = linalg.tensor_reshape %1 [#map1, #map2] :
     tensor<264x4xf32> into tensor<8x33x4xf32>
-  return %1 : tensor<8x33x4xf32>
+  return %2 : tensor<8x33x4xf32>
 }
 
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -189,51 +173,54 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK: func @generic_op_reshape_consumer_static
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
-//      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//      CHECK:   %[[T0:.+]] = linalg.init_tensor [264, 4]
+//      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
 // CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
 // CHECK-SAME:     tensor<264x4xf32> into tensor<8x33x4xf32>
-//      CHECK:   %[[T1:.+]] = linalg.init_tensor [8, 33, 4] : tensor<8x33x4xf32>
-//      CHECK:   %[[T2:.+]] = linalg.generic
+//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[T0]]
+// CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
+//      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[T0]] : tensor<8x33x4xf32>)
-// CHECK-SAME:     outs(%[[T1]] : tensor<8x33x4xf32>)
-//      CHECK:   return %[[T2]] : tensor<8x33x4xf32>
+// CHECK-SAME:     ins(%[[T1]] : tensor<8x33x4xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<8x33x4xf32>)
+//      CHECK:   return %[[T3]] : tensor<8x33x4xf32>
 
 // -----
 
 func @scalar_reshape(
-  %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>, %shape : tensor<10xf32>)
-    -> tensor<1x10xf32>
+  %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>) -> tensor<1x10xf32>
 {
   %0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor<f32>
-  %1 = linalg.generic
+  %1 = linalg.init_tensor [10] : tensor<10xf32>
+  %2 = linalg.generic
     {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
      iterator_types = ["parallel"]}
      ins(%0 : tensor<f32>)
-    outs(%shape : tensor<10xf32>) {
+     outs(%1 : tensor<10xf32>) {
   ^bb0(%arg2: f32, %s: f32):  // no predecessors
     linalg.yield %arg2 : f32
   } -> tensor<10xf32>
-  %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1) -> (d0, d1)>]
+  %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>]
     : tensor<10xf32> into tensor<1x10xf32>
-  return %2 : tensor<1x10xf32>
+  return %3 : tensor<1x10xf32>
 }
 
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()>
 //      CHECK: func @scalar_reshape
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32>
 //      CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] []
 // CHECK-SAME:     tensor<1xf32> into tensor<f32>
-//      CHECK:   %[[T1:.+]] = linalg.init_tensor [1, 10] : tensor<1x10xf32>
-//      CHECK:   %[[T2:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//      CHECK:   %[[T1:.+]] = linalg.init_tensor [10]
+//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[T1]] [#[[MAP0]]]
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP0]]]
 // CHECK-SAME:     iterator_types = ["parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T0]] : tensor<f32>)
-// CHECK-SAME:     outs(%[[T1]] : tensor<1x10xf32>)
-//      CHECK:   return %[[T2]] : tensor<1x10xf32>
+// CHECK-SAME:     outs(%[[T2]] : tensor<1x10xf32>)
+//      CHECK:   return %[[T3]] : tensor<1x10xf32>
 
 // -----
 
@@ -331,15 +318,16 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
 // -----
 
 func @reshape_as_consumer_permutation
-  (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>, %shape : tensor<6x4x210xi32>)
+  (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
     -> tensor<2x3x4x5x6x7xi32> {
+  %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32>
   %c = linalg.indexed_generic {
          indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
                           affine_map<(d0, d1, d2) -> (d1, d2)>,
                           affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
          iterator_types = ["parallel", "parallel", "parallel"]}
           ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>)
-         outs(%shape : tensor<6x4x210xi32>) {
+          outs(%shape : tensor<6x4x210xi32>) {
        ^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32, %s: i32):
          %1 = addi %arg3, %arg4 : i32
          %2 = index_cast %arg0 : index to i32
@@ -364,38 +352,43 @@ func @reshape_as_consumer_permutation
 //   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
 //   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 //   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 * 3 + d1)>
-//   CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2) -> (d0 * 42 + d1 * 7 + d2)>
-//   CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-//   CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-//   CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>
+//   CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+//   CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+//   CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+//   CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+//   CHECK-DAG: #[[MAP11:.+]] = affine_map<(d0, d1) -> (d0 * 3 + d1)>
+//   CHECK-DAG: #[[MAP12:.+]] = affine_map<(d0, d1, d2) -> (d0 * 42 + d1 * 7 + d2)>
 //       CHECK: func @reshape_as_consumer_permutation
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<210x6x4xi32>
 //  CHECK-SAME:   %[[ARG1:.+]]: tensor<210x4xi32>
-//   CHECK-DAG:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//   CHECK-DAG:   %[[T0:.+]] = linalg.init_tensor [6, 4, 210]
+//   CHECK-DAG:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
 //  CHECK-SAME:     [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-//   CHECK-DAG:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+//   CHECK-DAG:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
 //  CHECK-SAME:     [#[[MAP3]], #[[MAP4]]]
-//       CHECK:   %[[T2:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
-//       CHECK:   %[[T3:.+]] = linalg.indexed_generic
-//  CHECK-SAME:     indexing_maps = [#[[MAP7]], #[[MAP8]], #[[MAP9]]]
-//  CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
-//  CHECK-SAME:     outs(%[[T2]] : tensor<2x3x4x5x6x7xi32>)
+//       CHECK:   %[[T3:.+]] = linalg.tensor_reshape %[[T0]]
+//  CHECK-SAME:     [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+//       CHECK:   %[[T4:.+]] = linalg.indexed_generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
+//  CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
+//  CHECK-SAME:     outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>)
 //       CHECK:   ^{{.+}}(
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG6:[a-zA-Z0-9]+]]: index, %[[ARG7:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
 //  CHECK-SAME:     %[[ARG10:[a-zA-Z0-9]+]]: i32)
-//   CHECK-DAG:       %[[T4:.+]] = affine.apply #[[MAP5]](%[[ARG2]], %[[ARG3]])
-//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP6]](%[[ARG4]], %[[ARG5]], %[[ARG6]])
-//   CHECK-DAG:       %[[T6:.+]] = addi %[[ARG8]], %[[ARG9]]
-//       CHECK:       %[[T7:.+]] = index_cast %[[T4]]
-//       CHECK:       %[[T8:.+]] = addi %[[T6]], %[[T7]]
-//       CHECK:       %[[T9:.+]] = index_cast %[[T5]]
-//       CHECK:       %[[T10:.+]] = addi %[[T8]], %[[T9]]
-//       CHECK:       %[[T11:.+]] = index_cast %[[ARG7]]
-//       CHECK:       %[[T12:.+]] = addi %[[T10]], %[[T11]]
+//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP11]](%[[ARG2]], %[[ARG3]])
+//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP12]](%[[ARG4]], %[[ARG5]], %[[ARG6]])
+//   CHECK-DAG:       %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]]
+//       CHECK:       %[[T8:.+]] = index_cast %[[T5]]
+//       CHECK:       %[[T9:.+]] = addi %[[T7]], %[[T8]]
+//       CHECK:       %[[T10:.+]] = index_cast %[[T6]]
+//       CHECK:       %[[T11:.+]] = addi %[[T9]], %[[T10]]
+//       CHECK:       %[[T12:.+]] = index_cast %[[ARG7]]
+//       CHECK:       %[[T13:.+]] = addi %[[T11]], %[[T12]]
 
 // -----
 
@@ -466,7 +459,7 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
      indexing_maps = [#map0, #map0, #map1],
      iterator_types = ["parallel", "parallel"]}
        ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
-      outs(%arg0 : tensor<?x?xf32>) {
+       outs(%arg0 : tensor<?x?xf32>) {
     ^bb0(%arg3: f32, %arg4: f32, %s: f32):       // no predecessors
       %1 = mulf %arg3, %arg4 : f32
       linalg.yield %1 : f32
@@ -479,8 +472,10 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
 
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
 //      CHECK: func @generic_op_reshape_consumer_fusion_projected
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -490,8 +485,11 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
 //      CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
 // CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
 // CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
-//      CHECK:   %[[T2:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]]
+//      CHECK:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:     [#[[MAP2]], #[[MAP3]]]
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
-//      CHECK:   return %[[T2]] : tensor<?x?x4x5xf32>
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
+//      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>

diff  --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
index aff1447a63c7..cb57f00372c7 100644
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization -verify-diagnostics %s | FileCheck %s
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
@@ -21,14 +21,19 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
   return %1 : tensor<?x?x4x?xf32>
 }
 
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//   CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+//   CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //       CHECK: func @generic_op_reshape_producer_fusion
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
+//       CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//  CHECK-SAME:   [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 //       CHECK: linalg.generic
-//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]]
+//  CHECK-SAME:   indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]]
 //  CHECK-SAME:   ins(%[[ARG0]], %{{.+}} : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>)
-//  CHECK-SAME:   outs(%{{.+}} : tensor<?x?x4x?xf32>)
+//  CHECK-SAME:   outs(%[[T0]] : tensor<?x?x4x?xf32>)
 
 // -----
 
@@ -52,47 +57,17 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
   return %1 : tensor<?x?xf32>
 }
 
-
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+//   CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//   CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
 //       CHECK: func @generic_op_reshape_consumer_fusion
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xf32>
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG:   %[[C20:.+]] = constant 20 : index
-//       CHECK:   %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
-//       CHECK:   %[[T1:.+]] = dim %[[ARG0]], %[[C1]]
-//       CHECK:   %[[T2:.+]] = muli %[[T1]], %[[C20]]
-//       CHECK:   %[[T3:.+]] = linalg.init_tensor [%[[T0]], %[[T2]]]
+//       CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//  CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
 //       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]]
-//  CHECK-SAME:     outs(%[[T3]] : tensor<?x?xf32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
-                                           %arg1 : tensor<?x?x?x5xf32>) ->
-                                           tensor<?x?xf32>
-{
-  %0 = linalg.generic {
-     indexing_maps = [#map0, #map0, #map0],
-     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
-      ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>)
-      outs(%arg0 : tensor<?x?x?x5xf32>) {
-    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):       // no predecessors
-      %1 = mulf %arg3, %arg4 : f32
-      linalg.yield %1 : f32
-  } -> tensor<?x?x?x5xf32>
-  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
-                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
-    tensor<?x?x?x5xf32> into tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
-//       CHECK:   %[[T0:.+]] = linalg.generic
-//       CHECK:   linalg.tensor_reshape %[[T0]]
+//  CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]]
+//  CHECK-SAME:     outs(%[[T0]] : tensor<?x?xf32>)
 
 // -----
 
@@ -116,13 +91,19 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
   return %1 : tensor<?x?x4x?xi32>
 }
 
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+//   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 //       CHECK: func @indexed_generic_op_reshape_producer_fusion
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xi32>
+//       CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//  CHECK-SAME:     [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 //       CHECK:   linalg.indexed_generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//  CHECK-SAME:     indexing_maps = [#[[MAP3]], #[[MAP4]]]
 //  CHECK-SAME:     ins(%[[ARG0]] : tensor<?x?x?xi32>)
+//  CHECK-SAME:     outs(%[[T0]] : tensor<?x?x4x?xi32>)
 
 // -----
 
@@ -144,20 +125,17 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
   return %1 : tensor<?x?xi32>
 }
 
-//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+//       CHECK: func @indexed_generic_op_reshape_consumer_fusion
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG:   %[[C20:.+]] = constant 20 : index
-//       CHECK:   %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
-//       CHECK:   %[[T1:.+]] = dim %[[ARG0]], %[[C1]]
-//       CHECK:   %[[T2:.+]] = muli %[[T1]], %[[C20]]
-//       CHECK:   %[[T3:.+]] = linalg.init_tensor [%[[T0]], %[[T2]]]
+//       CHECK:   %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+//  CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
 //       CHECK:   linalg.indexed_generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-//  CHECK-SAME:     outs(%[[T3]] : tensor<?x?xi32>)
+//  CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
+//  CHECK-SAME:     outs(%[[T0]] : tensor<?x?xi32>)
 //   CHECK-NOT:   linalg.tensor_reshape
 
 // -----
@@ -179,12 +157,12 @@ func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
     return %2 : tensor<3x7x5xf32>
 }
 
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//       CHECK: func @generic_op_021_permultation_reshape_producer_fusion
 //   CHECK-NOT:   linalg.tensor_reshape
 //       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
 
 // -----
 
@@ -210,7 +188,7 @@ func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
 //       CHECK: func @generic_op_120_permultation_reshape_producer_fusion
 //   CHECK-NOT:   linalg.tensor_reshape
 //       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
 
 // -----
 
@@ -237,7 +215,7 @@ func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
 //       CHECK: func @generic_op_102_permultation_reshape_producer_fusion
 //   CHECK-NOT:   linalg.tensor_reshape
 //       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
 
 // -----
 
@@ -258,10 +236,39 @@ func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf
   return %2 : tensor<5x21xf32>
 }
 
-
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
 //       CHECK: func @generic_op_102_permultation_reshape_consumer_fusion
-//   CHECK-NOT:   linalg.tensor_reshape
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<3x5x7xf32>
+//       CHECK:   %[[T0:.+]] = linalg.init_tensor [5, 3, 7]
+//       CHECK:   %[[T1:.+]] = linalg.tensor_reshape %[[T0]]
+//  CHECK-SAME:     [#[[MAP0]], #[[MAP1]]]
 //       CHECK:   linalg.generic
-//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//  CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
+//  CHECK-SAME:     ins(%[[ARG0]] : tensor<3x5x7xf32>)
+//  CHECK-SAME:     outs(%[[T1]] : tensor<5x21xf32>)
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
+                                           %arg1 : tensor<?x?x?x5xf32>) ->
+                                           tensor<?x?xf32>
+{
+  // expected-remark @+1 {{fused op indexing map is not affine}}
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>)
+      outs(%arg0 : tensor<?x?x?x5xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):       // no predecessors
+      %1 = mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?x?x5xf32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?x?x5xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}


        


More information about the Mlir-commits mailing list