[Mlir-commits] [mlir] f7137da - [mlir][linalg] Fix dim(iter_arg) canonicalization

Matthias Springer llvmlistbot at llvm.org
Wed Sep 8 20:18:17 PDT 2021


Author: Matthias Springer
Date: 2021-09-09T12:13:05+09:00
New Revision: f7137da174a43dd5460b72a21fbcaac7b7b74d7a

URL: https://github.com/llvm/llvm-project/commit/f7137da174a43dd5460b72a21fbcaac7b7b74d7a
DIFF: https://github.com/llvm/llvm-project/commit/f7137da174a43dd5460b72a21fbcaac7b7b74d7a.diff

LOG: [mlir][linalg] Fix dim(iter_arg) canonicalization

Run a small analysis to see if the runtime type of the iter_arg is changing. Fold only if the runtime type stays the same. (Same as `DimOfIterArgFolder` in SCF.)

Differential Revision: https://reviews.llvm.org/D109299

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e2b9436f4a10e..a1c0c996e332c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -31,6 +31,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
@@ -2299,10 +2300,47 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
 /// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
 ///   tensor.dim %y, %c0 : tensor<...>
 /// }
+///
+/// Note: Dim ops are folded only if it can be proven that the runtime type of
+/// the yielded value (in case of outputs) does not change with loop iterations.
 template <typename OpTy>
 struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
 
+  /// A simple, conservative analysis to determine if the loop is shape
+  /// conserving. I.e., the type of the arg-th yielded value is the same as the
+  /// type of the corresponding basic block argument of the loop.
+  /// Note: This function handles only simple cases. Expand as needed.
+  static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
+    auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
+    if (yieldOp.values().empty())
+      // Tiled loop either has no outputs or is a "memref-based version". In
+      // either case, the loop is shape conserving.
+      return true;
+    assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
+           "arg is out of bounds");
+    Value value = yieldOp.values()[arg];
+    while (value) {
+      if (value == loopOp.getRegionOutputArgs()[arg])
+        return true;
+      OpResult opResult = value.dyn_cast<OpResult>();
+      if (!opResult)
+        return false;
+
+      using tensor::InsertSliceOp;
+      value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+                  .template Case<InsertSliceOp>(
+                      [&](InsertSliceOp op) { return op.dest(); })
+                  .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
+                    return isShapePreserving(loopOp, opResult.getResultNumber())
+                               ? loopOp.outputs()[opResult.getResultNumber()]
+                               : Value();
+                  })
+                  .Default([&](auto op) { return Value(); });
+    }
+    return false;
+  }
+
   LogicalResult matchAndRewrite(OpTy dimOp,
                                 PatternRewriter &rewriter) const final {
     auto src = dimOp.source().template dyn_cast<BlockArgument>();
@@ -2312,6 +2350,12 @@ struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
         dyn_cast<TiledLoopOp>(src.getOwner()->getParent()->getParentOp());
     if (!loopOp)
       return failure();
+    unsigned numLoops = loopOp.getNumLoops();
+    unsigned numInputArgs = loopOp.getRegionInputArgs().size();
+    if (src.getArgNumber() >= numInputArgs + numLoops &&
+        !isShapePreserving(loopOp,
+                           src.getArgNumber() - numInputArgs - numLoops))
+      return failure();
 
     auto inputArgs = loopOp.getRegionInputArgs();
     auto it1 = llvm::find(inputArgs, src);

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9b6be7920a999..14c7a1ac639df 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -904,6 +904,34 @@ func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
 
 // -----
 
+// CHECK-LABEL: func @dim_of_tiled_loop_input_no_canonicalize(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+//       CHECK:   %[[c0:.*]] = constant 0 : index
+//       CHECK:   linalg.tiled_loop {{.*}} outs (%[[o:.*]] =
+//       CHECK:     %[[dim:.*]] = tensor.dim %[[o]], %[[c0]]
+//       CHECK:     index_cast %[[dim]]
+func @dim_of_tiled_loop_input_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
+    -> tensor<?x?xf32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+      to (%d0, %d1) step (%c1, %c1)
+      ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+      outs (%out1 = %arg2 : tensor<?x?xf32>) {
+    %inner_dim = tensor.dim %out1, %c0 : tensor<?x?xf32>
+    %cast1 = std.index_cast %inner_dim : index to i32
+    %cast2 = std.sitofp %cast1 : i32 to f32
+    %fill = linalg.fill(%cast2, %out1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+    %slice = tensor.extract_slice %fill[0, 0][%s, %s][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+    linalg.yield %slice : tensor<?x?xf32>
+  }
+  return %r : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @dim_of_tiled_loop_input(
 //  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
 //       CHECK:   %[[c0:.*]] = constant 0 : index


        


More information about the Mlir-commits mailing list