[Mlir-commits] [mlir] [MLIR][SCF] Fold dim ops of iter_args to respective init_args (PR #109973)
Prashant Kumar
llvmlistbot at llvm.org
Thu Sep 26 02:54:40 PDT 2024
https://github.com/pashu123 updated https://github.com/llvm/llvm-project/pull/109973
>From f73b26ad8687287722e687fd884acf912f3a3f16 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Wed, 25 Sep 2024 18:05:31 +0530
Subject: [PATCH 1/3] [MLIR][SCF] Remove whitespaces.
---
mlir/test/Dialect/SCF/canonicalize.mlir | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c68369a8e4fce7..2eab04363d38f3 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1787,7 +1787,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1830,7 +1830,7 @@ module {
}
// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK: %[[RESULT:.*]] = scf.forall
// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1854,7 +1854,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
%switch_cst_2 = arith.constant 2: index
%1 = scf.index_switch %switch_cst_2 -> f32
case 0 {
@@ -1865,7 +1865,7 @@ func.func @index_switch_fold() -> (f32, f32) {
%y = arith.constant 42.0 : f32
scf.yield %y : f32
}
-
+
return %0, %1 : f32, f32
}
@@ -1891,3 +1891,5 @@ func.func @index_switch_fold_no_res() {
// CHECK-LABEL: func.func @index_switch_fold_no_res()
// CHECK-NEXT: "test.op"() : () -> ()
+
+// -----
>From 25480550ebad8b29e2f38637ce22a7e2c4ddb297 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Wed, 25 Sep 2024 18:07:10 +0530
Subject: [PATCH 2/3] [MLIR][SCF] Fold dim ops of iter_args to respective
init_args
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 46 +++++++++++++++++++++++--
mlir/test/Dialect/SCF/canonicalize.mlir | 27 +++++++++++++++
2 files changed, 70 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6d47ff3890977a..9705a9cd1741b9 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1478,6 +1478,45 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
}
};
+/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+/// ...
+/// }
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+/// ...
+/// }
+/// ```
+struct DimOfForallIterArg : public OpRewritePattern<tensor::DimOp> {
+ using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+ PatternRewriter &rewriter) const final {
+ auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
+ if (!blockArg)
+ return failure();
+ auto forallOp =
+ dyn_cast<ForallOp>(blockArg.getParentBlock()->getParentOp());
+ if (!forallOp)
+ return failure();
+ Value initArg = forallOp.getTiedLoopInit(blockArg)->get();
+ rewriter.modifyOpInPlace(
+ dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
+
+ return success();
+ }
+};
+
class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
public:
using OpRewritePattern<ForallOp>::OpRewritePattern;
@@ -1851,9 +1890,10 @@ struct FoldTensorCastOfOutputIntoForallOp
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
- ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
- ForallOpSingleOrZeroIterationDimsFolder>(context);
+ results.add<DimOfForallOp, DimOfForallIterArg,
+ FoldTensorCastOfOutputIntoForallOp, ForallOpControlOperandsFolder,
+ ForallOpIterArgsFolder, ForallOpSingleOrZeroIterationDimsFolder>(
+ context);
}
/// Given the region at `index`, or the parent operation if `index` is None,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2eab04363d38f3..e203a517932237 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1893,3 +1893,30 @@ func.func @index_switch_fold_no_res() {
// CHECK-NEXT: "test.op"() : () -> ()
// -----
+
+func.func @forall_iter_to_init_arg(
+ %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+
+ %result = scf.forall (%i) = (%c0) to (%dim0)
+ step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+
+ %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
+ %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
+ : tensor<?x?xf32> to tensor<1x?xf32>
+
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
+ : tensor<1x?xf32> into tensor<?x?xf32>
+ }
+ }
+
+ return %result : tensor<?x?xf32>
+}
+// CHECK-LABEL: @forall_iter_to_init_arg
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK-SAME: shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
+// CHECK-NEXT: %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>
>From a6ca06fc6aec6ea8df32e3b5d247cfb1740e1968 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Thu, 26 Sep 2024 15:24:16 +0530
Subject: [PATCH 3/3] Move pattern to resolveShapedTypeResultDims
---
.../ResolveShapedTypeResultDims.cpp | 44 ++++++++++++++++++-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 39 ----------------
mlir/test/Dialect/MemRef/resolve-dim-ops.mlir | 28 ++++++++++++
mlir/test/Dialect/SCF/canonicalize.mlir | 29 ------------
4 files changed, 70 insertions(+), 70 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 0cb5931ce6bf9b..79466c26b5cab6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -103,6 +103,45 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
return success();
}
};
+
+/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+/// ...
+/// }
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+/// ...
+/// }
+/// ```
+template <typename OpTy>
+struct IterArgsToInitArgs : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy dimOp,
+ PatternRewriter &rewriter) const final {
+ auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
+ if (!blockArg)
+ return failure();
+ auto loopLikeOp =
+ dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
+ if (!loopLikeOp)
+ return failure();
+ Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+ rewriter.modifyOpInPlace(
+ dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
+ return success();
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -127,8 +166,9 @@ struct ResolveShapedTypeResultDimsPass final
void memref::populateResolveRankedShapedTypeResultDimsPatterns(
RewritePatternSet &patterns) {
patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
- DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
- patterns.getContext());
+ DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
+ IterArgsToInitArgs<memref::DimOp>,
+ IterArgsToInitArgs<tensor::DimOp>>(patterns.getContext());
}
void memref::populateResolveShapedTypeResultDimsPatterns(
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9705a9cd1741b9..2041ce88152f33 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1478,45 +1478,6 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
}
};
-/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
-///
-/// ```
-/// %0 = ... : tensor<?x?xf32>
-/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
-/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-/// ...
-/// }
-/// ```
-///
-/// is folded to:
-///
-/// ```
-/// %0 = ... : tensor<?x?xf32>
-/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
-/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
-/// ...
-/// }
-/// ```
-struct DimOfForallIterArg : public OpRewritePattern<tensor::DimOp> {
- using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::DimOp dimOp,
- PatternRewriter &rewriter) const final {
- auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
- if (!blockArg)
- return failure();
- auto forallOp =
- dyn_cast<ForallOp>(blockArg.getParentBlock()->getParentOp());
- if (!forallOp)
- return failure();
- Value initArg = forallOp.getTiedLoopInit(blockArg)->get();
- rewriter.modifyOpInPlace(
- dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
-
- return success();
- }
-};
-
class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
public:
using OpRewritePattern<ForallOp>::OpRewritePattern;
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 85a4853972457c..ef8b80f6b5c22a 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -71,3 +71,31 @@ func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
%1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: @iter_to_init_arg_loop_like
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK-SAME: shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
+// CHECK-NEXT: %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>
+func.func @iter_to_init_arg_loop_like(
+ %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+
+ %result = scf.forall (%i) = (%c0) to (%dim0)
+ step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+
+ %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
+ %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
+ : tensor<?x?xf32> to tensor<1x?xf32>
+
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
+ : tensor<1x?xf32> into tensor<?x?xf32>
+ }
+ }
+ return %result : tensor<?x?xf32>
+}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index e203a517932237..284e141d87bcc3 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1891,32 +1891,3 @@ func.func @index_switch_fold_no_res() {
// CHECK-LABEL: func.func @index_switch_fold_no_res()
// CHECK-NEXT: "test.op"() : () -> ()
-
-// -----
-
-func.func @forall_iter_to_init_arg(
- %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
-
- %result = scf.forall (%i) = (%c0) to (%dim0)
- step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
-
- %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
- %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
- : tensor<?x?xf32> to tensor<1x?xf32>
-
- scf.forall.in_parallel {
- tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
- : tensor<1x?xf32> into tensor<?x?xf32>
- }
- }
-
- return %result : tensor<?x?xf32>
-}
-// CHECK-LABEL: @forall_iter_to_init_arg
-// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-// CHECK: %[[RESULT:.*]] = scf.forall
-// CHECK-SAME: shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
-// CHECK-NEXT: %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>
More information about the Mlir-commits
mailing list