[Mlir-commits] [mlir] e2c8fcb - [mlir][linalg] Fold dim(linalg.tiled_loop) to dim(output_arg)

Matthias Springer llvmlistbot at llvm.org
Wed Sep 8 21:44:44 PDT 2021


Author: Matthias Springer
Date: 2021-09-09T13:37:28+09:00
New Revision: e2c8fcb9d0bd33fda481f7e27cf0d6ebdde2b5b0

URL: https://github.com/llvm/llvm-project/commit/e2c8fcb9d0bd33fda481f7e27cf0d6ebdde2b5b0
DIFF: https://github.com/llvm/llvm-project/commit/e2c8fcb9d0bd33fda481f7e27cf0d6ebdde2b5b0.diff

LOG: [mlir][linalg] Fold dim(linalg.tiled_loop) to dim(output_arg)

Fold dim ops of linalg.tiled_loop results to dim ops of the respective iter args if the loop is shape preserving.

Differential Revision: https://reviews.llvm.org/D109431

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a1c0c996e332c..688f241c36f5a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2286,6 +2286,44 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
   }
 };
 
+} // namespace
+
+/// A simple, conservative analysis to determine if the loop is shape
+/// conserving. I.e., the type of the arg-th yielded value is the same as the
+/// type of the corresponding basic block argument of the loop.
+/// Note: This function handles only simple cases. Expand as needed.
+static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
+  auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
+  if (yieldOp.values().empty())
+    // Tiled loop either has no outputs or is a "memref-based version". In
+    // either case, the loop is shape conserving.
+    return true;
+  assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
+         "arg is out of bounds");
+  Value value = yieldOp.values()[arg];
+  while (value) {
+    if (value == loopOp.getRegionOutputArgs()[arg])
+      return true;
+    OpResult opResult = value.dyn_cast<OpResult>();
+    if (!opResult)
+      return false;
+
+    using tensor::InsertSliceOp;
+    value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+                .template Case<InsertSliceOp>(
+                    [&](InsertSliceOp op) { return op.dest(); })
+                .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
+                  return isShapePreserving(loopOp, opResult.getResultNumber())
+                             ? loopOp.outputs()[opResult.getResultNumber()]
+                             : Value();
+                })
+                .Default([&](auto op) { return Value(); });
+  }
+  return false;
+}
+
+namespace {
+
 /// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
 /// to dim(y) where `y` is the initial input/output value of the argument.
 ///
@@ -2307,40 +2345,6 @@ template <typename OpTy>
 struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
 
-  /// A simple, conservative analysis to determine if the loop is shape
-  /// conserving. I.e., the type of the arg-th yielded value is the same as the
-  /// type of the corresponding basic block argument of the loop.
-  /// Note: This function handles only simple cases. Expand as needed.
-  static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
-    auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
-    if (yieldOp.values().empty())
-      // Tiled loop either has no outputs or is a "memref-based version". In
-      // either case, the loop is shape conserving.
-      return true;
-    assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
-           "arg is out of bounds");
-    Value value = yieldOp.values()[arg];
-    while (value) {
-      if (value == loopOp.getRegionOutputArgs()[arg])
-        return true;
-      OpResult opResult = value.dyn_cast<OpResult>();
-      if (!opResult)
-        return false;
-
-      using tensor::InsertSliceOp;
-      value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
-                  .template Case<InsertSliceOp>(
-                      [&](InsertSliceOp op) { return op.dest(); })
-                  .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
-                    return isShapePreserving(loopOp, opResult.getResultNumber())
-                               ? loopOp.outputs()[opResult.getResultNumber()]
-                               : Value();
-                  })
-                  .Default([&](auto op) { return Value(); });
-    }
-    return false;
-  }
-
   LogicalResult matchAndRewrite(OpTy dimOp,
                                 PatternRewriter &rewriter) const final {
     auto src = dimOp.source().template dyn_cast<BlockArgument>();
@@ -2380,6 +2384,45 @@ struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
   }
 };
 
+/// Fold dim(r) where `r` is the result of a TiledLoopOp to dim(y) where `y`
+/// is the initial output value of the loop.
+///
+/// E.g.:
+/// %y = ... : tensor<...>
+/// %r = linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
+///   ...
+/// }
+/// %0 = tensor.dim %r, %c0 : tensor<...>
+///
+/// is folded to:
+/// %y = ... : tensor<...>
+/// linalg.tiled_loop ... outs(%i = %y : tensor<...>) {
+///   ...
+/// }
+/// %0 = tensor.dim %y, %c0 : tensor<...>
+///
+/// Note: Dim ops are folded only if it can be proven that the runtime type of
+/// the yielded value (in case of outputs) does not change with loop iterations.
+template <typename OpTy>
+struct DimOfTiledLoopResultFolder : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy dimOp,
+                                PatternRewriter &rewriter) const final {
+    auto loopOp = dimOp.source().template getDefiningOp<TiledLoopOp>();
+    if (!loopOp)
+      return failure();
+    auto opResult = dimOp.source().template cast<OpResult>();
+    unsigned resultNumber = opResult.getResultNumber();
+    if (!isShapePreserving(loopOp, resultNumber))
+      return failure();
+    rewriter.updateRootInPlace(dimOp, [&]() {
+      dimOp.sourceMutable().assign(loopOp.outputs()[resultNumber]);
+    });
+    return success();
+  }
+};
+
 // Folds away TiledLoopOp output tensors when the following conditions are met:
 // * result of `linalg.tiled_loop` has no uses
 // * output tensor is the argument of `linalg.yield`
@@ -2485,7 +2528,9 @@ void TiledLoopOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                               MLIRContext *context) {
   results.insert<TiledLoopInputsFolder, TiledLoopResultsFolder,
                  DimOfTiledLoopInsOutsFolder<tensor::DimOp>,
-                 DimOfTiledLoopInsOutsFolder<memref::DimOp>>(context);
+                 DimOfTiledLoopInsOutsFolder<memref::DimOp>,
+                 DimOfTiledLoopResultFolder<tensor::DimOp>,
+                 DimOfTiledLoopResultFolder<memref::DimOp>>(context);
 }
 
 LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 14c7a1ac639df..db915f10e7dde 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -956,3 +956,51 @@ func @dim_of_tiled_loop_input(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %a
   }
   return %r : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @dim_of_tiled_loop_result(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   tensor.dim %[[arg2]], %[[c0]]
+func @dim_of_tiled_loop_result(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
+    -> index {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+      to (%d0, %d1) step (%c1, %c1)
+      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+      outs (%out1 = %arg2 : tensor<?x?xf32>) {
+    %1 = tensor.insert_slice %arg0 into %out1 [0, 0] [%s, %s] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+    linalg.yield %1 : tensor<?x?xf32>
+  }
+  %r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
+  return %r2 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @dim_of_tiled_loop_result_no_canonicalize(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   %[[r:.*]] = linalg.tiled_loop
+//       CHECK:   tensor.dim %[[r]], %[[c0]]
+func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
+    -> index {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+      to (%d0, %d1) step (%c1, %c1)
+      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+      outs (%out1 = %arg2 : tensor<?x?xf32>) {
+    %1 = tensor.insert_slice %arg0 into %arg1 [0, 0] [%s, %s] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+    linalg.yield %1 : tensor<?x?xf32>
+  }
+  %r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
+  return %r2 : index
+}
+


        


More information about the Mlir-commits mailing list