[Mlir-commits] [mlir] 6eee66d - [mlir][linalg] Add a new pattern to handle folding unit reduction dims.
Hanhan Wang
llvmlistbot at llvm.org
Wed Nov 23 10:47:24 PST 2022
Author: Hanhan Wang
Date: 2022-11-23T10:47:10-08:00
New Revision: 6eee66d12ab33f35a37a1514342b51ae93d175e8
URL: https://github.com/llvm/llvm-project/commit/6eee66d12ab33f35a37a1514342b51ae93d175e8
DIFF: https://github.com/llvm/llvm-project/commit/6eee66d12ab33f35a37a1514342b51ae93d175e8.diff
LOG: [mlir][linalg] Add a new pattern to handle folding unit reduction dims.
The output operands will be added to input operands if the generic op (on tensors)
becomes an elementwise operation. The outputs of the generic op is still the same.
They will be cleaned up by ReplaceWithEmptyTensorIfUnused pattern.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D138251
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 7b9b7358bb459..5933ff9c65140 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -19,12 +19,15 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -225,6 +228,125 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
}
};
+/// Pattern to add init operands to ins when all the loops are parallel and
+/// blockArgument corresponding to init is used in the region. This is a fix-up
+/// when unit reduction dimensions are all folded away. In this context, it
+/// becomes a elementwise generic op. E.g., it converts
+///
+/// %0 = tensor.empty() : tensor<1x1xf32>
+/// %1 = linalg.fill
+/// ins(%cst : f32)
+/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
+/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
+/// affine_map<(d0) -> (0, d0)>],
+/// iterator_types = ["parallel"]}
+/// ins(%arg0 : tensor<1x?x1x1xf32>)
+/// outs(%1 : tensor<1x1xf32>) {
+/// ^bb0(%in: f32, %out: f32):
+/// %3 = arith.addf %in, %out : f32
+/// linalg.yield %3 : f32
+/// } -> tensor<1x1xf32>
+///
+/// into
+///
+/// %0 = tensor.empty() : tensor<1x1xf32>
+/// %1 = linalg.fill
+/// ins(%cst : f32)
+/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
+/// %2 = tensor.empty() : tensor<1x1xf32>
+/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
+/// affine_map<(d0) -> (0, d0)>,
+/// affine_map<(d0) -> (0, d0)>],
+/// iterator_types = ["parallel"]}
+/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
+/// outs(%2 : tensor<1x1xf32>) {
+/// ^bb0(%in: f32, %in_0: f32, %out: f32):
+/// %4 = arith.addf %in, %in_0 : f32
+/// linalg.yield %4 : f32
+/// } -> tensor<1x1xf32>
+struct AddInitOperandsToInput : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+ return failure();
+
+ auto outputOperands = genericOp.getDpsInitOperands();
+ SetVector<OpOperand *> candidates;
+ for (OpOperand *op : outputOperands) {
+ if (genericOp.getMatchingBlockArgument(op).use_empty())
+ continue;
+ candidates.insert(op);
+ }
+
+ if (candidates.empty())
+ return failure();
+
+ // Compute the modified indexing maps.
+ int64_t origNumInput = genericOp.getNumDpsInputs();
+ SmallVector<Value> newInputOperands = genericOp.getDpsInputOperands();
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<AffineMap> newIndexingMaps;
+ newIndexingMaps.append(indexingMaps.begin(),
+ std::next(indexingMaps.begin(), origNumInput));
+ for (OpOperand *op : candidates) {
+ newInputOperands.push_back(op->get());
+ newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
+ }
+ newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
+ indexingMaps.end());
+
+ Location loc = genericOp.getLoc();
+ SmallVector<Value> newOutputOperands = outputOperands;
+ for (OpOperand *op : candidates) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointAfterValue(op->get());
+ auto elemType = op->get().getType().cast<ShapedType>().getElementType();
+ auto empty = rewriter.create<tensor::EmptyOp>(
+ loc, tensor::createDimValues(rewriter, loc, op->get()), elemType);
+
+ auto [start, end] = genericOp.getDpsInitsPositionRange();
+ newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
+ }
+
+ auto newOp = rewriter.create<GenericOp>(
+ loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
+ newIndexingMaps, genericOp.getIteratorTypesArray(),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+
+ Region ®ion = newOp.getRegion();
+ Block *block = new Block();
+ region.push_back(block);
+ BlockAndValueMapping mapper;
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(block);
+ for (auto bbarg : genericOp.getRegionInputArgs())
+ mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
+
+ for (OpOperand *op : candidates) {
+ BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
+ mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
+ }
+
+ for (OpOperand *op : outputOperands) {
+ BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
+ if (candidates.count(op))
+ block->addArgument(bbarg.getType(), loc);
+ else
+ mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
+ }
+
+ for (auto &op : genericOp.getBody()->getOperations()) {
+ rewriter.clone(op, mapper);
+ }
+ rewriter.replaceOp(genericOp, newOp.getResults());
+
+ return success();
+ }
+};
+
struct UnitExtentReplacementInfo {
Type type;
AffineMap indexMap;
@@ -536,7 +658,8 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
- patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp,
+ patterns.add<FoldUnitDimLoops, AddInitOperandsToInput, ReplaceUnitExtents,
+ RankReducedExtractSliceOp,
RankReducedInsertSliceOp<tensor::InsertSliceOp>,
RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
context);
@@ -544,6 +667,8 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+ memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
namespace {
@@ -555,7 +680,7 @@ struct LinalgFoldUnitExtentDimsPass
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
if (foldOneTripLoopsOnly)
- patterns.add<FoldUnitDimLoops>(context);
+ patterns.add<FoldUnitDimLoops, AddInitOperandsToInput>(context);
else
populateFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 4ff1f19fe36b5..ffa95633da597 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -384,11 +384,12 @@ func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1
// CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
+// CHECK: %[[INIT2:.+]] = tensor.empty() : tensor<1xf32>
// CHECK: %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel"]
-// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
-// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>)
+// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor<?xf32>, tensor<1xf32>)
+// CHECK-SAME: outs(%[[INIT2]] : tensor<1xf32>)
// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]]
// CHECK: return %[[RESULT_RESHAPE]]
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 72fa3c0b92faf..256b055081472 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8301,6 +8301,7 @@ cc_library(
":LinalgUtils",
":MathDialect",
":MemRefDialect",
+ ":MemRefTransforms",
":Pass",
":SCFDialect",
":SCFTransforms",
More information about the Mlir-commits
mailing list