[Mlir-commits] [mlir] f358c37 - [mlir][linalg] Remove IndexedGenericOp support from DropUnitDims...
Tobias Gysi
llvmlistbot at llvm.org
Thu May 13 07:19:20 PDT 2021
Author: Tobias Gysi
Date: 2021-05-13T14:18:59Z
New Revision: f358c372094599bf2a9246a0d2145cd949b4c62d
URL: https://github.com/llvm/llvm-project/commit/f358c372094599bf2a9246a0d2145cd949b4c62d
DIFF: https://github.com/llvm/llvm-project/commit/f358c372094599bf2a9246a0d2145cd949b4c62d.diff
LOG: [mlir][linalg] Remove IndexedGenericOp support from DropUnitDims...
after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612).
Differential Revision: https://reviews.llvm.org/D102235
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 9c4d8afadb6e..623c8245630e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -146,13 +146,13 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
}
/// Update the index accesses of linalg operations having index semantics.
-template <typename GenericOpTy>
-static void replaceUnitDimIndexOps(GenericOpTy op,
+static void replaceUnitDimIndexOps(GenericOp genericOp,
const DenseSet<unsigned> &unitDims,
PatternRewriter &rewriter) {
- assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
+ assert(genericOp->getNumRegions() == 1 &&
+ genericOp->getRegion(0).getBlocks().size() == 1 &&
"expected generic operation to have one block.");
- Block &block = op->getRegion(0).front();
+ Block &block = genericOp->getRegion(0).front();
for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
OpBuilder::InsertionGuard guard(rewriter);
@@ -170,39 +170,13 @@ static void replaceUnitDimIndexOps(GenericOpTy op,
}
}
-/// Modify the region of indexed generic op to drop arguments corresponding to
-/// loops that are unit trip count.
-template <typename OpTy>
-static LogicalResult
-replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
- PatternRewriter &rewriterp) {
- return success();
-}
-
-template <>
-LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
- IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
- PatternRewriter &rewriter) {
- OpBuilder::InsertionGuard guard(rewriter);
- Block *entryBlock = &op->getRegion(0).front();
- rewriter.setInsertionPointToStart(entryBlock);
- Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
- for (unsigned unitDimLoop : unitDims) {
- entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
- }
- SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end());
- entryBlock->eraseArguments(unitDimsToErase);
- return success();
-}
-
namespace {
/// Pattern to fold unit-trip count loops in GenericOps.
-template <typename GenericOpTy>
-struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
- using OpRewritePattern<GenericOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenericOpTy op,
+struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
+ SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.empty())
return failure();
@@ -213,7 +187,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
if (!invertedMap)
return failure();
SmallVector<int64_t, 4> dims;
- for (ShapedType shapedType : op.getShapedOperandTypes())
+ for (ShapedType shapedType : genericOp.getShapedOperandTypes())
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
// Find all the reduction iterators. Those need some special consideration
@@ -221,7 +195,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
auto getLoopDimsOfType =
[&](StringRef iteratorTypeName) -> SmallVector<unsigned, 4> {
SmallVector<AffineExpr> dimExprs;
- getDimsOfType(op, iteratorTypeName, dimExprs);
+ getDimsOfType(genericOp, iteratorTypeName, dimExprs);
return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) {
return expr.cast<AffineDimExpr>().getPosition();
}));
@@ -230,7 +204,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
DenseSet<unsigned> unitDims;
SmallVector<unsigned, 4> unitDimsReductionLoops;
- ArrayAttr iteratorTypes = op.iterator_types();
+ ArrayAttr iteratorTypes = genericOp.iterator_types();
for (auto expr : enumerate(invertedMap.getResults())) {
if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
if (dims[dimExpr.getPosition()] == 1) {
@@ -260,7 +234,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
ArrayAttr newIndexingMapAttr =
replaceUnitDims(unitDims, indexingMaps, context);
if (!newIndexingMapAttr)
- return op.emitError("unable to compute modified indexing_maps");
+ return genericOp.emitError("unable to compute modified indexing_maps");
// Compute the iterator types of the modified op by dropping the one-trip
// count loops.
@@ -270,12 +244,11 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
newIteratorTypes.push_back(attr.value());
}
- rewriter.startRootUpdate(op);
- op.indexing_mapsAttr(newIndexingMapAttr);
- op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
- (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
- replaceUnitDimIndexOps(op, unitDims, rewriter);
- rewriter.finalizeRootUpdate(op);
+ rewriter.startRootUpdate(genericOp);
+ genericOp.indexing_mapsAttr(newIndexingMapAttr);
+ genericOp.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
+ replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
+ rewriter.finalizeRootUpdate(genericOp);
return success();
}
};
@@ -351,23 +324,22 @@ convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
}
/// Pattern to replace tensors operands/results that are unit extents.
-template <typename GenericOpTy>
-struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
- using OpRewritePattern<GenericOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenericOpTy op,
+struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- if (!op.hasTensorSemantics())
+ if (!genericOp.hasTensorSemantics())
return failure();
MLIRContext *context = rewriter.getContext();
- Location loc = op.getLoc();
+ Location loc = genericOp.getLoc();
SmallVector<AffineMap, 4> newIndexingMaps;
SmallVector<ArrayAttr, 4> reassociationMaps;
SmallVector<ShapedType, 4> newInputOutputTypes;
bool doCanonicalization = false;
- for (auto it :
- llvm::zip(op.getIndexingMaps(), op.getShapedOperandTypes())) {
+ for (auto it : llvm::zip(genericOp.getIndexingMaps(),
+ genericOp.getShapedOperandTypes())) {
auto replacementInfo = replaceUnitExtents(
std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
context);
@@ -402,20 +374,20 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
return res;
};
- SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
- SmallVector<Value, 4> newOutputs = insertReshapes(op.outputs());
+ SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
+ SmallVector<Value, 4> newOutputs = insertReshapes(genericOp.outputs());
// If any result type changes, insert a reshape to convert from the original
// type to the new type.
SmallVector<Type, 4> resultTypes;
- resultTypes.reserve(op.getNumResults());
- for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
- resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
- GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
+ resultTypes.reserve(genericOp.getNumResults());
+ for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+ resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
+ GenericOp replacementOp = rewriter.create<GenericOp>(
loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
llvm::to_vector<4>(
- op.iterator_types().template getAsValueRange<StringAttr>()));
- rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
+ genericOp.iterator_types().template getAsValueRange<StringAttr>()));
+ rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
replacementOp.region().begin());
// If any result tensor has a modified shape, then add reshape to recover
@@ -423,7 +395,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
SmallVector<Value, 4> resultReplacements;
for (auto result : llvm::enumerate(replacementOp.getResults())) {
unsigned index = result.index() + replacementOp.getNumInputs();
- RankedTensorType origResultType = op.getResult(result.index())
+ RankedTensorType origResultType = genericOp.getResult(result.index())
.getType()
.template cast<RankedTensorType>();
if (origResultType != result.value().getType())
@@ -433,7 +405,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
else
resultReplacements.push_back(result.value());
}
- rewriter.replaceOp(op, resultReplacements);
+ rewriter.replaceOp(genericOp, resultReplacements);
return success();
}
};
@@ -528,9 +500,7 @@ struct UseRankReducedSubTensorInsertOp
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
- patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
- ReplaceUnitExtentTensors<GenericOp>,
- ReplaceUnitExtentTensors<IndexedGenericOp>,
+ patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors,
UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
@@ -545,9 +515,7 @@ struct LinalgFoldUnitExtentDimsPass
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
if (foldOneTripLoopsOnly)
- patterns
- .add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>>(
- context);
+ patterns.add<FoldUnitDimLoops>(context);
else
populateFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 808622bc85c8..5bc11f20c7ca 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -42,48 +42,6 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf3
library_call = "some_external_func"
}
-func @drop_one_trip_loops_indexed_generic
- (%arg0 : tensor<?x1x?xi32>, %shape: tensor<?x1x?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
-{
- %0 = linalg.indexed_generic #trait
- ins(%arg0 : tensor<?x1x?xi32>)
- outs(%shape: tensor<?x1x?x1x?xi32>) {
- ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index,
- %arg5 : index, %arg6 : i32, %arg7 : i32) :
- %1 = addi %arg1, %arg2 : index
- %2 = addi %1, %arg3 : index
- %3 = addi %2, %arg4 : index
- %4 = addi %3, %arg5 : index
- %5 = index_cast %4 : index to i32
- %6 = addi %5, %arg6 : i32
- linalg.yield %6 : i32
- } -> tensor<?x1x?x1x?xi32>
- return %0 : tensor<?x1x?x1x?xi32>
-}
-// CHECK-LABEL: func @drop_one_trip_loops_indexed_generic
-// CHECK: linalg.indexed_generic
-// CHECK: ^{{.+}}(
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index
-// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32)
-// CHECK: %[[T3:.+]] = addi %[[ARG1]], %[[ARG2]]
-// CHECK: %[[T4:.+]] = addi %[[T3]], %[[ARG3]]
-// CHECK: %[[T5:.+]] = index_cast %[[T4]] : index to i32
-// CHECK: %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32
-// CHECK: linalg.yield %[[T6]] : i32
-
-// -----
-
-#accesses = [
- affine_map<(i, j, k, l, m) -> (i, k, m)>,
- affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
-]
-
-#trait = {
- iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
- indexing_maps = #accesses,
- library_call = "some_external_func"
-}
-
func @drop_one_trip_loops_indexed
(%arg0 : tensor<?x1x?xi32>, %shape: tensor<?x1x?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
{
@@ -158,35 +116,6 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
library_call = "some_external_func"
}
-func @drop_all_loops_indexed_generic
- (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{
- %0 = linalg.indexed_generic #trait
- ins(%arg0 : tensor<1x1xi32>)
- outs(%arg0 : tensor<1x1xi32>) {
- ^bb0(%arg1 : index, %arg2 : index, %arg3: i32, %arg4: i32) :
- %1 = addi %arg1, %arg2 : index
- %2 = index_cast %1 : index to i32
- %3 = addi %2, %arg3 : i32
- linalg.yield %3 : i32
- } -> tensor<1x1xi32>
- return %0 : tensor<1x1xi32>
-}
-
-// CHECK-LABEL: func @drop_all_loops_indexed_generic
-// CHECK: linalg.indexed_generic
-// CHECK: ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
-// CHECK: linalg.yield %[[ARG1]] : i32
-
-// -----
-
-#map0 = affine_map<(i, j) -> (i, j)>
-#access = [#map0, #map0]
-#trait = {
- iterator_types = ["parallel", "parallel"],
- indexing_maps = #access,
- library_call = "some_external_func"
-}
-
func @drop_all_loops_indexed
(%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{
%0 = linalg.generic #trait
More information about the Mlir-commits
mailing list