[llvm-branch-commits] [mlir] b8d211f - [MLIR][Linalg] Canonicalization patterns for linalg.generic.
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Mar 15 11:44:20 PDT 2022
Author: Nirvedh
Date: 2022-03-15T18:42:43Z
New Revision: b8d211fc317ffefaed1d65b226cda6c464f7d216
URL: https://github.com/llvm/llvm-project/commit/b8d211fc317ffefaed1d65b226cda6c464f7d216
DIFF: https://github.com/llvm/llvm-project/commit/b8d211fc317ffefaed1d65b226cda6c464f7d216.diff
LOG: [MLIR][Linalg] Canonicalization patterns for linalg.generic.
Fold linalg.fill into linalg.generic.
Remove dead arguments used in linalg.generic.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D121535
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/fusion-indexed.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 331a8b91bd330..2c9b1e0c53553 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -105,6 +105,33 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}
+/// Helper function to find if there is atleast one dimension in an AffineMap
+/// testMap that is contained in `testMapLocation` of `maps` but not in any
+/// other locations
+static bool hasaUniqueDim(ArrayRef<AffineMap> maps, unsigned testMapLocation) {
+ AffineMap testMap = maps[testMapLocation];
+ llvm::SmallDenseSet<unsigned> dimsToCheck;
+ for (auto result : testMap.getResults()) {
+ auto expr = result.dyn_cast<AffineDimExpr>();
+ if (expr != nullptr)
+ dimsToCheck.insert(expr.getPosition());
+ }
+ for (auto It : llvm::enumerate(maps)) {
+ if (It.index() == testMapLocation)
+ continue;
+ auto map = It.value();
+ for (auto result : map.getResults()) {
+ auto expr = result.dyn_cast<AffineDimExpr>();
+ if (expr != nullptr) {
+ dimsToCheck.erase(expr.getPosition());
+ }
+ if (dimsToCheck.empty())
+ return false;
+ }
+ }
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// Region builder helper.
// TODO: Move this to a utility library.
@@ -826,11 +853,95 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
return success();
}
};
+
+/// Drop dead args of a linalg generic op.
+/// An arg is dead if it has zero uses in the op region.
+struct DeadArgsGenericOpInputs : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<AffineMap> oldIndexingMaps = genericOp.getIndexingMaps();
+ // Maps must be projected permutations.
+ if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) {
+ return !map.isProjectedPermutation();
+ }))
+ return failure();
+ Block &payload = genericOp.region().front();
+ SmallVector<Value> newInputOperands;
+ SmallVector<AffineMap> newIndexingMaps;
+ bool deadArgFound = false;
+ int inputSize = genericOp.getInputOperands().size();
+ for (int i = inputSize - 1; i >= 0; i--) {
+ OpOperand *opOperand = genericOp.getInputOperand(i);
+ // Iterate in reverse, so that we erase later args first, preventing the
+ // argument list from shifting unexpectedly and invalidating all our
+ // indices.
+ if (payload.getArgument(i).use_empty() &&
+ !hasaUniqueDim(oldIndexingMaps, i)) {
+ payload.eraseArgument(i);
+ deadArgFound = true;
+ // remove this indexing map out of consideration for hasaUniqueDim check
+ oldIndexingMaps.erase(oldIndexingMaps.begin() + i);
+ } else {
+ newInputOperands.insert(newInputOperands.begin(), opOperand->get());
+ newIndexingMaps.insert(newIndexingMaps.begin(),
+ genericOp.getTiedIndexingMap(opOperand));
+ }
+ }
+ // Bail out if there are no dead args.
+ if (!deadArgFound)
+ return failure();
+ for (OpOperand *opOperand : genericOp.getOutputOperands())
+ newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
+ SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+
+ auto newOp = rewriter.create<GenericOp>(
+ genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
+ outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
+ genericOp.iterator_types(), genericOp.docAttr(),
+ genericOp.library_callAttr());
+ // Copy over unknown attributes. They might be load bearing for some flow.
+ ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
+ for (NamedAttribute kv : genericOp->getAttrs()) {
+ if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
+ newOp->setAttr(kv.getName(), kv.getValue());
+ }
+ }
+ rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
+ newOp.region().begin());
+ rewriter.replaceOp(genericOp, newOp->getResults());
+ return success();
+ }
+};
+
+/// Fold linalg.fill into linalg.generic
+struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+ bool fillFound = false;
+ Block &payload = genericOp.region().front();
+ for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
+ if (fillOp) {
+ fillFound = true;
+ payload.getArgument(opOperand->getOperandNumber())
+ .replaceAllUsesWith(fillOp.value());
+ }
+ }
+ // fail if there are no FillOps to fold.
+ return success(fillFound);
+ }
+};
} // namespace
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
+ results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
+ DeadArgsGenericOpInputs, FoldFillWithGenericOp>(context);
}
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 28655b96e5df6..0e0faab56f6c9 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -325,6 +325,106 @@ func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
// -----
+// CHECK-LABEL: func @fold_fill_generic_basic
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK-NOT: linalg.fill
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
+#map0 = affine_map<(d0) -> (d0)>
+func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 7.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %1 = linalg.init_tensor [%0] : tensor<?xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
+ %3 = linalg.init_tensor [%0] : tensor<?xf32>
+ %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %5 = arith.addf %arg1, %arg2 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?xf32>
+ return %4 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_fill_generic_mixedaccess
+// CHECK-NOT: linalg.fill
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-NOT: ins
+// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 0 : index
+ %cst1 = arith.constant 7.0 : f32
+ %cst2 = arith.constant 6.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %4 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
+ %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %6 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %8 = arith.divf %arg1, %arg2 : f32
+ linalg.yield %8 : f32
+ } -> tensor<?x?xf32>
+ return %7 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_deadargs_generic_basic
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
+// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
+#map0 = affine_map<(d0) -> (d0)>
+func @remove_deadargs_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 7.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %1 = linalg.init_tensor [%0] : tensor<?xf32>
+ %2 = linalg.init_tensor [%0] : tensor<?xf32>
+ %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor<?xf32>, tensor<?xf32>) outs (%2:tensor<?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %4 = arith.addf %arg1, %cst : f32
+ linalg.yield %4 : f32
+ } -> tensor<?xf32>
+ return %3 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-NOT: ins
+// CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d1, d0)>
+func @remove_deadargs_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 0 : index
+ %cst1 = arith.constant 7.0 : f32
+ %cst2 = arith.constant 6.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %3 = linalg.init_tensor [%1, %0] : tensor<?x?xf32>
+ %4 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%4:tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
+ %6 = arith.divf %cst1, %cst2 : f32
+ linalg.yield %6 : f32
+ } -> tensor<?x?xf32>
+ return %5 : tensor<?x?xf32>
+}
+
+// -----
// CHECK-LABEL: func @fold_fill_reshape()
func @fold_fill_reshape() -> tensor<6x4xf32> {
%zero = arith.constant 0.0 : f32
diff --git a/mlir/test/Dialect/Linalg/fusion-indexed.mlir b/mlir/test/Dialect/Linalg/fusion-indexed.mlir
index 1b075cc5ac483..03ac767136f00 100644
--- a/mlir/test/Dialect/Linalg/fusion-indexed.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-indexed.mlir
@@ -46,7 +46,8 @@ func @fuse_indexed_consumer(%A: memref<?x?xf32>,
%10 = arith.index_cast %7 : index to i32
%11 = arith.sitofp %10 : i32 to f32
%12 = arith.addf %9, %11 : f32
- linalg.yield %12 : f32
+ %13 = arith.addf %12, %arg4 : f32
+ linalg.yield %13 : f32
}
}
}
More information about the llvm-branch-commits
mailing list