[Mlir-commits] [mlir] 542668d - [mlir][Linalg] Add support for fusing linalg.tensor_reshape with
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 23 13:42:19 PDT 2020
Author: MaheshRavishankar
Date: 2020-04-23T13:41:47-07:00
New Revision: 542668d1e2060693279462b67d07756fe93f3eb9
URL: https://github.com/llvm/llvm-project/commit/542668d1e2060693279462b67d07756fe93f3eb9
DIFF: https://github.com/llvm/llvm-project/commit/542668d1e2060693279462b67d07756fe93f3eb9.diff
LOG: [mlir][Linalg] Add support for fusing linalg.tensor_reshape with
linalg.generic operations.
Differential Revision: https://reviews.llvm.org/D78464
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/IR/StandardTypes.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 3e667d98f822..10883d03b38b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -172,6 +172,10 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
RankedTensorType getResultType() {
return result().getType().cast<RankedTensorType>();
}
+ SmallVector<AffineMap, 4> getReassociationMaps() {
+ return llvm::to_vector<4>(llvm::map_range(reassociation(),
+ [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
+ }
}];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index d8886acc5992..5c4868c4c870 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -18,8 +18,10 @@
namespace mlir {
class FuncOp;
+class MLIRContext;
class ModuleOp;
template <typename T> class OperationPass;
+class OwningRewritePatternList;
class Pass;
std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass();
@@ -48,6 +50,10 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToParallelLoopsPass();
/// Placeholder for now, this is NYI.
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
+/// Patterns for fusing linalg operation on tensors.
+void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
+
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_PASSES_H_
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index cc94d27dedbb..4c5bbba0aa6a 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -672,9 +672,18 @@ MemRefType canonicalizeStridedLayout(MemRefType t);
/// varying stride is always `1`.
///
/// Examples:
-/// - memref<3x4x5xf32> has canonical stride expression `20*d0 + 5*d1 + d2`.
-/// - memref<3x?x5xf32> has canonical stride expression `s0*d0 + 5*d1 + d2`.
-/// - memref<3x4x?xf32> has canonical stride expression `s1*d0 + s0*d1 + d2`.
+/// - memref<3x4x5xf32> has canonical stride expression
+/// `20*exprs[0] + 5*exprs[1] + exprs[2]`.
+/// - memref<3x?x5xf32> has canonical stride expression
+/// `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
+/// - memref<3x4x?xf32> has canonical stride expression
+/// `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
+AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+ ArrayRef<AffineExpr> exprs,
+ MLIRContext *context);
+
+/// Return the result of makeCanonicalStrudedLayoutExpr for the common case
+/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
MLIRContext *context);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8fa90f444f63..0aa149ef907f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -554,7 +554,7 @@ computeTensorReshapeCollapsedType(RankedTensorType type,
unsigned currentDim = 0;
for (AffineMap m : reassociation) {
unsigned dim = m.getNumResults();
- auto band = shape.drop_front(currentDim).take_front(dim);
+ auto band = shape.slice(currentDim, dim);
int64_t size = 1;
if (llvm::is_contained(band, ShapedType::kDynamicSize))
size = ShapedType::kDynamicSize;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 1184b5f87ea6..cd6301ae249c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -559,6 +559,187 @@ struct FuseGenericOpsOnTensors {
};
} // namespace
+/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
+/// provided, given the shape of the source tensor that corresponds to the
+/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
+/// are "row-major" ordered logically.
+///
+/// For example:
+///
+/// %0 = op ... : tensor<?x?x4x5xf32>
+/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
+///
+/// and reshape:
+/// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+/// affine_map<(i, j, k, l) -> (j, k, l)>] :
+/// tensor<?x?x4x5xf32> into tensor<?x?xf32>
+///
+/// would be rewritten into:
+/// %0 = op ... : tensor<?x?x4x5xf32>
+/// with output index_map
+/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
+static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
+ ArrayRef<int64_t> sourceShape,
+ ArrayRef<AffineMap> reassociationMaps) {
+ SmallVector<AffineExpr, 4> resultExprs;
+ resultExprs.reserve(reassociationMaps.size());
+ ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
+ MLIRContext *context = sourceMap.getContext();
+
+ // Compute the result exprs based on the reassociation maps.
+ for (AffineMap map : reassociationMaps) {
+ ArrayRef<AffineExpr> collapsedDims = map.getResults();
+ // Assume that they are in-order and contiguous (already checked in
+ // verifier).
+ assert(!collapsedDims.empty());
+ unsigned startDim =
+ collapsedDims.front().cast<AffineDimExpr>().getPosition();
+ AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
+ sourceShape.slice(startDim, collapsedDims.size()),
+ sourceExprs.slice(startDim, collapsedDims.size()), context);
+ resultExprs.push_back(linearizedExpr);
+ }
+ return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
+ resultExprs, context);
+}
+
+/// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
+/// true) or its producer (if `asProducer` is false) given the indexing map at
+/// its use.
+static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
+ AffineMap useIndexMap, bool asProducer) {
+ RankedTensorType returnType = reshapeOp.getResultType();
+ RankedTensorType operandType = reshapeOp.getSrcType();
+ // Reshape is fusible with its consumer (i.e. reshape as a producer) when its
+ // operand is of lesser rank than the result. Fusing when operand has higher
+ // rank will require use of mods and divs in the indexing maps of the fused op
+ // which would make it non-invertible. Similarly reshape is fused with its
+ // producer (i.e. reshape as consumer) only if the return type has lesser
+ // rank.
+ if ((asProducer && returnType.getRank() < operandType.getRank()) ||
+ (!asProducer && operandType.getRank() < returnType.getRank()))
+ return false;
+ return useIndexMap.isIdentity();
+}
+
+namespace {
+/// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
+template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer {
+ static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer,
+ unsigned consumerIdx) {
+ return isTensorReshapeOpFusible(
+ producer, consumer.getInputIndexingMap(consumerIdx), true);
+ }
+
+ static Operation *fuse(TensorReshapeOp producer, LinalgOpTy consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
+ if (!isFusible(producer, consumer, consumerIdx))
+ return nullptr;
+
+ // Compute the fused operands list,
+ SmallVector<Value, 2> fusedOperands(consumer.operand_begin(),
+ consumer.operand_end());
+ fusedOperands[consumerIdx] = producer.src();
+
+ // Compute indexing_maps for the fused operation. The indexing_maps for the
+ // operands of the consumers that arent fused are the same.
+ SmallVector<AffineMap, 4> fusedIndexMaps =
+ llvm::to_vector<4>(llvm::map_range(
+ consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
+ return attr.cast<AffineMapAttr>().getValue();
+ }));
+
+ // Compute the indexing map to use for the operand of the producer.
+ AffineMap modifiedMap = linearizeCollapsedDims(
+ fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
+ producer.getReassociationMaps());
+ for (AffineExpr expr : modifiedMap.getResults()) {
+ if (!expr.isPureAffine())
+ return nullptr;
+ }
+ fusedIndexMaps[consumerIdx] = modifiedMap;
+
+ // Further check that the resulting index maps can be fused and
+ // inverted. Without this the resultant op is not legal.
+ if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
+ return nullptr;
+
+ SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
+ llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ }));
+ auto fusedOp = rewriter.create<LinalgOpTy>(
+ rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
+ rewriter.getI64IntegerAttr(fusedOperands.size()),
+ rewriter.getI64IntegerAttr(consumer.getNumResults()),
+ rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr);
+ auto &fusedRegion = fusedOp.region();
+ rewriter.cloneRegionBefore(consumer.region(), fusedRegion,
+ fusedRegion.begin());
+ return fusedOp;
+ }
+};
+
+/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
+template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
+ static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer,
+ unsigned consumerIdx) {
+ return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
+ false);
+ }
+
+ static Operation *fuse(LinalgOpTy producer, TensorReshapeOp consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
+ if (!isFusible(producer, consumer, consumerIdx))
+ return nullptr;
+
+ // The indexing_maps for the operands of the fused operation are same as
+ // those for the operands of the producer.
+ SmallVector<AffineMap, 4> fusedIndexMaps =
+ llvm::to_vector<4>(llvm::map_range(
+ producer.indexing_maps(), [](Attribute attr) -> AffineMap {
+ return attr.cast<AffineMapAttr>().getValue();
+ }));
+ // Compute the indexing map to use for the operand of the producer.
+ AffineMap modifiedMap = linearizeCollapsedDims(
+ producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
+ consumer.getReassociationMaps());
+ for (AffineExpr expr : modifiedMap.getResults()) {
+ if (!expr.isPureAffine())
+ return nullptr;
+ }
+ fusedIndexMaps.back() = modifiedMap;
+
+ // Further check that the resulting index maps can be fused and
+ // inverted. Without this the resultant op is not legal.
+ if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
+ return nullptr;
+
+ SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>(
+ llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ }));
+
+ auto fusedOp = rewriter.create<LinalgOpTy>(
+ rewriter.getUnknownLoc(), consumer.getResultType(),
+ producer.getOperands(),
+ rewriter.getI64IntegerAttr(producer.getNumOperands()),
+ rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
+ producer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr);
+ auto &fusedRegion = fusedOp.region();
+ rewriter.cloneRegionBefore(producer.region(), fusedRegion,
+ fusedRegion.begin());
+ return fusedOp;
+ }
+};
+} // namespace
+
Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
Operation *consumer,
unsigned consumerIdx,
@@ -569,6 +750,7 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
if (!producer || producer->getNumResults() != 1)
return nullptr;
+ // Fuse when consumer is GenericOp.
if (GenericOp genericOp = dyn_cast<GenericOp>(consumer)) {
if (!genericOp.hasTensorSemantics())
return nullptr;
@@ -576,7 +758,21 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
if (genericOpProducer.hasTensorSemantics())
return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp,
consumerIdx, rewriter, folder);
+ } else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
+ return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
+ reshapeOpProducer, genericOp, consumerIdx, rewriter, folder);
}
+ return nullptr;
+ }
+
+ // Fuse when consumer is a TensorReshapeOp.
+ if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
+ if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
+ if (genericOpProducer.hasTensorSemantics())
+ return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse(
+ genericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
+ }
+ return nullptr;
}
return nullptr;
}
@@ -612,7 +808,7 @@ struct FusionOfTensorOpsPass
void runOnOperation() override {
OwningRewritePatternList patterns;
Operation *op = getOperation();
- patterns.insert<FuseTensorOps<GenericOp>>(op->getContext());
+ populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
applyPatternsAndFoldGreedily(op->getRegions(), patterns);
};
};
@@ -622,6 +818,12 @@ struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> {
};
} // namespace
+void mlir::populateLinalgTensorOpsFusionPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<TensorReshapeOp>>(
+ context);
+}
+
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
return std::make_unique<LinalgFusionPass>();
}
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 903ae92e6baf..94156c358eb0 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -728,35 +728,47 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
}
AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+ ArrayRef<AffineExpr> exprs,
MLIRContext *context) {
AffineExpr expr;
bool dynamicPoisonBit = false;
+ unsigned numDims = 0;
unsigned nSymbols = 0;
+ // Compute the number of symbols and dimensions of the passed exprs.
+ for (AffineExpr expr : exprs) {
+ expr.walk([&numDims, &nSymbols](AffineExpr d) {
+ if (AffineDimExpr dim = d.dyn_cast<AffineDimExpr>())
+ numDims = std::max(numDims, dim.getPosition() + 1);
+ else if (AffineSymbolExpr symbol = d.dyn_cast<AffineSymbolExpr>())
+ nSymbols = std::max(nSymbols, symbol.getPosition() + 1);
+ });
+ }
int64_t runningSize = 1;
- unsigned rank = sizes.size();
- for (auto en : llvm::enumerate(llvm::reverse(sizes))) {
- auto size = en.value();
- auto position = rank - 1 - en.index();
+ for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
+ int64_t size = std::get<1>(en);
// Degenerate case, no size =-> no stride
if (size == 0)
continue;
- auto d = getAffineDimExpr(position, context);
- // Static case: stride = runningSize and runningSize *= size.
- if (!dynamicPoisonBit) {
- auto cst = getAffineConstantExpr(runningSize, context);
- expr = expr ? expr + cst * d : cst * d;
- if (size > 0)
- runningSize *= size;
- else
- // From now on bail into dynamic mode.
- dynamicPoisonBit = true;
- continue;
- }
- // Dynamic case, new symbol for each new stride.
- auto sym = getAffineSymbolExpr(nSymbols++, context);
- expr = expr ? expr + d * sym : d * sym;
+ AffineExpr dimExpr = std::get<0>(en);
+ AffineExpr stride = dynamicPoisonBit
+ ? getAffineSymbolExpr(nSymbols++, context)
+ : getAffineConstantExpr(runningSize, context);
+ expr = expr ? expr + dimExpr * stride : dimExpr * stride;
+ if (size > 0)
+ runningSize *= size;
+ else
+ dynamicPoisonBit = true;
}
- return simplifyAffineExpr(expr, rank, nSymbols);
+ return simplifyAffineExpr(expr, numDims, nSymbols);
+}
+
+AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
+ MLIRContext *context) {
+ SmallVector<AffineExpr, 4> exprs;
+ exprs.reserve(sizes.size());
+ for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
+ exprs.push_back(getAffineDimExpr(dim, context));
+ return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
}
/// Return true if the layout for `t` is compatible with strided semantics.
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 11c38fcb7601..2c00f77edd3f 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -129,3 +129,93 @@ func @add_mul_scalar_fusion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tenso
return %1 : tensor<f32>
}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
+ %arg1 : tensor<?x?x4x?xf32>) ->
+ tensor<?x?x4x?xf32>
+{
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k)>,
+ affine_map<(i, j, k, l) -> (l)>] :
+ tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
+ %1 = linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ %0, %arg1 {
+ ^bb0(%arg3: f32, %arg4: f32): // no predecessors
+ %1 = mulf %arg3, %arg4 : f32
+ linalg.yield %1 : f32
+ }: tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32> -> tensor<?x?x4x?xf32>
+ return %1 : tensor<?x?x4x?xf32>
+}
+
+// CHECK-LABEL: func @generic_op_reshape_producer_fusion
+// CHECK: linalg.generic
+// CHECK-SAME: args_in = 2
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP1]]]
+// CHECK-NOT: linalg.generic
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
+ %arg1 : tensor<?x?x4x5xf32>) ->
+ tensor<?x?xf32>
+{
+ %0 = linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ %arg0, %arg1 {
+ ^bb0(%arg3: f32, %arg4: f32): // no predecessors
+ %1 = mulf %arg3, %arg4 : f32
+ linalg.yield %1 : f32
+ }: tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32> -> tensor<?x?x4x5xf32>
+ %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k, l)>] :
+ tensor<?x?x4x5xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @generic_op_reshape_consumer_fusion
+// CHECK: linalg.generic
+// CHECK-SAME: args_in = 2
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]]]
+// CHECK-NOT: linalg.generic
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
+ %arg1 : tensor<?x?x?x5xf32>) ->
+ tensor<?x?xf32>
+{
+ %0 = linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ %arg0, %arg1 {
+ ^bb0(%arg3: f32, %arg4: f32): // no predecessors
+ %1 = mulf %arg3, %arg4 : f32
+ linalg.yield %1 : f32
+ }: tensor<?x?x?x5xf32>, tensor<?x?x?x5xf32> -> tensor<?x?x?x5xf32>
+ %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k, l)>] :
+ tensor<?x?x?x5xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
+// CHECK: linalg.tensor_reshape
More information about the Mlir-commits
mailing list