[Mlir-commits] [mlir] [MLIR][SCF] Fold dim ops of iter_args to respective init_args (PR #109973)
Prashant Kumar
llvmlistbot at llvm.org
Wed Sep 25 06:15:57 PDT 2024
https://github.com/pashu123 updated https://github.com/llvm/llvm-project/pull/109973
>From f73b26ad8687287722e687fd884acf912f3a3f16 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Wed, 25 Sep 2024 18:05:31 +0530
Subject: [PATCH 1/2] [MLIR][SCF] Remove whitespaces.
---
mlir/test/Dialect/SCF/canonicalize.mlir | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c68369a8e4fce7..2eab04363d38f3 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1787,7 +1787,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1830,7 +1830,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1854,7 +1854,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
%switch_cst_2 = arith.constant 2: index
%1 = scf.index_switch %switch_cst_2 -> f32
case 0 {
@@ -1865,7 +1865,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
return %0, %1 : f32, f32
}
@@ -1891,3 +1891,5 @@ func.func @index_switch_fold_no_res() {
// CHECK-LABEL: func.func @index_switch_fold_no_res()
// CHECK-NEXT: "test.op"() : () -> ()
+
+// -----
>From 25480550ebad8b29e2f38637ce22a7e2c4ddb297 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Wed, 25 Sep 2024 18:07:10 +0530
Subject: [PATCH 2/2] [MLIR][SCF] Fold dim ops of iter_args to respective
init_args
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 46 +++++++++++++++++++++++--
mlir/test/Dialect/SCF/canonicalize.mlir | 27 +++++++++++++++
2 files changed, 70 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6d47ff3890977a..9705a9cd1741b9 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1478,6 +1478,45 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
}
};
+/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+/// ...
+/// }
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+/// ...
+/// }
+/// ```
+struct DimOfForallIterArg : public OpRewritePattern<tensor::DimOp> {
+ using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+ PatternRewriter &rewriter) const final {
+ auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
+ if (!blockArg)
+ return failure();
+ auto forallOp =
+ dyn_cast<ForallOp>(blockArg.getParentBlock()->getParentOp());
+ if (!forallOp)
+ return failure();
+ Value initArg = forallOp.getTiedLoopInit(blockArg)->get();
+ rewriter.modifyOpInPlace(
+ dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
+
+ return success();
+ }
+};
+
class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
public:
using OpRewritePattern<ForallOp>::OpRewritePattern;
@@ -1851,9 +1890,10 @@ struct FoldTensorCastOfOutputIntoForallOp
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
- ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
- ForallOpSingleOrZeroIterationDimsFolder>(context);
+ results.add<DimOfForallOp, DimOfForallIterArg,
+ FoldTensorCastOfOutputIntoForallOp, ForallOpControlOperandsFolder,
+ ForallOpIterArgsFolder, ForallOpSingleOrZeroIterationDimsFolder>(
+ context);
}
/// Given the region at `index`, or the parent operation if `index` is None,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2eab04363d38f3..e203a517932237 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1893,3 +1893,30 @@ func.func @index_switch_fold_no_res() {
// CHECK-NEXT: "test.op"() : () -> ()
// -----
+
+func.func @forall_iter_to_init_arg(
+ %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+
+ %result = scf.forall (%i) = (%c0) to (%dim0)
+ step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+
+ %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
+ %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
+ : tensor<?x?xf32> to tensor<1x?xf32>
+
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
+ : tensor<1x?xf32> into tensor<?x?xf32>
+ }
+ }
+
+ return %result : tensor<?x?xf32>
+}
+// CHECK-LABEL: @forall_iter_to_init_arg
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK-SAME: shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
+// CHECK-NEXT: %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>
More information about the Mlir-commits
mailing list