[Mlir-commits] [mlir] f7137da - [mlir][linalg] Fix dim(iter_arg) canonicalization
Matthias Springer
llvmlistbot at llvm.org
Wed Sep 8 20:18:17 PDT 2021
Author: Matthias Springer
Date: 2021-09-09T12:13:05+09:00
New Revision: f7137da174a43dd5460b72a21fbcaac7b7b74d7a
URL: https://github.com/llvm/llvm-project/commit/f7137da174a43dd5460b72a21fbcaac7b7b74d7a
DIFF: https://github.com/llvm/llvm-project/commit/f7137da174a43dd5460b72a21fbcaac7b7b74d7a.diff
LOG: [mlir][linalg] Fix dim(iter_arg) canonicalization
Run a small analysis to see if the runtime type of the iter_arg is changing. Fold only if the runtime type stays the same. (Same as `DimOfIterArgFolder` in SCF.)
Differential Revision: https://reviews.llvm.org/D109299
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e2b9436f4a10e..a1c0c996e332c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -31,6 +31,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
@@ -2299,10 +2300,47 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
/// tensor.dim %y, %c0 : tensor<...>
/// }
+///
+/// Note: Dim ops are folded only if it can be proven that the runtime type of
+/// the yielded value (in case of outputs) does not change with loop iterations.
template <typename OpTy>
struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
+ /// A simple, conservative analysis to determine if the loop is shape
+ /// conserving. I.e., the type of the arg-th yielded value is the same as the
+ /// type of the corresponding basic block argument of the loop.
+ /// Note: This function handles only simple cases. Expand as needed.
+ static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
+ auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
+ if (yieldOp.values().empty())
+ // Tiled loop either has no outputs or is a "memref-based version". In
+ // either case, the loop is shape conserving.
+ return true;
+ assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
+ "arg is out of bounds");
+ Value value = yieldOp.values()[arg];
+ while (value) {
+ if (value == loopOp.getRegionOutputArgs()[arg])
+ return true;
+ OpResult opResult = value.dyn_cast<OpResult>();
+ if (!opResult)
+ return false;
+
+ using tensor::InsertSliceOp;
+ value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
+ .template Case<InsertSliceOp>(
+ [&](InsertSliceOp op) { return op.dest(); })
+ .template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
+ return isShapePreserving(loopOp, opResult.getResultNumber())
+ ? loopOp.outputs()[opResult.getResultNumber()]
+ : Value();
+ })
+ .Default([&](auto op) { return Value(); });
+ }
+ return false;
+ }
+
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const final {
auto src = dimOp.source().template dyn_cast<BlockArgument>();
@@ -2312,6 +2350,12 @@ struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
dyn_cast<TiledLoopOp>(src.getOwner()->getParent()->getParentOp());
if (!loopOp)
return failure();
+ unsigned numLoops = loopOp.getNumLoops();
+ unsigned numInputArgs = loopOp.getRegionInputArgs().size();
+ if (src.getArgNumber() >= numInputArgs + numLoops &&
+ !isShapePreserving(loopOp,
+ src.getArgNumber() - numInputArgs - numLoops))
+ return failure();
auto inputArgs = loopOp.getRegionInputArgs();
auto it1 = llvm::find(inputArgs, src);
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 9b6be7920a999..14c7a1ac639df 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -904,6 +904,34 @@ func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
// -----
+// CHECK-LABEL: func @dim_of_tiled_loop_input_no_canonicalize(
+// CHECK-SAME: %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
+// CHECK: %[[c0:.*]] = constant 0 : index
+// CHECK: linalg.tiled_loop {{.*}} outs (%[[o:.*]] =
+// CHECK: %[[dim:.*]] = tensor.dim %[[o]], %[[c0]]
+// CHECK: index_cast %[[dim]]
+func @dim_of_tiled_loop_input_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>, %s: index)
+ -> tensor<?x?xf32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %r = linalg.tiled_loop (%iv0, %iv1) = (%c0, %c0)
+ to (%d0, %d1) step (%c1, %c1)
+ ins (%in0 = %arg0 : tensor<?x?xf32>, %in1 = %arg1 : tensor<?x?xf32>)
+ outs (%out1 = %arg2 : tensor<?x?xf32>) {
+ %inner_dim = tensor.dim %out1, %c0 : tensor<?x?xf32>
+ %cast1 = std.index_cast %inner_dim : index to i32
+ %cast2 = std.sitofp %cast1 : i32 to f32
+ %fill = linalg.fill(%cast2, %out1) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
+ %slice = tensor.extract_slice %fill[0, 0][%s, %s][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ linalg.yield %slice : tensor<?x?xf32>
+ }
+ return %r : tensor<?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @dim_of_tiled_loop_input(
// CHECK-SAME: %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
// CHECK: %[[c0:.*]] = constant 0 : index
More information about the Mlir-commits
mailing list