[Mlir-commits] [mlir] cf194da - [mlir][linalg] Remove IndexedGenericOp support from FusionOnTensors...
Tobias Gysi
llvmlistbot at llvm.org
Thu May 13 07:58:09 PDT 2021
Author: Tobias Gysi
Date: 2021-05-13T14:57:16Z
New Revision: cf194da1bbf79d392688dba0c74875829e9873f2
URL: https://github.com/llvm/llvm-project/commit/cf194da1bbf79d392688dba0c74875829e9873f2
DIFF: https://github.com/llvm/llvm-project/commit/cf194da1bbf79d392688dba0c74875829e9873f2.diff
LOG: [mlir][linalg] Remove IndexedGenericOp support from FusionOnTensors...
after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612).
Differential Revision: https://reviews.llvm.org/D102163
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir
mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 93c99038e322a..501c34f5c46b0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -57,17 +57,16 @@ void populateFoldReshapeOpsByExpansionPatterns(
ControlElementwiseOpsFusionFn controlFoldingReshapes = skipUnitDimReshape);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic/indexed_generic operation by linearizing the
-/// indexing map used to access the source (target) of the reshape operation in
-/// the generic/indexed_generic operation.
+/// producer (consumer) generic operation by linearizing the indexing map used
+/// to access the source (target) of the reshape operation in the generic
+/// operation.
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
-/// producer (consumer) generic/indexed_generic operation by linearizing the
-/// indexing map used to access the source (target) of the reshape operation in
-/// the generic/indexed_generic operation. The patterns are applied only when
-/// the tensor reshape involved is collapsing (introducing) unit-extent
-/// dimensions.
+/// producer (consumer) generic operation by linearizing the indexing map used
+/// to access the source (target) of the reshape operation in the generic
+/// operation. The patterns are applied only when the tensor reshape involved is
+/// collapsing (introducing) unit-extent dimensions.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 6a0c597b7c3cd..4ecc72ca3d094 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -26,8 +26,8 @@
using namespace mlir;
using namespace mlir::linalg;
-/// Implementation of fusion of generic ops and indexed_generic ops.
-static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
+/// Conditions for elementwise fusion of generic operations.
+static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
unsigned consumerIdx) {
// Producer and consumer must have tensor semantics.
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
@@ -95,57 +95,20 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void
-generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
- LinalgOp producer, LinalgOp consumer,
+generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
+ GenericOp producer, GenericOp consumer,
AffineMap consumerToProducerLoopsMap,
unsigned consumerIdx, unsigned nloops) {
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();
Block *fusedBlock = new Block();
- fusedOp->getRegion(0).push_back(fusedBlock);
+ fusedOp.region().push_back(fusedBlock);
BlockAndValueMapping mapper;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(fusedBlock);
- // The block arguments are
- // [index_0, index_1, ... ,
- // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
- // producer_operand_0, ... , producer_operand_(n-1)],
- // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
- // , where n is the number of producer's operand and m is the number
- // consumer's operand.
- // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
- // generic op. In this case, there are no indices in block arguments.
- unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation())
- ? producer.getNumLoops()
- : 0;
- unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation())
- ? consumer.getNumLoops()
- : 0;
- unsigned numFusedOpIndices =
- (isa<IndexedGenericOp>(producer.getOperation()) ||
- isa<IndexedGenericOp>(consumer.getOperation()))
- ? std::max(producer.getNumLoops(), consumer.getNumLoops())
- : 0;
-
- // 0. Firstly, add all the indices to the block arguments.
- for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
- fusedBlock->addArgument(rewriter.getIndexType());
- // 1. Map consumer indices to fusedBlock indices 1-1.
- mapper.map(consumerBlock.getArguments().take_front(numConsumerIndices),
- fusedBlock->getArguments().take_front(numConsumerIndices));
- // 2a. Embed producer indices into fusedBlock index space 1-1.
- for (auto it :
- llvm::zip(producerBlock.getArguments().take_front(numProducerIndices),
- fusedBlock->getArguments().take_front(numProducerIndices))) {
- auto newIndex = rewriter.create<mlir::AffineApplyOp>(
- producer.getLoc(),
- consumerToProducerLoopsMap.getSubMap(std::get<0>(it).getArgNumber()),
- fusedBlock->getArguments().take_front(numFusedOpIndices));
- mapper.map(std::get<0>(it), newIndex);
- }
- // 2b. Add an index operation for every fused loop dimension and use the
+ // 2. Add an index operation for every fused loop dimension and use the
// `consumerToProducerLoopsMap` to map the producer indices.
if (producer.hasIndexSemantics()) {
// Add an index operation for every fused loop dimension.
@@ -169,34 +132,30 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
assert(consumerIdx < consumer.getNumInputs() &&
"expected producer of input operand");
// 3. Consumer input operands up to consumerIdx (exclusive).
- for (BlockArgument bbArg : consumerBlock.getArguments()
- .drop_front(numConsumerIndices)
- .take_front(consumerIdx)) // input assumption.
+ for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
+ consumerIdx)) // input assumption.
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
// Replacing consumerIdx requires getting the cloned, yielded, value from
// the (cloned) producer block. This happens in step 9.
// 4. Splice in producer's input operands.
- for (BlockArgument bbArg : producerBlock.getArguments()
- .drop_front(numProducerIndices)
- .take_front(producer.getNumInputs()))
+ for (BlockArgument bbArg :
+ producerBlock.getArguments().take_front(producer.getNumInputs()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
// 4.b. Producer output operand/map that is fused needs to be mapped to the
// producer bbArg if it is an "initTensor" (i.e. its value is actually read).
assert(producer->getNumResults() == 1 && "expected single result producer");
if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) {
- BlockArgument bbArg =
- producerBlock.getArguments()
- .drop_front(numConsumerIndices + producer.getNumInputs())
- // TODO: bbArg index of
- .front();
+ BlockArgument bbArg = producerBlock.getArguments()
+ .drop_front(producer.getNumInputs())
+ // TODO: bbArg index of
+ .front();
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
}
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
for (BlockArgument bbArg : consumerBlock.getArguments()
- .drop_front(numConsumerIndices)
.take_front(consumer.getNumInputs())
.drop_front(consumerIdx + 1))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
@@ -232,23 +191,21 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
assert(!producer->isAncestor(replacement.getDefiningOp()) &&
"yielded value must have been mapped");
}
- mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
- replacement);
+ mapper.map(consumerBlock.getArgument(consumerIdx), replacement);
// 10. Clone operations from the consumer to the fused op.
for (auto &op : consumerBlock.getOperations())
rewriter.clone(op, mapper);
// Sanity checks.
- assert(fusedBlock->getNumArguments() ==
- fusedOp->getNumOperands() + numFusedOpIndices &&
- "Ill-formed LinalgOp region");
+ assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
+ "Ill-formed GenericOp region");
}
-static Optional<SmallVector<Value, 1>>
-fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
+static Optional<SmallVector<Value>>
+fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand,
const ControlElementwiseOpsFusionFn &controlFn,
PatternRewriter &rewriter) {
- LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
+ auto consumer = cast<GenericOp>(consumerOpOperand.getOwner());
unsigned consumerIdx = consumerOpOperand.getOperandNumber();
if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) ||
!controlFn(producer->getResult(0), consumerOpOperand))
@@ -311,27 +268,14 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
assert(producer->getNumResults() == 1 && "expected single result producer");
// Generate the fused op.
- Operation *fusedOp;
- if (isa<GenericOp>(producer.getOperation()) &&
- isa<GenericOp>(consumer.getOperation())) {
- fusedOp = rewriter.create<GenericOp>(
- consumer.getLoc(), consumer->getResultTypes(),
- /*inputs=*/fusedOperands,
- // TODO: handle outputs.
- consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
- } else {
- fusedOp = rewriter.create<IndexedGenericOp>(
- consumer.getLoc(), consumer->getResultTypes(),
- /*inputs=*/fusedOperands,
- // TODO: handle outputs.
- consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
- }
+ auto fusedOp = rewriter.create<GenericOp>(
+ consumer.getLoc(), consumer->getResultTypes(),
+ /*inputs=*/fusedOperands,
+ // TODO: handle outputs.
+ consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr);
// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
@@ -348,7 +292,7 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
consumerToProducerLoopsMap, consumerIdx,
consumer.getNumLoops());
- return SmallVector<Value, 1>(fusedOp->getResults());
+ return SmallVector<Value>(fusedOp->getResults());
}
/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
@@ -373,7 +317,7 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
ArrayRef<int64_t> sourceShape,
ArrayRef<AffineMap> reassociationMaps) {
- SmallVector<AffineExpr, 4> resultExprs;
+ SmallVector<AffineExpr> resultExprs;
resultExprs.reserve(reassociationMaps.size());
ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
MLIRContext *context = sourceMap.getContext();
@@ -386,8 +330,8 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
assert(!collapsedDims.empty());
unsigned startDim =
collapsedDims.front().cast<AffineDimExpr>().getPosition();
- SmallVector<int64_t, 4> sizes;
- SmallVector<AffineExpr, 4> dimExprs;
+ SmallVector<int64_t> sizes;
+ SmallVector<AffineExpr> dimExprs;
for (auto en :
llvm::zip(sourceShape.slice(startDim, collapsedDims.size()),
sourceExprs.slice(startDim, collapsedDims.size()))) {
@@ -426,22 +370,6 @@ static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,
return useIndexMap.isPermutation();
}
-/// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
-/// is a linalg.generic operation, the create a `linalg.generic` operation with
-/// the given `args`. Expects `op` to be `linalg.generic` or
-/// `linalg.indexed_generic`.
-template <typename... Args>
-static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
- Args... args) {
- if (isa<GenericOp>(op.getOperation()))
- return rewriter.create<GenericOp>(args...);
- if (isa<IndexedGenericOp>(op.getOperation()))
- return rewriter.create<IndexedGenericOp>(args...);
- llvm_unreachable(
- "expected only linalg.generic or linalg.indexed_generic ops");
- return nullptr;
-}
-
/// Check if the reshape operation is only expansion into/collapsing of
/// unit-dimension.
static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
@@ -459,10 +387,10 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
return true;
}
-/// Conditions for folding a generic/indexed-generic operation with a reshape op
-/// by expanding the iteration space dimensionality for tensor operations. These
-/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
-/// the following fusion pattern.
+/// Conditions for folding a generic operation with a reshape op by expanding
+/// the iteration space dimensionality for tensor operations. These are
+/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
+/// following fusion pattern.
///
/// Consider
///
@@ -476,12 +404,12 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
/// affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
///
-/// The reshape can be folded into the `linalgOp` if the
-/// generic/indexed-generic op loop dimensionality is increased to match the
-/// result (operand) of the tensor_reshape when the reshape is expanding
-/// (folding). The indexing_map of the fused tensor in the `linalgOp` and the
-/// reassociation map helps compute the indexing maps of the modified op. For
-/// the above example, based on the reassociation map it can be concluded that
+/// The reshape can be folded into the `genericOp` if its loop dimensionality
+/// is increased to match the result (operand) of the tensor_reshape when the
+/// reshape is expanding (folding). The indexing_map of the fused tensor in the
+/// `genericOp` and the reassociation map helps compute the indexing maps of
+/// the modified op. For the above example, based on the reassociation map it
+/// can be concluded that
///
/// - The loop used to access the first dimension of the fused tensor is split
/// into two.
@@ -520,41 +448,40 @@ static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
///
/// The added reshapes are again expanding patterns, so they will get fused
/// with its producers if possible.
-static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
+static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
unsigned fusedTensorIndex) {
// Is fusable only if:
- // - The linalgOp is a generic op, or an indexed_generic.
- // - All the indexing maps for operands and results in linalgOp are projected
+ // - All the indexing maps for operands and results are projected
// permutations.
// - The fused tensor is not a scalar.
- // - All the loops in linalgOp are parallel loops.
- return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) &&
- linalgOp.hasTensorSemantics() &&
- llvm::all_of(linalgOp.indexing_maps().getValue(),
+ // - All the loops are parallel loops.
+ return genericOp.hasTensorSemantics() &&
+ llvm::all_of(genericOp.indexing_maps().getValue(),
[](Attribute attr) {
return attr.cast<AffineMapAttr>()
.getValue()
.isProjectedPermutation();
}) &&
- linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
- llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) {
+ genericOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
+ llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
return attr.cast<StringAttr>().getValue() ==
getParallelIteratorTypeName();
});
}
namespace {
-/// Information needed to expand a generic/indexed_generic operation to fold the
-/// reshape with it.
+/// Information needed to expand a generic operation to fold the reshape with
+/// it.
class ExpansionInfo {
public:
// Computes the mapping from original dimensions of the op to the dimensions
// of the expanded op given the `indexingMap` of the fused operand/result of
- // the generic/indexed_generic op, the `reassocationMaps` of the reshape op
- // and the shape of the expanded op.
+ // the generic op, the `reassocationMaps` of the reshape op and the shape of
+ // the expanded op.
LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape);
+ ArrayRef<int64_t> expandedShape,
+ PatternRewriter &rewriter);
unsigned getOrigOpNumDims() const { return reassociation.size(); }
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
ReassociationIndicesRef getExpandedDims(unsigned i) const {
@@ -567,10 +494,10 @@ class ExpansionInfo {
private:
/// Reassociation from the dimensions in the original operation to the
/// dimension of the expanded operation.
- SmallVector<ReassociationIndices, 4> reassociation;
+ SmallVector<ReassociationIndices> reassociation;
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
- SmallVector<SmallVector<int64_t, 4>, 4> expandedShapeMap;
+ SmallVector<SmallVector<int64_t>> expandedShapeMap;
unsigned expandedOpNumDims;
};
} // namespace
@@ -578,7 +505,8 @@ class ExpansionInfo {
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
unsigned fusedTensorIndex,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape) {
+ ArrayRef<int64_t> expandedShape,
+ PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
@@ -586,13 +514,13 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
Optional<SmallVector<int64_t, 4>> originalLoopRange =
linalgOp.getStaticLoopRanges();
if (!originalLoopRange)
- return linalgOp.emitError("unable to find loop range for operation");
+ return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
reassociation.clear();
expandedShapeMap.clear();
// Compute the number of dimension in the expanded op that correspond to each
// dimension of the original op.
- SmallVector<unsigned, 4> numExpandedDims(fusedIndexMap.getNumDims(), 1);
+ SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
expandedShapeMap.resize(fusedIndexMap.getNumDims());
for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
@@ -627,17 +555,19 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
/// Note that this could be extended to handle dynamic case, but the
/// implementation below uses `affine.apply` which seems to have issues when the
/// shapes are not static.
-LogicalResult isIndexedOpExpandable(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo) {
+LogicalResult isGenericOpExpandable(GenericOp genericOp,
+ const ExpansionInfo &expansionInfo,
+ PatternRewriter &rewriter) {
+ if (!genericOp.hasIndexSemantics())
+ return success();
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
if (expandedShape.size() == 1)
continue;
for (int64_t shape : expandedShape.drop_front()) {
if (ShapedType::isDynamic(shape)) {
- return linalgOp.emitError(
- "unable to fuse indexed generic op where the expanded dim is "
- "dynamic");
+ return rewriter.notifyMatchFailure(
+ genericOp, "cannot expand due to index semantics and dynamic dims");
}
}
}
@@ -649,7 +579,7 @@ LogicalResult isIndexedOpExpandable(LinalgOp linalgOp,
static AffineMap
getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
- SmallVector<AffineExpr, 4> newExprs;
+ SmallVector<AffineExpr> newExprs;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned pos = expr.cast<AffineDimExpr>().getPosition();
SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
@@ -668,7 +598,7 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
static RankedTensorType getExpandedType(RankedTensorType originalType,
AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
- SmallVector<int64_t, 4> expandedShape;
+ SmallVector<int64_t> expandedShape;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
@@ -682,15 +612,15 @@ static RankedTensorType getExpandedType(RankedTensorType originalType,
/// the expanded operation. The same method is used to compute the
/// `linalg.tensor_reshape` used to collapse the result of the expanded op to
/// get the value that can replace all uses of the results of the original op.
-static SmallVector<ReassociationIndices, 4>
+static SmallVector<ReassociationIndices>
getReassociationForExpansion(AffineMap indexingMap,
const ExpansionInfo &expansionInfo) {
- SmallVector<ReassociationIndices, 4> reassociation;
+ SmallVector<ReassociationIndices> reassociation;
unsigned numReshapeDims = 0;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
- auto indices = llvm::to_vector<2>(
+ SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
reassociation.emplace_back(std::move(indices));
numReshapeDims += numExpandedDims;
@@ -698,66 +628,14 @@ getReassociationForExpansion(AffineMap indexingMap,
return reassociation;
}
-/// Build the body of the expanded IndexedGenericOp. The arguments for the
-/// induction variables of the original operation need to be recovered by
-/// linearizing the arguments of the corresponding dimensions of the expanded
-/// op. For now it is assumed that the shapes of the expanded op needed for
-/// linearization are static.
-static void buildExpandedIndexedGenericOpRegion(
- PatternRewriter &rewriter, Location loc, Region &originalOpRegion,
- Region &fusedOpRegion, const ExpansionInfo &expansionInfo) {
- assert(fusedOpRegion.empty() && "expected fused op to have empty region");
- // Create an entry block in the fused region with same number of arguments
- // as the fused op
- Block *fusedEntryBlock = new Block;
- fusedOpRegion.push_back(fusedEntryBlock);
- rewriter.cloneRegionBefore(originalOpRegion, fusedOpRegion,
- fusedOpRegion.end());
-
- // Merge the entry block of the fused op with the cloned blocks. For this
- // compute the value for arguments of the region in the original operation
- // in terms of the arguments of the fused op. Since the original operation
- // is expanded, the expanded dimensions need to be folded back to get the
- // replacement value for the arguments corresponding to interation index.
- // For now this expects that all the loop ranges are constants, which is
- // true if the shapes are all static. This has already been checked in the
- // precondition.
- using namespace edsc::op;
- using namespace edsc::intrinsics;
- OpBuilder::InsertionGuard guard(rewriter);
- SmallVector<Value, 4> argReplacements(originalOpRegion.getNumArguments());
- rewriter.setInsertionPointToStart(fusedEntryBlock);
- edsc::ScopedContext scopedContext(rewriter, loc);
- IndexType indexType = rewriter.getIndexType();
- for (auto i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
- Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
- ArrayRef<int64_t> expandedDimsShape =
- expansionInfo.getExpandedShapeOfDim(i).drop_front();
- for (unsigned shape : expandedDimsShape) {
- assert(!ShapedType::isDynamic(shape));
- linearizedIndex = linearizedIndex * std_constant_index(shape);
- linearizedIndex =
- linearizedIndex + fusedEntryBlock->addArgument(indexType);
- }
- argReplacements[i] = linearizedIndex;
- }
- for (auto i : llvm::seq<unsigned>(expansionInfo.getOrigOpNumDims(),
- argReplacements.size())) {
- argReplacements[i] =
- fusedEntryBlock->addArgument(originalOpRegion.getArgument(i).getType());
- }
- rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
- argReplacements);
-}
-
/// Update the body of an expanded linalg operation having index semantics. The
/// indices of the original operation need to be recovered by linearizing the
/// indices of the correspoding dimensions of the expanded operation. For now it
/// is assumed that the shapes of the expanded operation needed for
/// linearization are static.
-static void updateExpandedIndexOpRegion(PatternRewriter &rewriter, Location loc,
- Region &fusedRegion,
- const ExpansionInfo &expansionInfo) {
+static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
+ Location loc, Region &fusedRegion,
+ const ExpansionInfo &expansionInfo) {
// Replace the original indices by the linearization of the expanded indices.
for (IndexOp indexOp :
llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
@@ -793,112 +671,100 @@ static void updateExpandedIndexOpRegion(PatternRewriter &rewriter, Location loc,
}
}
-/// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
-/// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
-/// conditions have been satisfied.
-static Optional<SmallVector<Value, 1>>
-fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
+/// Implements the fusion of a tensor_reshape op and a generic op as explained
+/// in `isFusableWithReshapeByExpansion`. Assumes that those conditions have
+/// been satisfied.
+static Optional<SmallVector<Value>>
+fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp,
unsigned fusedTensorIndex,
PatternRewriter &rewriter) {
- assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) &&
+ assert(isFusableWithReshapeByDimExpansion(genericOp, fusedTensorIndex) &&
"preconditions for fuse operation failed");
// Check if reshape is expanding or collapsing.
bool isExpanding =
reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
RankedTensorType expandedType =
isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
- bool hasIndexSemantics = linalgOp.hasIndexSemantics() ||
- isa<IndexedGenericOp>(linalgOp.getOperation());
ExpansionInfo expansionInfo;
- if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex,
+ if (failed(expansionInfo.compute(genericOp, fusedTensorIndex,
reshapeOp.getReassociationMaps(),
- expandedType.getShape())))
+ expandedType.getShape(), rewriter)))
return llvm::None;
- if (hasIndexSemantics &&
- failed(isIndexedOpExpandable(linalgOp, expansionInfo)))
+ if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
return llvm::None;
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
- llvm::map_range(linalgOp.getIndexingMaps(), [&](AffineMap m) {
+ llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
}));
- SmallVector<Value, 4> expandedOpOperands;
- for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+ SmallVector<Value> expandedOpOperands;
+ for (auto operand : llvm::enumerate(genericOp.getInputs())) {
if (operand.index() == fusedTensorIndex) {
expandedOpOperands.push_back(reshapeOp.src());
continue;
}
- AffineMap indexingMap = linalgOp.getInputIndexingMap(operand.index());
+ AffineMap indexingMap = genericOp.getInputIndexingMap(operand.index());
RankedTensorType expandedOperandType =
getExpandedType(operand.value().getType().cast<RankedTensorType>(),
indexingMap, expansionInfo);
if (expandedOperandType != operand.value().getType()) {
// Reshape the operand to get the right type.
- SmallVector<ReassociationIndices, 4> reassociation =
+ SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
- linalgOp.getLoc(), expandedOperandType, operand.value(),
+ genericOp.getLoc(), expandedOperandType, operand.value(),
reassociation));
continue;
}
expandedOpOperands.push_back(operand.value());
}
- Location loc = linalgOp.getLoc();
- SmallVector<Value, 1> outputs;
- for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
- AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index());
+ Location loc = genericOp.getLoc();
+ SmallVector<Value> outputs;
+ for (auto result : llvm::enumerate(genericOp.getOutputs())) {
+ AffineMap indexingMap = genericOp.getOutputIndexingMap(result.index());
RankedTensorType expandedOutputType =
getExpandedType(result.value().getType().cast<RankedTensorType>(),
indexingMap, expansionInfo);
if (expandedOutputType != result.value().getType()) {
- SmallVector<ReassociationIndices, 4> reassociation =
+ SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
outputs.push_back(rewriter.create<TensorReshapeOp>(
- linalgOp.getLoc(), expandedOutputType, result.value(),
+ genericOp.getLoc(), expandedOutputType, result.value(),
reassociation));
}
}
// The iterator types of the expanded op are all parallel.
- SmallVector<StringRef, 4> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
- getParallelIteratorTypeName());
+ SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
+ getParallelIteratorTypeName());
TypeRange resultTypes = ValueRange(outputs).getTypes();
- LinalgOp fusedOp = createLinalgOpOfSameType(
- linalgOp, rewriter, linalgOp.getLoc(), resultTypes,
- /*inputs=*/expandedOpOperands, outputs, expandedOpIndexingMaps,
- iteratorTypes);
+ auto fusedOp =
+ rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
+ /*inputs=*/expandedOpOperands, outputs,
+ expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fusedOp->getRegion(0);
- Region &originalRegion = linalgOp->getRegion(0);
-
- if (isa<GenericOp>(linalgOp.getOperation())) {
- rewriter.cloneRegionBefore(originalRegion, fusedRegion,
- fusedRegion.begin());
- } else {
- assert(isa<IndexedGenericOp>(linalgOp.getOperation()));
- buildExpandedIndexedGenericOpRegion(rewriter, loc, originalRegion,
- fusedRegion, expansionInfo);
- }
+ Region &originalRegion = genericOp->getRegion(0);
+ rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
// Update the index accesses after the expansion.
- if (linalgOp.hasIndexSemantics())
- updateExpandedIndexOpRegion(rewriter, loc, fusedRegion, expansionInfo);
+ updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
- SmallVector<Value, 1> resultVals;
- for (auto result : llvm::enumerate(linalgOp->getResults())) {
+ SmallVector<Value> resultVals;
+ for (auto result : llvm::enumerate(genericOp->getResults())) {
if (!isExpanding &&
resultTypes[result.index()] != result.value().getType()) {
- SmallVector<ReassociationIndices, 4> reassociation =
+ SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(
- linalgOp.getOutputIndexingMap(result.index()), expansionInfo);
+ genericOp.getOutputIndexingMap(result.index()), expansionInfo);
resultVals.push_back(rewriter.create<TensorReshapeOp>(
- linalgOp.getLoc(), result.value().getType(),
+ genericOp.getLoc(), result.value().getType(),
fusedOp->getResult(result.index()), reassociation));
} else {
resultVals.push_back(fusedOp->getResult(result.index()));
@@ -934,22 +800,21 @@ namespace {
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
/// -> tensor<?x?x4x?xf32>
-template <typename LinalgOpTy, bool foldUnitDimReshapesOnly>
+template <bool foldUnitDimReshapesOnly>
struct FoldProducerReshapeOpByLinearization
- : public OpRewritePattern<LinalgOpTy> {
- using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
+ : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(LinalgOpTy op,
+ LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!op.hasTensorSemantics())
+ if (!genericOp.hasTensorSemantics())
return failure();
- LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
- for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+ for (auto operand : llvm::enumerate(genericOp.getInputs())) {
TensorReshapeOp reshapeOp =
operand.value().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp ||
!isTensorReshapeOpFoldableByLinearization(
- reshapeOp, linalgOp.getInputIndexingMap(operand.index()),
+ reshapeOp, genericOp.getInputIndexingMap(operand.index()),
/*asProducer =*/true) ||
(foldUnitDimReshapesOnly &&
!isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
@@ -957,15 +822,15 @@ struct FoldProducerReshapeOpByLinearization
continue;
// Compute the fused operands list,
- SmallVector<Value, 2> fusedOperands(linalgOp.getInputs());
+ SmallVector<Value> fusedOperands(genericOp.getInputs());
fusedOperands[operand.index()] = reshapeOp.src();
- fusedOperands.append(linalgOp.getOutputs().begin(),
- linalgOp.getOutputs().end());
+ fusedOperands.append(genericOp.getOutputs().begin(),
+ genericOp.getOutputs().end());
// Compute indexing_maps for the fused operation. The indexing_maps for
// the operands of the consumers that arent fused are the same.
SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
- op.indexing_maps().template getAsValueRange<AffineMapAttr>());
+ genericOp.indexing_maps().template getAsValueRange<AffineMapAttr>());
// Accepted consumer maps are either identity or permutation.
auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
@@ -984,13 +849,14 @@ struct FoldProducerReshapeOpByLinearization
// inverted. Without this the resultant op is not legal.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
return rewriter.notifyMatchFailure(
- op, "fused op loop bound computation failed");
+ genericOp, "fused op loop bound computation failed");
}
- rewriter.startRootUpdate(op);
- op->setOperands(fusedOperands);
- op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps));
- rewriter.finalizeRootUpdate(op);
+ rewriter.startRootUpdate(genericOp);
+ genericOp->setOperands(fusedOperands);
+ genericOp.indexing_mapsAttr(
+ rewriter.getAffineMapArrayAttr(fusedIndexMaps));
+ rewriter.finalizeRootUpdate(genericOp);
return success();
}
return failure();
@@ -1013,7 +879,7 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
/// Pattern to move rank reducing reshape after an elementwise linalg generic
/// op. This is useful to expose more fusion opportunities between named ops and
-/// generic op. This can only be done if there is no broadcast or permuation
+/// generic ops. This can only be done if there is no broadcast or permuation
/// within the dimensions we need to merge.
///
/// For example,
@@ -1040,27 +906,27 @@ getReassociationIndices(ArrayRef<AffineMap> maps) {
/// %3 = linalg.tensor_reshape %2 [
/// #affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
-template <typename GenericOpTy>
-struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
- using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenericOpTy op,
+ LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Only apply to elementwise linalg on tensor.
- if (!op.hasTensorSemantics() ||
- op.getNumParallelLoops() != op.getNumLoops())
+ if (!genericOp.hasTensorSemantics() ||
+ genericOp.getNumParallelLoops() != genericOp.getNumLoops())
return failure();
// Only support identity output maps. It could be extended to permuations if
// needed.
- if (llvm::any_of(op.getOutputIndexingMaps(),
+ if (llvm::any_of(genericOp.getOutputIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
return failure();
- int64_t destRank = op.getNumParallelLoops();
- SmallVector<Value, 4> newOperands = llvm::to_vector<4>(op.getInputs());
+ int64_t destRank = genericOp.getNumParallelLoops();
+ SmallVector<Value, 4> newOperands =
+ llvm::to_vector<4>(genericOp.getInputs());
TensorReshapeOp reshapeFound;
// 1. Look for tensor_reshape operands and figure out save the dimensions
// merged.
- for (auto operand : llvm::enumerate(op.getInputs())) {
+ for (auto operand : llvm::enumerate(genericOp.getInputs())) {
TensorReshapeOp reshapeOp =
operand.value().template getDefiningOp<TensorReshapeOp>();
if (!reshapeOp || reshapeOp.getSrcType().getRank() >
@@ -1069,7 +935,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
}
// TODO: We could support non-identity map as long as the merged
// dimensions are still contiguous.
- if (!op.getIndexingMaps()[operand.index()].isIdentity())
+ if (!genericOp.getIndexingMaps()[operand.index()].isIdentity())
continue;
if (reshapeFound) {
// Only support a second reshape op if it has the same reassociate maps.
@@ -1087,7 +953,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
// Calculate the reassociation indices and rassociated reverse map.
SmallVector<ReassociationIndices> reassociation =
getReassociationIndices(reshapeFound.getReassociationMaps());
- SmallVector<unsigned, 4> remap(destRank);
+ SmallVector<unsigned> remap(destRank);
for (auto &indices : llvm::enumerate(reassociation)) {
for (int64_t index : indices.value()) {
remap[index] = indices.index();
@@ -1096,9 +962,9 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
// 2. Verify that we can merge the dimensions in the linalg and that we
// don't need to create new reshapes operands. Inserting new reshape
// operands would defeat the purpose of the transformation.
- for (auto operand : llvm::enumerate(op.getInputs())) {
+ for (auto operand : llvm::enumerate(genericOp.getInputs())) {
if (operand.value() == newOperands[operand.index()]) {
- AffineMap map = op.getIndexingMaps()[operand.index()];
+ AffineMap map = genericOp.getIndexingMaps()[operand.index()];
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
return failure();
@@ -1108,70 +974,69 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
// 3. Calculate the affine map remapping and the reassociation to apply to
// output tensors.
- SmallVector<AffineMap, 4> newMaps;
+ SmallVector<AffineMap> newMaps;
unsigned newRank = reassociation.size();
- for (auto map : op.getIndexingMaps()) {
+ for (auto map : genericOp.getIndexingMaps()) {
SmallVector<AffineExpr> newExprs;
for (auto expr : map.getResults()) {
unsigned position = expr.template cast<AffineDimExpr>().getPosition();
// Skip dimension merged except for the last of the group.
if (reassociation[remap[position]].back() == position) {
newExprs.push_back(
- getAffineDimExpr(remap[position], op.getContext()));
+ getAffineDimExpr(remap[position], genericOp.getContext()));
}
}
- newMaps.push_back(AffineMap::get(newRank, 0, newExprs, op.getContext()));
+ newMaps.push_back(
+ AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
}
// 4. Reshape the output tensors.
SmallVector<Value> newOutputs;
SmallVector<Type> newOutputTypes;
- for (auto output : op.outputs()) {
+ for (auto output : genericOp.outputs()) {
auto newOutputType = RankedTensorType::get(
reshapeFound.getSrcType().getShape(),
output.getType().template cast<RankedTensorType>().getElementType());
Value newOutput = rewriter.create<TensorReshapeOp>(
- op->getLoc(), newOutputType, output, reassociation);
+ genericOp->getLoc(), newOutputType, output, reassociation);
newOutputTypes.push_back(newOutputType);
newOutputs.push_back(newOutput);
}
// 5. Create a new generic op with lowerer rank.
- SmallVector<StringRef, 4> iteratorTypes(newRank,
- getParallelIteratorTypeName());
- auto newOp =
- rewriter.create<GenericOpTy>(op->getLoc(), newOutputTypes, newOperands,
- newOutputs, newMaps, iteratorTypes);
- rewriter.inlineRegionBefore(op.region(), newOp.region(),
+ SmallVector<StringRef> iteratorTypes(newRank,
+ getParallelIteratorTypeName());
+ auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
+ newOperands, newOutputs, newMaps,
+ iteratorTypes);
+ rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
newOp.region().begin());
// 6. Reshape the so that the type matches the uses.
SmallVector<Value> newResults;
for (auto result : llvm::enumerate(newOp->getResults())) {
newResults.push_back(rewriter.create<TensorReshapeOp>(
- op->getLoc(), op.getOutputTensorTypes()[result.index()],
+ genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
result.value(), reassociation));
}
- rewriter.replaceOp(op, newResults);
+ rewriter.replaceOp(genericOp, newResults);
return success();
}
};
-/// Pattern to fuse a tensor_reshape op with its consumer
-/// generic/indexed_generic op, when the reshape op is collapsing
-/// dimensions. The dimensionality of the loop in the consumer is expanded.
-template <typename GenericOpTy>
+/// Pattern to fuse a tensor_reshape op with its consumer generic op, when the
+/// reshape op is collapsing dimensions. The dimensionality of the loop in the
+/// consumer is expanded.
class FoldWithProducerReshapeOpByExpansion
- : public OpRewritePattern<GenericOpTy> {
+ : public OpRewritePattern<GenericOp> {
public:
FoldWithProducerReshapeOpByExpansion(
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOpTy>(context, benefit),
+ : OpRewritePattern<GenericOp>(context, benefit),
controlFoldingReshapes(foldReshapes) {}
- LogicalResult matchAndRewrite(GenericOpTy genericOp,
+ LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- LinalgOp linalgOp = cast<LinalgOp>(genericOp.getOperation());
- for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
+ for (auto operand : llvm::enumerate(genericOp.getInputs())) {
TensorReshapeOp reshapeOp =
operand.value().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
@@ -1181,14 +1046,14 @@ class FoldWithProducerReshapeOpByExpansion
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
- !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
+ !isFusableWithReshapeByDimExpansion(genericOp, operand.index()) ||
(!controlFoldingReshapes(
reshapeOp->getResult(0),
- linalgOp.getInputOpOperands()[operand.index()])))
+ genericOp.getInputOpOperands()[operand.index()])))
continue;
- Optional<SmallVector<Value, 1>> replacementValues =
- fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(),
+ Optional<SmallVector<Value>> replacementValues =
+ fuseWithReshapeByExpansion(genericOp, reshapeOp, operand.index(),
rewriter);
if (!replacementValues)
return failure();
@@ -1211,10 +1076,9 @@ struct FoldConsumerReshapeOpByLinearization
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
- LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
- if (!producer ||
- !isa<GenericOp, IndexedGenericOp>(producer.getOperation()) ||
- !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 ||
+ GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
+ if (!producer || !producer.hasTensorSemantics() ||
+ producer.getNumOutputs() != 1 ||
!isTensorReshapeOpFoldableByLinearization(
reshapeOp, producer.getOutputIndexingMap(0),
/*asProducer =*/false) ||
@@ -1251,8 +1115,8 @@ struct FoldConsumerReshapeOpByLinearization
Location loc = producer.getLoc();
Value output = rewriter.create<TensorReshapeOp>(
loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
- LinalgOp fusedOp = createLinalgOpOfSameType(
- producer, rewriter, loc, reshapeOp.getResultType(),
+ auto fusedOp = rewriter.create<GenericOp>(
+ loc, reshapeOp.getResultType(),
/*inputs=*/producer.getInputs(),
// TODO: handle outputs.
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
@@ -1280,16 +1144,15 @@ struct FoldReshapeWithGenericOpByExpansion
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank())
return failure();
- LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
+ GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
if (!producer || producer.getNumOutputs() != 1 ||
!isFusableWithReshapeByDimExpansion(producer,
producer.getNumInputs()) ||
isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
reshapeOp.getReassociationMaps()))
return failure();
- Optional<SmallVector<Value, 1>> replacementValues =
- fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(),
- rewriter);
+ Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
+ producer, reshapeOp, producer.getNumInputs(), rewriter);
if (!replacementValues)
return failure();
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
@@ -1297,20 +1160,18 @@ struct FoldReshapeWithGenericOpByExpansion
}
};
-/// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant.
-template <typename LinalgOpTy>
-class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
+/// Pattern to fold a generic op with a splat constant.
+class FoldSplatConstants : public OpRewritePattern<GenericOp> {
public:
FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
PatternBenefit benefit = 1)
- : OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
+ : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
- LogicalResult matchAndRewrite(LinalgOpTy op,
+ LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!op.hasTensorSemantics())
+ if (!genericOp.hasTensorSemantics())
return failure();
- LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
- for (auto operand : llvm::enumerate(linalgOp.getInputOpOperands())) {
+ for (auto operand : llvm::enumerate(genericOp.getInputOpOperands())) {
Operation *def = operand.value().get().getDefiningOp();
DenseElementsAttr constantAttr;
if (!def ||
@@ -1320,49 +1181,46 @@ class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
continue;
// The indexing_maps for the operands of the fused operation are same as
- // those for the operands of the linalgOp without the indexing map at
+ // those for the operands of the genericOp without the indexing map at
// operand.index()
SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
- linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
+ genericOp.indexing_maps().getAsValueRange<AffineMapAttr>());
fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
// Check if the operation shapes to loops map is computable.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
return rewriter.notifyMatchFailure(
- linalgOp, "fused op loop bound computation failed");
+ genericOp, "fused op loop bound computation failed");
}
- // The operands list is same as the linalgOp with the argument for
+ // The operands list is same as the genericOp with the argument for
// constant index dropped.
- SmallVector<Value, 4> fusedOperands(linalgOp.getInputs());
+ SmallVector<Value> fusedOperands(genericOp.getInputs());
fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
// Create a constant scalar value from the splat constant.
Value scalarConstant = rewriter.create<ConstantOp>(
def->getLoc(), constantAttr.getSplatValue());
- LinalgOp fusedOp = createLinalgOpOfSameType(
- linalgOp, rewriter, rewriter.getUnknownLoc(),
- linalgOp->getResultTypes(),
+ auto fusedOp = rewriter.create<GenericOp>(
+ rewriter.getUnknownLoc(), genericOp->getResultTypes(),
/*inputs=*/fusedOperands,
- /*outputs=*/linalgOp.getOutputs(),
+ /*outputs=*/genericOp.getOutputs(),
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- linalgOp.iterator_types(),
+ genericOp.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
// Map the block argument corresponding to the replaced argument with the
// scalar constant.
- Region &linalgOpRegion = linalgOp->getRegion(0);
- Block &entryBlock = *linalgOpRegion.begin();
- unsigned argIndex = entryBlock.getNumArguments() -
- linalgOp.getNumShapedOperands() + operand.index();
+ Region ®ion = genericOp->getRegion(0);
+ Block &entryBlock = *region.begin();
BlockAndValueMapping mapping;
- mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
+ mapping.map(entryBlock.getArgument(operand.index()), scalarConstant);
Region &fusedRegion = fusedOp->getRegion(0);
- rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion,
- fusedRegion.begin(), mapping);
- rewriter.replaceOp(linalgOp, fusedOp->getResults());
+ rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
+ mapping);
+ rewriter.replaceOp(genericOp, fusedOp->getResults());
return success();
}
return failure();
@@ -1373,20 +1231,15 @@ class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
};
} // namespace
-static Optional<SmallVector<Value, 1>>
+static Optional<SmallVector<Value>>
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
+ GenericOp producer,
const ControlElementwiseOpsFusionFn &controlFn) {
- Operation *producer = consumerOpOperand.get().getDefiningOp();
- if (!producer || producer->getNumResults() != 1)
- return llvm::None;
-
- // Fuse when consumer is GenericOp or IndexedGenericOp.
- if (!isa<GenericOp, IndexedGenericOp>(consumerOpOperand.getOwner()) ||
- !isa<GenericOp, IndexedGenericOp>(producer))
+ if (producer->getNumResults() != 1)
return llvm::None;
- return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
- controlFn, rewriter);
+ return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
+ rewriter);
}
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
@@ -1398,25 +1251,24 @@ bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
-template <typename LinalgOpTy>
-class FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
+class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
PatternBenefit benefit = 1)
- : OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
+ : OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
- LogicalResult matchAndRewrite(LinalgOpTy op,
+ LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
- for (OpOperand &opOperand : op.getShapedOpOperands()) {
- LinalgOp producerOp =
- dyn_cast_or_null<LinalgOp>(opOperand.get().getDefiningOp());
- if (!producerOp || !producerOp.hasTensorSemantics())
+ for (OpOperand &opOperand : genericOp.getShapedOpOperands()) {
+ auto producer =
+ dyn_cast_or_null<GenericOp>(opOperand.get().getDefiningOp());
+ if (!producer || !producer.hasTensorSemantics())
continue;
- Optional<SmallVector<Value, 1>> fusedOpResults =
- fuseElementwiseOps(rewriter, opOperand, controlFn);
+ Optional<SmallVector<Value>> fusedOpResults =
+ fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
if (fusedOpResults) {
- rewriter.replaceOp(op, *fusedOpResults);
+ rewriter.replaceOp(genericOp, *fusedOpResults);
return success();
}
}
@@ -1445,8 +1297,7 @@ struct FusionOfTensorOpsPass
}
};
-/// Pass to test folding of reshape op with generic/indexed_generic ops by
-/// linearization.
+/// Pass to test folding of reshape ops with generic ops by linearization.
struct FoldReshapeOpsByLinearizationPass
: public LinalgFoldReshapeOpsByLinearizationBase<
FoldReshapeOpsByLinearizationPass> {
@@ -1462,16 +1313,14 @@ struct FoldReshapeOpsByLinearizationPass
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
- patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, false>,
- FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
+ patterns.add<FoldProducerReshapeOpByLinearization<false>,
FoldConsumerReshapeOpByLinearization<false>>(
patterns.getContext());
}
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns) {
- patterns.add<FoldProducerReshapeOpByLinearization<GenericOp, true>,
- FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
+ patterns.add<FoldProducerReshapeOpByLinearization<true>,
FoldConsumerReshapeOpByLinearization<true>>(
patterns.getContext());
}
@@ -1480,18 +1329,15 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
ControlElementwiseOpsFusionFn controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
- patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
- FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
- patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
}
void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
auto *context = patterns.getContext();
- patterns
- .add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
- FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
- context, options.controlElementwiseOpsFusionFn);
+ patterns.add<FuseElementwiseOps, FoldSplatConstants>(
+ context, options.controlElementwiseOpsFusionFn);
populateFoldReshapeOpsByExpansionPatterns(patterns,
options.controlFoldingReshapesFn);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
@@ -1502,8 +1348,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
auto *context = patterns.getContext();
- patterns.add<PushExpandingReshape<GenericOp>,
- PushExpandingReshape<IndexedGenericOp>>(context);
+ patterns.add<PushExpandingReshape>(context);
}
std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 7b43c0ffbd5b0..36a1e45839ec5 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -205,39 +205,6 @@ func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
// -----
-#map0 = affine_map<(d0, d1, d2) -> (d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>)
- -> tensor<5x?x?xf32>
-{
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %cst = constant dense<42.0> : tensor<5xf32>
- %0 = memref.dim %arg0, %c1 : tensor<5x?x?xf32>
- %1 = memref.dim %arg0, %c2 : tensor<5x?x?xf32>
- %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32>
- %3 = linalg.indexed_generic {
- indexing_maps = [#map0, #map1, #map1],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%cst, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>)
- outs(%2 : tensor<5x?x?xf32>) {
- ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32, %arg6 : f32):
- %4 = mulf %arg4, %arg5 : f32
- linalg.yield %4 : f32
- } -> tensor<5x?x?xf32>
- return %3 : tensor<5x?x?xf32>
-}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: func @indexed_generic_op_constant_fusion
-// CHECK: %[[CST:.*]] = constant {{.*}} : f32
-// CHECK: linalg.generic
-// CHECK: ^{{[a-zA-Z0-9_]*}}
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
-// CHECK: mulf %[[CST]], %[[ARG4]]
-
-// -----
-
#map0 = affine_map<(d0, d1, d2) -> ()>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
@@ -270,89 +237,6 @@ func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
// -----
-#map0 = affine_map<(d0, d1, d2) -> ()>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func @indexed_generic_op_zero_dim_constant_fusion
- (%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
-{
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %cst = constant dense<42.0> : tensor<f32>
- %0 = memref.dim %arg0, %c1 : tensor<5x?x?xf32>
- %1 = memref.dim %arg0, %c2 : tensor<5x?x?xf32>
- %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32>
- %3 = linalg.indexed_generic {
- indexing_maps = [#map0, #map1, #map1],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%cst, %arg0 : tensor<f32>, tensor<5x?x?xf32>)
- outs(%2 : tensor<5x?x?xf32>) {
- ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32, %arg6: f32):
- %4 = mulf %arg4, %arg5 : f32
- linalg.yield %4 : f32
- } -> tensor<5x?x?xf32>
- return %3 : tensor<5x?x?xf32>
-}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion
-// CHECK: %[[CST:.*]] = constant {{.*}} : f32
-// CHECK: linalg.generic
-// CHECK: ^{{[a-zA-Z0-9_]*}}
-// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
-// CHECK: mulf %[[CST]], %[[ARG4]]
-
-// -----
-
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
- %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
- %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
- %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
- %3 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel"] }
- ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
- outs(%2 : tensor<?x?xi32>) {
- ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors
- %10 = addi %arg2, %arg3 : i32
- linalg.yield %10 : i32
- } -> tensor<?x?xi32>
- %4 = linalg.indexed_generic {
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"] }
- ins(%3 : tensor<?x?xi32>)
- outs(%2 : tensor<?x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors
- %5 = index_cast %arg2 : index to i32
- %6 = index_cast %arg3 : index to i32
- %7 = addi %arg4, %5 : i32
- %8 = subi %7, %6 : i32
- linalg.yield %8 : i32
- } -> tensor<?x?xi32>
- return %4 : tensor<?x?xi32>
-}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func @generic_op_indexed_generic_op_fusion
-// CHECK-NOT: linalg.indexed_generic
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
-// CHECK: ^{{[a-zA-Z0-9_]*}}
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32
-// CHECK: %[[ARG0:.+]] = linalg.index 0 : index
-// CHECK: %[[ARG1:.+]] = linalg.index 1 : index
-// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ARG3]] : i32
-// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32
-// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32
-// CHECK: %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32
-// CHECK: %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32
-// CHECK: linalg.yield %[[VAL3]] : i32
-
-// -----
-
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func @producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>,
%arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
@@ -405,56 +289,6 @@ func @producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>,
// -----
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
- %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
- %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
- %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
- %3 = linalg.indexed_generic {
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"] }
- ins(%arg0 : tensor<?x?xi32>)
- outs(%2 : tensor<?x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors
- %4 = index_cast %arg2 : index to i32
- %5 = index_cast %arg3 : index to i32
- %6 = addi %arg4, %4 : i32
- %7 = subi %6, %5 : i32
- linalg.yield %7 : i32
- } -> tensor<?x?xi32>
- %4 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel"] }
- ins(%3, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
- outs(%2 : tensor<?x?xi32>) {
- ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors
- %10 = addi %arg2, %arg3 : i32
- linalg.yield %10 : i32
- } -> tensor<?x?xi32>
- return %4 : tensor<?x?xi32>
-}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func @indexed_generic_op_generic_op_fusion
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
-// CHECK: ^{{[a-zA-Z0-9_]*}}
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32
-// CHECK: %[[ARG0:.+]] = linalg.index 0 : index
-// CHECK: %[[ARG1:.+]] = linalg.index 1 : index
-// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32
-// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32
-// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND]] : i32
-// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32
-// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG3]] : i32
-// CHECK: linalg.yield %[[VAL3]] : i32
-// CHECK-NOT: linalg.generic
-
-// -----
-
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
%c0 = constant 0 : index
@@ -506,63 +340,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32
// -----
-// The indices of the first indexed_generic op are swapped after fusion.
-#map0 = affine_map<(d0, d1) -> (d1, d0)>
-#map1 = affine_map<(d0, d1) -> (d0, d1)>
-func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
- %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
- %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
- %3 = linalg.indexed_generic {
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"] }
- ins(%arg0 : tensor<?x?xi32>)
- outs(%2 : tensor<?x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors
- %4 = index_cast %arg2 : index to i32
- %5 = index_cast %arg3 : index to i32
- %6 = addi %arg4, %4 : i32
- %7 = subi %5, %6 : i32
- linalg.yield %7 : i32
- } -> tensor<?x?xi32>
- %4= linalg.indexed_generic {
- indexing_maps = [#map1, #map1],
- iterator_types = ["parallel", "parallel"] }
- ins(%3 : tensor<?x?xi32>)
- outs(%2 : tensor<?x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: i32, %arg5: i32): // no predecessors
- %5 = index_cast %arg2 : index to i32
- %6 = index_cast %arg3 : index to i32
- %7 = addi %arg4, %5 : i32
- %8 = subi %7, %6 : i32
- linalg.yield %8 : i32
- } -> tensor<?x?xi32>
- return %4 : tensor<?x?xi32>
-}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: func @indexed_generic_op_fusion
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
-// CHECK: ^{{[a-zA-Z0-9_]*}}
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
-// CHECK: %[[ARG0:.+]] = linalg.index 0 : index
-// CHECK: %[[ARG1:.+]] = linalg.index 1 : index
-// CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[ARG1]] : index to i32
-// CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[ARG0]] : index to i32
-// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND1]] : i32
-// CHECK: %[[VAL2:.+]] = subi %[[SUB_OPERAND1]], %[[VAL1]] : i32
-// CHECK: %[[ADD_OPERAND2:.+]] = index_cast %[[ARG0]] : index to i32
-// CHECK: %[[SUB_OPERAND2:.+]] = index_cast %[[ARG1]] : index to i32
-// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32
-// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
-// CHECK: linalg.yield %[[VAL4]] : i32
-// CHECK-NOT: linalg.generic
-
-// -----
-
-// The indices of the first indexed_generic op are swapped after fusion.
+// The indices of the first generic op are swapped after fusion.
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>)
@@ -625,45 +403,52 @@ func @indexed_producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>)
// -----
-func @scalar_indexed_generic_fusion
- (%arg0: tensor<5x1x1xf32>, %arg1 : tensor<i32>) -> tensor<10xf32>
-{
+#map1 = affine_map<(d0) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d0, d1)>
+#map3 = affine_map<(d0, d1) -> (d1)>
+func @one_dim_indexed_producer_consumer_fusion(%arg0 : tensor<?xi32>,
+ %arg1 : tensor<?x?xi32>) -> tensor<?x?xi32> {
%c0 = constant 0 : index
- %cst = constant dense<1.000000e+00> : tensor<10xf32>
- %0 = linalg.init_tensor [] : tensor<f32>
- %1 = linalg.indexed_generic
- {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
- iterator_types = []}
- ins(%arg1 : tensor<i32>) outs(%0 : tensor<f32>) {
- ^bb0(%arg2: i32, %arg3: f32): // no predecessors
- %3 = index_cast %arg2 : i32 to index
- %4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
- linalg.yield %4 : f32
- } -> tensor<f32>
- %2 = linalg.init_tensor [10] : tensor<10xf32>
- %3 = linalg.generic
- {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>,
- affine_map<(d0) -> (d0)>],
- iterator_types = ["parallel"]}
- ins(%1, %cst : tensor<f32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) {
- ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
- %4 = mulf %arg2, %arg3 : f32
- linalg.yield %4 : f32
- } -> tensor<10xf32>
- return %3 : tensor<10xf32>
+ %c1 = constant 1 : index
+ %d0 = memref.dim %arg0, %c0 : tensor<?xi32>
+ %0 = linalg.init_tensor [%d0] : tensor<?xi32>
+ %1 = linalg.generic
+ {indexing_maps = [#map1, #map1],
+ iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xi32>) {
+ ^bb0(%arg2 : i32, %arg3 : i32):
+ %2 = linalg.index 0 : index
+ %3 = index_cast %2 : index to i32
+ %4 = addi %arg2, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<?xi32>
+ %2 = memref.dim %arg1, %c0 : tensor<?x?xi32>
+ %3 = memref.dim %arg1, %c1 : tensor<?x?xi32>
+ %4 = linalg.init_tensor [%2, %3] : tensor<?x?xi32>
+ %5 = linalg.generic
+ {indexing_maps = [#map2, #map3, #map2],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg1, %1 : tensor<?x?xi32>, tensor<?xi32>)
+ outs(%4 : tensor<?x?xi32>) {
+ ^bb0(%arg2 : i32, %arg3 : i32, %arg4: i32):
+ %6 = addi %arg2, %arg3 : i32
+ linalg.yield %6 : i32
+ } -> tensor<?x?xi32>
+ return %5 : tensor<?x?xi32>
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
-// CHECK: func @scalar_indexed_generic_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<i32>
-// CHECK: %[[T0:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: iterator_types = ["parallel"]
-// CHECK-SAME: ins(%[[ARG1]] : tensor<i32>)
-// CHECK: tensor.extract %[[ARG0]]
-// CHECK: linalg.yield
-// CHECK return %[[T0]]
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-LABEL: func @one_dim_indexed_producer_consumer_fusion
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]*]]: i32, %[[ARG1:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK: %[[VAL1:.+]] = index_cast %[[IDX1]] : index to i32
+// CHECK: %[[VAL2:.+]] = addi %[[ARG1]], %[[VAL1]] : i32
+// CHECK: %[[VAL3:.+]] = addi %[[ARG0]], %[[VAL2]] : i32
+// CHECK: linalg.yield %[[VAL3]] : i32
+// CHECK-NOT: linalg.generic
// -----
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ff9aeeb986f65..9ff534c8b6549 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -164,56 +164,6 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// -----
-#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
-#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
-func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
- %arg1 : tensor<?x?x?xi32>) ->
- tensor<?x?x?xi32>
-{
- %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
- tensor<?x?x4x?xi32> into tensor<?x?x?xi32>
- %1 = linalg.indexed_generic {
- indexing_maps = [#map0, #map1, #map1],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
- outs(%0 : tensor<?x?x?xi32>) {
- ^bb0(%arg3 : index, %arg4 : index, %arg5 : index, %arg6: i32, %arg7: i32, %s: i32):
- %1 = muli %arg6, %arg7 : i32
- %2 = index_cast %arg3 : index to i32
- %3 = addi %1, %2 : i32
- %4 = index_cast %arg4 : index to i32
- %5 = addi %3, %4 : i32
- %6 = index_cast %arg5 : index to i32
- %7 = addi %5, %6 : i32
- linalg.yield %7 : i32
- } -> tensor<?x?x?xi32>
- return %1 : tensor<?x?x?xi32>
-}
-
-// The generic op version of the test check for the op structure. Only
-// checking the op body here.
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
-// CHECK: func @indexed_generic_op_reshape_producer_fusion
-// CHECK: linalg.generic
-// CHECK: ^{{.*}}(
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32,
-// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32)
-// CHECK: %[[ARG2:.+]] = linalg.index 0 : index
-// CHECK: %[[ARG3:.+]] = linalg.index 1 : index
-// CHECK: %[[ARG4:.+]] = linalg.index 2 : index
-// CHECK: %[[ARG5:.+]] = linalg.index 3 : index
-// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG3]], %[[ARG2]])
-// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]]
-// CHECK: %[[T5:.+]] = index_cast %[[T3]]
-// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]]
-// CHECK: %[[T7:.+]] = index_cast %[[ARG4]]
-// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]]
-// CHECK: %[[T9:.+]] = index_cast %[[ARG5]]
-// CHECK: %[[T10:.+]] = addi %[[T8]], %[[T9]]
-// CHECK: linalg.yield %[[T10]]
-
-// -----
-
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
@@ -266,50 +216,6 @@ func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
// -----
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
- %arg1 : tensor<?x?xi32>) ->
- tensor<?x?x4x5xi32>
-{
- %0 = linalg.indexed_generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
- outs(%arg0 : tensor<?x?xi32>) {
- ^bb0(%arg3 : index, %arg4 : index, %arg5: i32, %arg6: i32, %s: i32): // no predecessors
- %1 = muli %arg5, %arg6 : i32
- %2 = index_cast %arg3 : index to i32
- %3 = addi %1, %2 : i32
- %4 = index_cast %arg4 : index to i32
- %5 = addi %3, %4 : i32
- linalg.yield %5 : i32
- } -> tensor<?x?xi32>
- %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
- tensor<?x?xi32> into tensor<?x?x4x5xi32>
- return %1 : tensor<?x?x4x5xi32>
-}
-// The generic op version of the test check for the op structure. Only
-// checking the op body here.
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)>
-// CHECK: func @indexed_generic_op_reshape_consumer_fusion
-// CHECK: linalg.generic
-// CHECK: ^{{.*}}(
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32, %[[ARG7:[a-zA-Z0-9]+]]: i32,
-// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32)
-// CHECK: %[[ARG2:.+]] = linalg.index 0 : index
-// CHECK: %[[ARG3:.+]] = linalg.index 1 : index
-// CHECK: %[[ARG4:.+]] = linalg.index 2 : index
-// CHECK: %[[ARG5:.+]] = linalg.index 3 : index
-// CHECK: %[[T3:.+]] = affine.apply #[[MAP]](%[[ARG5]], %[[ARG4]], %[[ARG3]])
-// CHECK: %[[T4:.+]] = muli %[[ARG6]], %[[ARG7]]
-// CHECK: %[[T5:.+]] = index_cast %[[ARG2]]
-// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]]
-// CHECK: %[[T7:.+]] = index_cast %[[T3]]
-// CHECK: %[[T8:.+]] = addi %[[T6]], %[[T7]]
-// CHECK: linalg.yield %[[T8]]
-
-// -----
-
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
%arg1 : tensor<?x?xi32>) ->
@@ -356,69 +262,6 @@ func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
// -----
-func @reshape_as_consumer_permutation
- (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
- -> tensor<2x3x4x5x6x7xi32> {
- %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32>
- %c = linalg.indexed_generic {
- indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
- affine_map<(d0, d1, d2) -> (d1, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>)
- outs(%shape : tensor<6x4x210xi32>) {
- ^bb0(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i32, %arg4: i32, %s: i32):
- %1 = addi %arg3, %arg4 : i32
- %2 = index_cast %arg0 : index to i32
- %3 = addi %1, %2 : i32
- %4 = index_cast %arg1 : index to i32
- %5 = addi %3, %4 : i32
- %6 = index_cast %arg2 : index to i32
- %7 = addi %5, %6 : i32
- linalg.yield %7 : i32
- } -> tensor<6x4x210xi32>
- %d = linalg.tensor_reshape %c [[0, 1], [2], [3, 4, 5]]
- : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
- return %d : tensor<2x3x4x5x6x7xi32>
-}
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
-// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
-// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
-// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
-// CHECK: func @reshape_as_consumer_permutation
-// CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32>
-// CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32>
-// CHECK-DAG: %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
-// CHECK-SAME: [0, 1, 2], [3, 4], [5]
-// CHECK-DAG: %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
-// CHECK-SAME: [0, 1, 2], [3]
-// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
-// CHECK: %[[T4:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
-// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
-// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
-// CHECK: ^{{.+}}(
-// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
-// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32)
-// CHECK: %[[ARG2:.+]] = linalg.index 0 : index
-// CHECK: %[[ARG3:.+]] = linalg.index 1 : index
-// CHECK: %[[ARG4:.+]] = linalg.index 2 : index
-// CHECK: %[[ARG5:.+]] = linalg.index 3 : index
-// CHECK: %[[ARG6:.+]] = linalg.index 4 : index
-// CHECK: %[[ARG7:.+]] = linalg.index 5 : index
-// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[ARG3]], %[[ARG2]])
-// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[ARG6]], %[[ARG5]], %[[ARG4]])
-// CHECK-DAG: %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]]
-// CHECK: %[[T8:.+]] = index_cast %[[T5]]
-// CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]]
-// CHECK: %[[T10:.+]] = index_cast %[[T6]]
-// CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]]
-// CHECK: %[[T12:.+]] = index_cast %[[ARG7]]
-// CHECK: %[[T13:.+]] = addi %[[T11]], %[[T12]]
-
-// -----
-
func @reshape_as_consumer_permutation
(%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
-> tensor<2x3x4x5x6x7xi32> {
@@ -487,59 +330,6 @@ func @reshape_as_consumer_permutation
// -----
-func @reshape_as_producer_projected_permutation(
- %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
-{
- %0 = linalg.tensor_reshape %arg0 [[0, 1], [2]]
- : tensor<33x8x?xi32> into tensor<264x?xi32>
- %1 = linalg.indexed_generic
- {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
- iterator_types = ["parallel", "parallel", "parallel"]}
- ins(%0 : tensor<264x?xi32>)
- outs(%shape : tensor<264x?x4xi32>) {
- ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32, %s: i32): // no predecessors
- %2 = index_cast %arg1 : index to i32
- %3 = addi %arg4, %2 : i32
- %4 = index_cast %arg2 : index to i32
- %5 = addi %3, %4 : i32
- %6 = index_cast %arg3 : index to i32
- %7 = addi %5, %6 : i32
- linalg.yield %7 : i32
- } -> tensor<264x?x4xi32>
- return %1 : tensor<264x?x4xi32>
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)>
-// CHECK: @reshape_as_producer_projected_permutation
-// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32>
-// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: ins(%[[ARG0]] : tensor<33x8x?xi32>)
-// CHECK: ^{{.+}}(
-// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32,
-// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: i32)
-// CHECK: %[[ARG1:.+]] = linalg.index 0 : index
-// CHECK: %[[ARG2:.+]] = linalg.index 1 : index
-// CHECK: %[[ARG3:.+]] = linalg.index 2 : index
-// CHECK: %[[ARG4:.+]] = linalg.index 3 : index
-// CHECK: %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG2]], %[[ARG1]])
-// CHECK: %[[T1:.+]] = index_cast %[[T0]] : index to i32
-// CHECK: %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32
-// CHECK: %[[T3:.+]] = index_cast %[[ARG3]] : index to i32
-// CHECK: %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32
-// CHECK: %[[T5:.+]] = index_cast %[[ARG4]] : index to i32
-// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32
-// CHECK: linalg.yield %[[T6]] : i32
-// CHECK: %[[RES2:.+]] = linalg.tensor_reshape %[[RES]]
-// CHECK-SAME: [0, 1], [2], [3]
-// CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
-// CHECK: return %[[RES2]] : tensor<264x?x4xi32>
-
-// -----
-
func @reshape_as_producer_projected_permutation(
%arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
{
diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
index 623350de97a8f..15fc2b5de0ae2 100644
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
@@ -1,75 +1,18 @@
// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xf32>,
- %arg1 : tensor<?x?x4x?xf32>) -> tensor<?x?x4x?xf32> {
- %0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
- tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
- %1 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>)
- outs(%0 : tensor<?x?x4x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
- %1 = mulf %arg3, %arg4 : f32
- linalg.yield %1 : f32
- } -> tensor<?x?x4x?xf32>
- return %1 : tensor<?x?x4x?xf32>
-}
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
-// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK: func @generic_op_reshape_producer_fusion
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
-// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
-// CHECK-SAME: [0], [1, 2], [3]
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP4]]]
-// CHECK-SAME: ins(%[[ARG0]], %{{.+}} : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>)
-// CHECK-SAME: outs(%[[T0]] : tensor<?x?x4x?xf32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xf32>,
- %arg1 : tensor<?x?x4x5xf32>) -> tensor<?x?xf32> {
- %0 = linalg.generic {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
- outs(%arg0 : tensor<?x?x4x5xf32>){
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
- %1 = mulf %arg3, %arg4 : f32
- linalg.yield %1 : f32
- } -> tensor<?x?x4x5xf32>
- %1 = linalg.tensor_reshape %0 [[0], [1, 2, 3]] :
- tensor<?x?x4x5xf32> into tensor<?x?xf32>
- return %1 : tensor<?x?xf32>
-}
-
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-// CHECK: func @generic_op_reshape_consumer_fusion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xf32>
-// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
-// CHECK-SAME: [0], [1, 2, 3]
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]]]
-// CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
-
-// -----
-
-#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
+func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
-> tensor<?x?x4x?xi32> {
%0 = linalg.tensor_reshape %arg0 [[0], [1, 2], [3]] :
tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
- %1 = linalg.indexed_generic {
+ %1 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
ins(%0 : tensor<?x?x4x?xi32>)
outs(%0 : tensor<?x?x4x?xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32, %arg7 : i32): // no predecessors
- %2 = index_cast %arg2 : index to i32
+ ^bb0(%arg6: i32, %arg7 : i32): // no predecessors
+ %idx = linalg.index 0 : index
+ %2 = index_cast %idx : index to i32
%3 = addi %arg6, %2 : i32
linalg.yield %3 : i32
} -> tensor<?x?x4x?xi32>
@@ -77,26 +20,29 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
}
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK: func @indexed_generic_op_reshape_producer_fusion
+// CHECK: func @generic_op_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xi32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [0], [1, 2], [3]
-// CHECK: linalg.indexed_generic
+// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]]
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?xi32>)
// CHECK-SAME: outs(%[[T0]] : tensor<?x?x4x?xi32>)
+// CHECK: %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT: %[[IDX_CASTED:.+]] = index_cast %[[IDX]] : index to i32
// -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
+func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
-> tensor<?x?xi32> {
- %0 = linalg.indexed_generic {
+ %0 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
ins(%arg0 : tensor<?x?x4x5xi32>) outs(%arg0 : tensor<?x?x4x5xi32>) {
- ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32, %arg7: i32): // no predecessors
- %2 = index_cast %arg2 : index to i32
+ ^bb0(%arg6: i32, %arg7: i32): // no predecessors
+ %idx = linalg.index 0 : index
+ %2 = index_cast %idx : index to i32
%3 = addi %arg6, %2 : i32
linalg.yield %3 : i32
} -> tensor<?x?x4x5xi32>
@@ -106,13 +52,15 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
}
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
-// CHECK: func @indexed_generic_op_reshape_consumer_fusion
+// CHECK: func @generic_op_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x5xi32>
// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG0]]
// CHECK-SAME: [0], [1, 2, 3]
-// CHECK: linalg.indexed_generic
+// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]]
// CHECK-SAME: outs(%[[T0]] : tensor<?x?xi32>)
+// CHECK: %[[IDX:.+]] = linalg.index 0 : index
+// CHECK-NEXT: %[[IDX_CASTED:.+]] = index_cast %[[IDX]] : index to i32
// CHECK-NOT: linalg.tensor_reshape
// -----
More information about the Mlir-commits
mailing list