[Mlir-commits] [mlir] 0cf8447 - [MLIR][SCF] Fold dim ops of iter_args to respective init_args (#109973)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 26 12:07:12 PDT 2024


Author: Prashant Kumar
Date: 2024-09-27T00:37:08+05:30
New Revision: 0cf844759add057f76ca72a611e692eea191c7b7

URL: https://github.com/llvm/llvm-project/commit/0cf844759add057f76ca72a611e692eea191c7b7
DIFF: https://github.com/llvm/llvm-project/commit/0cf844759add057f76ca72a611e692eea191c7b7.diff

LOG: [MLIR][SCF] Fold dim ops of iter_args to respective init_args (#109973)

Fold dim ops of iter_args to dim ops of their respective init args.
E.g.:

```
 %0 = ... : tensor<?x?xf32>
 scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
   ...
 }
```
 is folded to:

```

 %0 = ... : tensor<?x?xf32>
 scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
   ...
 }
```

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
    mlir/test/Dialect/MemRef/resolve-dim-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 0cb5931ce6bf9b..fb2921fec9f79d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -103,6 +103,44 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     return success();
   }
 };
+
+/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+///
+/// is folded to:
+///
+/// ```
+/// %0 = ... : tensor<?x?xf32>
+/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
+///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
+///   ...
+/// }
+/// ```
+struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
+  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+                                PatternRewriter &rewriter) const final {
+    auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
+    if (!blockArg)
+      return failure();
+    auto loopLikeOp =
+        dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
+    if (!loopLikeOp)
+      return failure();
+    Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
+    rewriter.modifyOpInPlace(
+        dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
+    return success();
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -127,8 +165,8 @@ struct ResolveShapedTypeResultDimsPass final
 void memref::populateResolveRankedShapedTypeResultDimsPatterns(
     RewritePatternSet &patterns) {
   patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
-               DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
-      patterns.getContext());
+               DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
+               IterArgsToInitArgs>(patterns.getContext());
 }
 
 void memref::populateResolveShapedTypeResultDimsPatterns(

diff  --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index 85a4853972457c..ef8b80f6b5c22a 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -71,3 +71,31 @@ func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
   %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: @iter_to_init_arg_loop_like
+//  CHECK-SAME:   (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+//       CHECK:    %[[RESULT:.*]] = scf.forall
+//  CHECK-SAME:                       shared_outs(%[[OUTS:.*]] = %[[ARG1]]) -> (tensor<?x?xf32>) {
+//  CHECK-NEXT:       %{{.*}} = tensor.dim %[[ARG1]], %{{.*}} : tensor<?x?xf32>
+func.func @iter_to_init_arg_loop_like(
+  %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+
+  %result = scf.forall (%i) = (%c0) to (%dim0)
+      step (%c1) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
+
+    %dim1 = tensor.dim %o, %c1 : tensor<?x?xf32>
+    %slice = tensor.extract_slice %arg1[%i, 0] [1, %dim1] [1, 1]
+      : tensor<?x?xf32> to tensor<1x?xf32>
+
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %slice into %o[%i, 0] [1, %dim1] [1, 1]
+        : tensor<1x?xf32> into tensor<?x?xf32>
+    }
+  }
+  return %result : tensor<?x?xf32>
+}


        


More information about the Mlir-commits mailing list