[Mlir-commits] [mlir] 08dbed8 - [mlir][linalg] Canonicalize dim ops of tiled_loop block args

Matthias Springer llvmlistbot at llvm.org
Wed Aug 18 19:27:26 PDT 2021


Author: Matthias Springer
Date: 2021-08-19T11:24:33+09:00
New Revision: 08dbed8a5725087a7be71c293e4d86ca0c4c0510

URL: https://github.com/llvm/llvm-project/commit/08dbed8a5725087a7be71c293e4d86ca0c4c0510
DIFF: https://github.com/llvm/llvm-project/commit/08dbed8a5725087a7be71c293e4d86ca0c4c0510.diff

LOG: [mlir][linalg] Canonicalize dim ops of tiled_loop block args

E.g.:
```
%y = ... : tensor<...>
linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
  tensor.dim %x, %c0 : tensor<...>
}
```

is rewritten to:
```
%y = ... : tensor<...>
linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
  tensor.dim %y, %c0 : tensor<...>
}
```

Differential Revision: https://reviews.llvm.org/D108272

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a58d91e98b71..6b65a9ecd9e5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2064,6 +2064,57 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
   }
 };
 
+/// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
+/// to dim(y) where `y` is the initial input/output value of the argument.
+///
+/// E.g.:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
+///   tensor.dim %x, %c0 : tensor<...>
+/// }
+///
+/// is folded to:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
+///   tensor.dim %y, %c0 : tensor<...>
+/// }
+template <typename OpTy>
+struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy dimOp,
+                                PatternRewriter &rewriter) const final {
+    auto src = dimOp.source().template dyn_cast<BlockArgument>();
+    if (!src)
+      return failure();
+    auto loopOp = dyn_cast<TiledLoopOp>(
+        src.getOwner()->getParent()->getParentOp());
+    if (!loopOp)
+      return failure();
+
+    auto inputArgs = loopOp.getRegionInputArgs();
+    auto it1 = llvm::find(inputArgs, src);
+    if (it1 != inputArgs.end()) {
+      rewriter.updateRootInPlace(dimOp, [&] {
+        dimOp.sourceMutable().assign(loopOp.inputs()[it1 - inputArgs.begin()]);
+      });
+      return success();
+    }
+
+    auto outputArgs = loopOp.getRegionOutputArgs();
+    auto it2 = llvm::find(outputArgs, src);
+    if (it2 != outputArgs.end()) {
+      rewriter.updateRootInPlace(dimOp, [&] {
+        dimOp.sourceMutable().assign(
+            loopOp.outputs()[it2 - outputArgs.begin()]);
+      });
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 // Folds away TiledLoopOp output tensors when the following conditions are met:
 // * result of `linalg.tiled_loop` has no uses
 // * output tensor is the argument of `linalg.yield`
@@ -2167,7 +2218,9 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
 
 void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
-  results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder>(context);
+  results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder,
+                 DimOfTiledLoopInsOutsFolder<tensor::DimOp>,
+                 DimOfTiledLoopInsOutsFolder<memref::DimOp>>(context);
 }
 
 LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 34dacd58a51e..41a4bfe9c980 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -919,3 +919,30 @@ func @dim_of_pad_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
   %r = tensor.dim %0, %c0 : tensor<?x?xf32>
   return %r : index
 }
+
+// -----
+
+// CHECK-LABEL: func @dim_of_tiled_loop_input(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   linalg.tiled_loop
+//       CHECK:     %[[dim:.*]] = tensor.dim %[[arg1]], %[[c0]]
+//       CHECK:     index_cast %[[dim]]
+func @dim_of_tiled_loop_input(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+      to (%d0, %d1) step (%c1, %c1)
+      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+      outs (%out1 = %arg2 : tensor<?x?xf32>) {
+    %inner_dim = tensor.dim %in1, %c0 : tensor<?x?xf32>
+    %cast1 = std.index_cast %inner_dim : index to i32
+    %cast2 = std.sitofp %cast1 : i32 to f32
+    %fill = linalg.fill(%cast2, %out1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+    linalg.yield %fill : tensor<?x?xf32>
+  }
+  return %r : tensor<?x?xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index de7c0b7c4820..01dab480d89c 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -121,8 +121,8 @@ module {
 // TLOOP:    %[[AB_SUB:.*]] = linalg.matmul
 // TLOOP-SAME:  ins(%[[A_SUB]], %[[B_]] : {{.*}}) outs(%[[AB_INIT_SUB]]
 
-// TLOOP:    %[[DIM_B_1:.*]] = tensor.dim %[[B_]], %[[C1]] : [[TY]]
-// TLOOP:    %[[DIM_C_1:.*]] = tensor.dim %[[C_]], %[[C1]] : [[TY]]
+// TLOOP:    %[[DIM_B_1:.*]] = tensor.dim %[[B]], %[[C1]] : [[TY]]
+// TLOOP:    %[[DIM_C_1:.*]] = tensor.dim %[[C]], %[[C1]] : [[TY]]
 
 // TLOOP:    %[[ABC_SUB_:.*]] = linalg.tiled_loop (%[[IV1:.*]], %[[IV2:.*]]) = 
 // TLOOP-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_C_1]], %[[DIM_B_1]])
@@ -300,7 +300,7 @@ module {
 // TLOOP-SAME:      %[[C0_F32_:.*]] = %[[C0_F32]]
 // TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
 
-// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
 // TLOOP:    %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
 // TLOOP:    %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
 // TLOOP:    %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]
@@ -371,7 +371,7 @@ module {
 // TLOOP-SAME:      %[[C0_F32_:.*]] = %[[C0_F32]]
 // TLOOP-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: [[TY]]) {
 
-// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A_]], %[[C1]] : [[TY]]
+// TLOOP:    %[[DIM_A__1:.*]] = tensor.dim %[[A]], %[[C1]] : [[TY]]
 // TLOOP:    %[[A_SUB:.*]] = tensor.extract_slice %[[A_]][%[[I]], 0]
 // TLOOP:    %[[B_SUB:.*]] = tensor.extract_slice %[[B_]][0, %[[J]]]
 // TLOOP:    %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]]


        


More information about the Mlir-commits mailing list