[Mlir-commits] [mlir] 0cf8447 - [MLIR][SCF] Fold dim ops of iter_args to respective init_args (#109973)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 26 12:07:12 PDT 2024
Author: Prashant Kumar
Date: 2024-09-27T00:37:08+05:30
New Revision: 0cf844759add057f76ca72a611e692eea191c7b7
URL: https://github.com/llvm/llvm-project/commit/0cf844759add057f76ca72a611e692eea191c7b7
DIFF: https://github.com/llvm/llvm-project/commit/0cf844759add057f76ca72a611e692eea191c7b7.diff
LOG: [MLIR][SCF] Fold dim ops of iter_args to respective init_args (#109973)
Fold dim ops of iter_args to dim ops of their respective init args.
E.g.:
```
%0 = ... : tensor<?x?xf32>
scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
%1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
...
}
```
is folded to:
```
%0 = ... : tensor<?x?xf32>
scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
%1 = tensor.dim %0, %c0 : tensor<?x?xf32>
...
}
```
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 0cb5931ce6bf9b..fb2921fec9f79d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -103,6 +103,44 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
return success();
}
};
+
+/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+/// ...
+/// }
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+/// ...
+/// }
+/// ```
+struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
+ using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+ PatternRewriter &rewriter) const final {
+ auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
+ if (!blockArg)
+ return failure();
+ auto loopLikeOp =
+ dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
+ if (!loopLikeOp)
+ return failure();
+ Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+ rewriter.modifyOpInPlace(
+ dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
+ return success();
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -127,8 +165,8 @@ struct ResolveShapedTypeResultDimsPass final
void memref::populateResolveRankedShapedTypeResultDimsPatterns(
RewritePatternSet &patterns) {
patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
- DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
- patterns.getContext());
+ DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
+ IterArgsToInitArgs>(patterns.getContext());
}
void memref::populateResolveShapedTypeResultDimsPatterns(
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 85a4853972457c..ef8b80f6b5c22a 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -71,3 +71,31 @@ func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
%1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: @iter_to_init_arg_loop_like
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[RESULT:.*]] = scf.forall
+// CHECK-SAME: shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
+// CHECK-NEXT: %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>
+func.func @iter_to_init_arg_loop_like(
+ %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+
+ %result = scf.forall (%i) = (%c0) to (%dim0)
+ step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+
+ %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
+ %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
+ : tensor<?x?xf32> to tensor<1x?xf32>
+
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
+ : tensor<1x?xf32> into tensor<?x?xf32>
+ }
+ }
+ return %result : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list