[Mlir-commits] [mlir] c4486cf - [mlir][Linalg] Fix reshape fusion to reshape the outs instead of creating new tensors.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 11 09:26:45 PST 2021
Author: MaheshRavishankar
Date: 2021-01-11T09:26:22-08:00
New Revision: c4486cfd556869a837911c7719fb6c36018bbd1f
URL: https://github.com/llvm/llvm-project/commit/c4486cfd556869a837911c7719fb6c36018bbd1f
DIFF: https://github.com/llvm/llvm-project/commit/c4486cfd556869a837911c7719fb6c36018bbd1f.diff
LOG: [mlir][Linalg] Fix reshape fusion to reshape the outs instead of creating new tensors.
When fusing tensor_reshape ops with generic/indexed_Generic op, new
linalg.init_tensor operations were created for the `outs` of the fused
op. While correct (technically) it is better to just reshape the
original `outs` operands and rely on canonicalization of init_tensor
-> tensor_reshape to achieve the same effect.
Differential Revision: https://reviews.llvm.org/D93774
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/reshape_fusion.mlir
mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a2e7d436eeb8..0ce86e403681 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -183,6 +183,14 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
return llvm::to_vector<4>(llvm::map_range(reassociation(), [
](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
}
+ SmallVector<ReassociationExprs, 4> getReassociationExprs() {
+ return
+ llvm::to_vector<4>(llvm::map_range(reassociation(),
+ [](Attribute a) {
+ return llvm::to_vector<2>(
+ a.cast<AffineMapAttr>().getValue().getResults());
+ }));
+ }
}];
let assemblyFormat = [{
$src $reassociation attr-dict `:` type($src) `into` type(results)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 37062ac33e2b..833662d282b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -566,45 +566,6 @@ static RankedTensorType getExpandedType(RankedTensorType originalType,
return RankedTensorType::get(expandedShape, originalType.getElementType());
}
-/// Get the value to use for the output in the expanded operation given the
-/// `indexingMap` for the output in the original op. Creates an
-/// `linalg.init_tensor` operation to materialize the tensor that carries the
-/// shape information. This is only used when the tensor_reshape is expanding
-/// and is a consumer. In such cases, the tensor_reshape op semantics gaurantees
-/// that the shape of the output is computable from the shape of the input since
-/// at most one of the expanded dims can be dynamic.
-static Value getOutputValueForExpandedOp(OpBuilder &builder, Location loc,
- AffineMap indexingMap, Value result,
- const ExpansionInfo &expansionInfo) {
- SmallVector<Value, 4> dynamicDims;
- SmallVector<int64_t, 4> staticDims;
- ShapedType resultType = result.getType().cast<ShapedType>();
- ArrayRef<int64_t> origShape = resultType.getShape();
- for (AffineExpr expr : indexingMap.getResults()) {
- unsigned origDimPos = expr.cast<AffineDimExpr>().getPosition();
- bool foundDynamic = false;
- int64_t linearizedShape = 1;
- for (int64_t extent : expansionInfo.getExpandedShapeOfDim(origDimPos)) {
- if (ShapedType::isDynamic(extent)) {
- assert(!foundDynamic &&
- "Expanded dimensions of reshape can have only one dynamic dim");
- staticDims.push_back(ShapedType::kDynamicSize);
- foundDynamic = true;
- continue;
- }
- staticDims.push_back(extent);
- linearizedShape *= extent;
- }
- if (ShapedType::isDynamic(origShape[origDimPos])) {
- Value origDim = builder.create<DimOp>(loc, result, origDimPos);
- dynamicDims.push_back(builder.create<UnsignedDivIOp>(
- loc, origDim, builder.create<ConstantIndexOp>(loc, linearizedShape)));
- }
- }
- return builder.create<linalg::InitTensorOp>(loc, dynamicDims, staticDims,
- resultType.getElementType());
-}
-
/// Returns the reassociation maps to use in the `linalg.tensor_reshape`
/// operation to convert the operands of the origial operation to operands of
/// the expanded operation. The same method is used to compute the
@@ -734,8 +695,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
SmallVector<Value, 1> outputs;
for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index());
- outputs.push_back(getOutputValueForExpandedOp(
- rewriter, loc, indexingMap, result.value(), expansionInfo));
+ RankedTensorType expandedOutputType =
+ getExpandedType(result.value().getType().cast<RankedTensorType>(),
+ indexingMap, expansionInfo);
+ if (expandedOutputType != result.value().getType()) {
+ SmallVector<ReassociationIndices, 4> reassociation =
+ getReassociationForExpansion(indexingMap, expansionInfo);
+ outputs.push_back(rewriter.create<TensorReshapeOp>(
+ linalgOp.getLoc(), expandedOutputType, result.value(),
+ reassociation));
+ }
}
// The iterator types of the expanded op are all parallel.
@@ -779,47 +748,6 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
return resultVals;
}
-static Value
-getOutputValueForLinearization(OpBuilder &builder, Location loc,
- Value origOutput,
- ArrayRef<AffineMap> reassociationMaps) {
- SmallVector<Value, 4> dynamicDims;
- SmallVector<int64_t, 4> staticDims;
- auto shapedType = origOutput.getType().cast<ShapedType>();
- ArrayRef<int64_t> origShape = shapedType.getShape();
- for (auto map : reassociationMaps) {
- Optional<Value> dynamicDim;
- int64_t staticLinearizedShape = 1;
- for (AffineDimExpr expr :
- llvm::map_range(map.getResults(), [](AffineExpr e) {
- return e.cast<AffineDimExpr>();
- })) {
- unsigned pos = expr.getPosition();
- if (ShapedType::isDynamic(origShape[pos])) {
- Value dim = builder.create<DimOp>(loc, origOutput, pos);
- if (dynamicDim) {
- dynamicDim = builder.create<MulIOp>(loc, dynamicDim.getValue(), dim);
- } else {
- dynamicDim = dim;
- }
- } else {
- staticLinearizedShape *= origShape[pos];
- }
- }
- if (dynamicDim) {
- dynamicDim = builder.create<MulIOp>(
- loc, dynamicDim.getValue(),
- builder.create<ConstantIndexOp>(loc, staticLinearizedShape));
- dynamicDims.push_back(dynamicDim.getValue());
- staticDims.push_back(ShapedType::kDynamicSize);
- } else {
- staticDims.push_back(staticLinearizedShape);
- }
- }
- return builder.create<InitTensorOp>(loc, dynamicDims, staticDims,
- shapedType.getElementType());
-}
-
namespace {
/// Pattern to fold tensor_reshape op with its consumer by using the source of
@@ -973,7 +901,7 @@ struct FoldConsumerReshapeOpByLinearization
reshapeOp.getReassociationMaps());
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine())
- return reshapeOp.emitRemark("fused op indexing map is not affine");
+ return producer.emitRemark("fused op indexing map is not affine");
}
fusedIndexMaps.back() = modifiedMap;
@@ -983,9 +911,8 @@ struct FoldConsumerReshapeOpByLinearization
return reshapeOp.emitRemark("fused op loop bound computation failed");
Location loc = producer.getLoc();
- Value output =
- getOutputValueForLinearization(rewriter, loc, producer.getOutputs()[0],
- reshapeOp.getReassociationMaps());
+ Value output = rewriter.create<TensorReshapeOp>(
+ loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
LinalgOp fusedOp = createLinalgOpOfSameType(
producer, rewriter, loc, reshapeOp.getResultType(),
/*inputs=*/producer.getInputs(),
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 92805218dde7..c8c3c12a3ea4 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -14,7 +14,7 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
indexing_maps = [#map0, #map1, #map1],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
- outs(%0 : tensor<?x?x?xf32>) {
+ outs(%0 : tensor<?x?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
@@ -32,19 +32,12 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
// CHECK: func @generic_op_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-SAME: [#[[MAP0]], #[[MAP3]], #[[MAP4]]]
-// CHECK-DAG: %[[D0:.+]] = dim %[[T0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = dim %[[T0]], %[[C1]]
-// CHECK-DAG: %[[D2:.+]] = dim %[[T0]], %[[C2]]
-// CHECK: %[[D3:.+]] = divi_unsigned %[[D0]], %[[C4]]
-// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[D1]], %[[D2]], %[[D3]], 4]
+// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP3]], #[[MAP4]]]
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -66,7 +59,7 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg0 : tensor<?x?xf32>) {
+ outs(%arg0 : tensor<?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
@@ -83,19 +76,14 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK: func @generic_op_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C20:.+]] = constant 20 : index
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x?x5xf32>
-// CHECK-DAG: %[[D0:.+]] = dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = dim %[[ARG0]], %[[C1]]
-// CHECK: %[[D2:.+]] = divi_unsigned %[[D1]], %[[C20]]
-// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[D0]], 4, %[[D2]], 5]
+// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -132,30 +120,25 @@ func @reshape_as_consumer_permutation
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
+// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>
+// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C12:.+]] = constant 12 : index
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-SAME: [#[[MAP3]], #[[MAP4]]]
// CHECK-SAME: tensor<?x?xf32> into tensor<3x4x?x?xf32>
-// CHECK-DAG: %[[D0:.+]] = dim %[[ARG0]], %[[C0]]
-// CHECK: %[[D1:.+]] = divi_unsigned %[[D0]], %[[C2]]
-// CHECK-DAG: %[[D2:.+]] = dim %[[ARG0]], %[[C2]]
-// CHECK-DAG: %[[D3:.+]] = dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[D4:.+]] = divi_unsigned %[[D3]], %[[C12]]
-// CHECK: %[[T2:.+]] = linalg.init_tensor [%[[D1]], 2, %[[D2]], 3, 4, %[[D4]]]
+// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
// CHECK: %[[T3:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>)
@@ -170,18 +153,19 @@ func @reshape_as_consumer_permutation
func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
-> tensor<8x33x4xf32> {
%cst = constant dense<2.000000e+00> : tensor<264x4xf32>
- %0 = linalg.generic {
+ %0 = linalg.init_tensor [264, 4] : tensor<264x4xf32>
+ %1 = linalg.generic {
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>)
- outs(%arg0 : tensor<264x4xf32>) {
+ outs(%0 : tensor<264x4xf32>) {
^bb0(%arg1: f32, %arg2: f32, %s: f32): // no predecessors
%2 = mulf %arg1, %arg2 : f32
linalg.yield %2 : f32
} -> tensor<264x4xf32>
- %1 = linalg.tensor_reshape %0 [#map1, #map2] :
+ %2 = linalg.tensor_reshape %1 [#map1, #map2] :
tensor<264x4xf32> into tensor<8x33x4xf32>
- return %1 : tensor<8x33x4xf32>
+ return %2 : tensor<8x33x4xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -189,51 +173,54 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @generic_op_reshape_consumer_static
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
-// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK: %[[T0:.+]] = linalg.init_tensor [264, 4]
+// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32>
-// CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4] : tensor<8x33x4xf32>
-// CHECK: %[[T2:.+]] = linalg.generic
+// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
+// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>)
-// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
-// CHECK: return %[[T2]] : tensor<8x33x4xf32>
+// CHECK-SAME: ins(%[[T1]] : tensor<8x33x4xf32>)
+// CHECK-SAME: outs(%[[T2]] : tensor<8x33x4xf32>)
+// CHECK: return %[[T3]] : tensor<8x33x4xf32>
// -----
func @scalar_reshape(
- %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>, %shape : tensor<10xf32>)
- -> tensor<1x10xf32>
+ %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>) -> tensor<1x10xf32>
{
%0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor<f32>
- %1 = linalg.generic
+ %1 = linalg.init_tensor [10] : tensor<10xf32>
+ %2 = linalg.generic
{indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%0 : tensor<f32>)
- outs(%shape : tensor<10xf32>) {
+ outs(%1 : tensor<10xf32>) {
^bb0(%arg2: f32, %s: f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<10xf32>
- %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1) -> (d0, d1)>]
+ %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>]
: tensor<10xf32> into tensor<1x10xf32>
- return %2 : tensor<1x10xf32>
+ return %3 : tensor<1x10xf32>
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> ()>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()>
// CHECK: func @scalar_reshape
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] []
// CHECK-SAME: tensor<1xf32> into tensor<f32>
-// CHECK: %[[T1:.+]] = linalg.init_tensor [1, 10] : tensor<1x10xf32>
-// CHECK: %[[T2:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK: %[[T1:.+]] = linalg.init_tensor [10]
+// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T1]] [#[[MAP0]]]
+// CHECK: %[[T3:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]] : tensor<f32>)
-// CHECK-SAME: outs(%[[T1]] : tensor<1x10xf32>)
-// CHECK: return %[[T2]] : tensor<1x10xf32>
+// CHECK-SAME: outs(%[[T2]] : tensor<1x10xf32>)
+// CHECK: return %[[T3]] : tensor<1x10xf32>
// -----
@@ -331,15 +318,16 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
// -----
func @reshape_as_consumer_permutation
- (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>, %shape : tensor<6x4x210xi32>)
+ (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
-> tensor<2x3x4x5x6x7xi32> {
+ %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32>
%c = linalg.indexed_generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>)
- outs(%shape : tensor<6x4x210xi32>) {
+ outs(%shape : tensor<6x4x210xi32>) {
^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32, %s: i32):
%1 = addi %arg3, %arg4 : i32
%2 = index_cast %arg0 : index to i32
@@ -364,38 +352,43 @@ func @reshape_as_consumer_permutation
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 * 3 + d1)>
-// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2) -> (d0 * 42 + d1 * 7 + d2)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
+// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>
+// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+// CHECK-DAG: #[[MAP11:.+]] = affine_map<(d0, d1) -> (d0 * 3 + d1)>
+// CHECK-DAG: #[[MAP12:.+]] = affine_map<(d0, d1, d2) -> (d0 * 42 + d1 * 7 + d2)>
// CHECK: func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
-// CHECK-DAG: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [6, 4, 210]
+// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
+// CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-SAME: [#[[MAP3]], #[[MAP4]]]
-// CHECK: %[[T2:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
-// CHECK: %[[T3:.+]] = linalg.indexed_generic
-// CHECK-SAME: indexing_maps = [#[[MAP7]], #[[MAP8]], #[[MAP9]]]
-// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
-// CHECK-SAME: outs(%[[T2]] : tensor<2x3x4x5x6x7xi32>)
+// CHECK: %[[T3:.+]] = linalg.tensor_reshape %[[T0]]
+// CHECK-SAME: [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+// CHECK: %[[T4:.+]] = linalg.indexed_generic
+// CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
+// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
+// CHECK-SAME: outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>)
// CHECK: ^{{.+}}(
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, %[[ARG3:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, %[[ARG5:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index, %[[ARG7:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32)
-// CHECK-DAG: %[[T4:.+]] = affine.apply #[[MAP5]](%[[ARG2]], %[[ARG3]])
-// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP6]](%[[ARG4]], %[[ARG5]], %[[ARG6]])
-// CHECK-DAG: %[[T6:.+]] = addi %[[ARG8]], %[[ARG9]]
-// CHECK: %[[T7:.+]] = index_cast %[[T4]]
-// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]]
-// CHECK: %[[T9:.+]] = index_cast %[[T5]]
-// CHECK: %[[T10:.+]] = addi %[[T8]], %[[T9]]
-// CHECK: %[[T11:.+]] = index_cast %[[ARG7]]
-// CHECK: %[[T12:.+]] = addi %[[T10]], %[[T11]]
+// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP11]](%[[ARG2]], %[[ARG3]])
+// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP12]](%[[ARG4]], %[[ARG5]], %[[ARG6]])
+// CHECK-DAG: %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]]
+// CHECK: %[[T8:.+]] = index_cast %[[T5]]
+// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]]
+// CHECK: %[[T10:.+]] = index_cast %[[T6]]
+// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]]
+// CHECK: %[[T12:.+]] = index_cast %[[ARG7]]
+// CHECK: %[[T13:.+]] = addi %[[T11]], %[[T12]]
// -----
@@ -466,7 +459,7 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
indexing_maps = [#map0, #map0, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg0 : tensor<?x?xf32>) {
+ outs(%arg0 : tensor<?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32, %s: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
linalg.yield %1 : f32
@@ -479,8 +472,10 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
// CHECK: func @generic_op_reshape_consumer_fusion_projected
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -490,8 +485,11 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]]
// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: tensor<?x?xf32> into tensor<?x4x5x?xf32>
-// CHECK: %[[T2:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]]
+// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: [#[[MAP2]], #[[MAP3]]]
+// CHECK: %[[T3:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
-// CHECK: return %[[T2]] : tensor<?x?x4x5xf32>
+// CHECK-SAME: outs(%[[T2]] : tensor<?x?x4x5xf32>)
+// CHECK: return %[[T3]] : tensor<?x?x4x5xf32>
diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
index aff1447a63c7..cb57f00372c7 100644
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization -verify-diagnostics %s | FileCheck %s
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
@@ -21,14 +21,19 @@ func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
return %1 : tensor<?x?x4x?xf32>
}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @generic_op_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]]
+// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]]
// CHECK-SAME: ins(%[[ARG0]], %{{.+}} : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>)
-// CHECK-SAME: outs(%{{.+}} : tensor<?x?x4x?xf32>)
+// CHECK-SAME: outs(%[[T0]] : tensor<?x?x4x?xf32>)
// -----
@@ -52,47 +57,17 @@ func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
return %1 : tensor<?x?xf32>
}
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
// CHECK: func @generic_op_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C20:.+]] = constant 20 : index
-// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
-// CHECK: %[[T1:.+]] = dim %[[ARG0]], %[[C1]]
-// CHECK: %[[T2:.+]] = muli %[[T1]], %[[C20]]
-// CHECK: %[[T3:.+]] = linalg.init_tensor [%[[T0]], %[[T2]]]
+// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]]
-// CHECK-SAME: outs(%[[T3]] : tensor<?x?xf32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
- %arg1 : tensor<?x?x?x5xf32>) ->
- tensor<?x?xf32>
-{
- %0 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>)
- outs(%arg0 : tensor<?x?x?x5xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
- %1 = mulf %arg3, %arg4 : f32
- linalg.yield %1 : f32
- } -> tensor<?x?x?x5xf32>
- %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
- affine_map<(i, j, k, l) -> (j, k, l)>] :
- tensor<?x?x?x5xf32> into tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
-// CHECK: %[[T0:.+]] = linalg.generic
-// CHECK: linalg.tensor_reshape %[[T0]]
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
// -----
@@ -116,13 +91,19 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
return %1 : tensor<?x?x4x?xi32>
}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @indexed_generic_op_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xi32>
+// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// CHECK: linalg.indexed_generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]]
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?xi32>)
+// CHECK-SAME: outs(%[[T0]] : tensor<?x?x4x?xi32>)
// -----
@@ -144,20 +125,17 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
return %1 : tensor<?x?xi32>
}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+// CHECK: func @indexed_generic_op_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C20:.+]] = constant 20 : index
-// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C0]]
-// CHECK: %[[T1:.+]] = dim %[[ARG0]], %[[C1]]
-// CHECK: %[[T2:.+]] = muli %[[T1]], %[[C20]]
-// CHECK: %[[T3:.+]] = linalg.init_tensor [%[[T0]], %[[T2]]]
+// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK: linalg.indexed_generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
-// CHECK-SAME: outs(%[[T3]] : tensor<?x?xi32>)
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: outs(%[[T0]] : tensor<?x?xi32>)
// CHECK-NOT: linalg.tensor_reshape
// -----
@@ -179,12 +157,12 @@ func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
return %2 : tensor<3x7x5xf32>
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: func @generic_op_021_permultation_reshape_producer_fusion
// CHECK-NOT: linalg.tensor_reshape
// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// -----
@@ -210,7 +188,7 @@ func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
// CHECK: func @generic_op_120_permultation_reshape_producer_fusion
// CHECK-NOT: linalg.tensor_reshape
// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// -----
@@ -237,7 +215,7 @@ func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
// CHECK: func @generic_op_102_permultation_reshape_producer_fusion
// CHECK-NOT: linalg.tensor_reshape
// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// -----
@@ -258,10 +236,39 @@ func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf
return %2 : tensor<5x21xf32>
}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
// CHECK: func @generic_op_102_permultation_reshape_consumer_fusion
-// CHECK-NOT: linalg.tensor_reshape
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x5x7xf32>
+// CHECK: %[[T0:.+]] = linalg.init_tensor [5, 3, 7]
+// CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[T0]]
+// CHECK-SAME: [#[[MAP0]], #[[MAP1]]]
// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<3x5x7xf32>)
+// CHECK-SAME: outs(%[[T1]] : tensor<5x21xf32>)
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
+ %arg1 : tensor<?x?x?x5xf32>) ->
+ tensor<?x?xf32>
+{
+ // expected-remark @+1 {{fused op indexing map is not affine}}
+ %0 = linalg.generic {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32>)
+ outs(%arg0 : tensor<?x?x?x5xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ %1 = mulf %arg3, %arg4 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?x?x5xf32>
+ %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k, l)>] :
+ tensor<?x?x?x5xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list