[Mlir-commits] [mlir] 0b05207 - [MLIR][LinAlg] Detensoring CF cost-model: look forward.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 20 00:02:06 PDT 2021
Author: KareemErgawy-TomTom
Date: 2021-04-20T09:01:43+02:00
New Revision: 0b05207e45ef01899f5bdae5e35a1a93fa2f564f
URL: https://github.com/llvm/llvm-project/commit/0b05207e45ef01899f5bdae5e35a1a93fa2f564f
DIFF: https://github.com/llvm/llvm-project/commit/0b05207e45ef01899f5bdae5e35a1a93fa2f564f.diff
LOG: [MLIR][LinAlg] Detensoring CF cost-model: look forward.
This patch extends the control-flow cost-model for detensoring by
implementing a forward-looking pass on block arguments that should be
detensored. This makes sure that if a (to-be-detensored) block argument
"escapes" its block through the terminator, then the successor arguments
are also detensored.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D100457
Added:
mlir/test/Dialect/Linalg/detensorize_0d.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/test/Dialect/Linalg/detensorize_if.mlir
mlir/test/Dialect/Linalg/detensorize_while.mlir
Removed:
mlir/test/Dialect/Linalg/detensorized_0d.mlir
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index c23b84cf1f62..4ca7da6fd22b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -300,14 +300,56 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
DenseSet<Value> visitedValues;
DenseSet<Operation *> visitedOps;
+ // For a (to-be-detesored) value, check if it "escapes" the block by being
+ // passed to terminator. If it does, then workList is updated with the
+ // corresponding argument to the successor block.
+ auto updateWorkListWithSuccessorArguments =
+ [&](Value value, BranchOpInterface terminator) {
+ if (!terminator)
+ return;
+
+ for (auto operandIdx :
+ llvm::seq<unsigned>(0, terminator->getOperands().size())) {
+ Value operand = terminator->getOperand(operandIdx);
+
+ if (operand == value) {
+ auto succBlockArg =
+ terminator.getSuccessorBlockArgument(operandIdx);
+
+ if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
+ workList.push_back(*succBlockArg);
+ }
+ }
+ };
+
while (!workList.empty()) {
Value currentItem = workList.pop_back_val();
if (!visitedValues.insert(currentItem).second)
continue;
- // The current item is defined by a block argument.
- if (auto bbarg = currentItem.dyn_cast<BlockArgument>()) {
+ // 1 - Look forward:
+ // 1.1 - If currentItem escapes to one or more successors, add
+ // the corresponding successor arguments to workList.
+ updateWorkListWithSuccessorArguments(
+ currentItem, dyn_cast<BranchOpInterface>(
+ currentItem.getParentBlock()->getTerminator()));
+
+ // 1.2 - For each user of currentItem, add the defined values to
+ // workList. This way, the user ops can be inspected later if they are
+ // detensorable and if so, their operands will be added to workList to
+ // potentially discover other parts of the detensorable component.
+ for (auto *user : currentItem.getUsers())
+ for (Value result : user->getResults())
+ workList.push_back(result);
+
+ // 2 - Look backward:
+ // 2.1 - The current item is defined by a block argument. If the owner
+ // block is a non-entry one, then:
+ // * Add the argument to blockArgsToDetensor.
+ // * Walk the use-def chain backwards to add each predecessor's
+ // terminator-operands corresponding to currentItem to workList.
+ if (currentItem.dyn_cast<BlockArgument>()) {
BlockArgument currentItemBlockArgument =
currentItem.cast<BlockArgument>();
Block *ownerBlock = currentItemBlockArgument.getOwner();
@@ -354,7 +396,11 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
if (!visitedOps.insert(currentItemDefiningOp).second)
continue;
- // The current item is computed by a GenericOp.
+ // 2.2 - The current item is computed by a GenericOp. If the op should
+ // be detensored, then:
+ // * Add it to opsToDetensor.
+ // * Add its operands to workList to discover other parts of the
+ // potentially detensorable component.
if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
// The op was encountered already, no need to inspect it again.
if (opsToDetensor.count(genericOp))
@@ -376,7 +422,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
continue;
}
- // The current item is the result of a FromElemntsOp, it will be
+ // 2.3 - The current item is the result of a FromElementsOp, it will be
// trivially detensored later as part of canonicalization patterns
// applied at the end of detensoring.
//
@@ -386,8 +432,8 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
continue;
- // The current item is the result of a scalar op, add all its operands
- // to the work list.
+ // 2.4 - The current item is the result of a scalar op, add all its
+ // operands to the work list.
if (llvm::all_of(
currentItemDefiningOp->getResultTypes(),
[&](Type resultType) { return resultType.isIntOrFloat(); }))
@@ -442,8 +488,8 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
// A function is legal if all of its non-entry blocks are legal. We
- // don't legalize the entry block (i.e. the function's signature) since
- // detensoring can't happen along external calling convention
+ // don't legalize the entry block (i.e. the function's signature)
+ // since detensoring can't happen along external calling convention
// boundaries, which we conservatively approximate as all function
// signatures.
return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) {
diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
similarity index 100%
rename from mlir/test/Dialect/Linalg/detensorized_0d.mlir
rename to mlir/test/Dialect/Linalg/detensorize_0d.mlir
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 05a3720b5089..c5f7eaebbba4 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
#map0 = affine_map<() -> ()>
@@ -48,18 +48,149 @@ func @main() -> (tensor<i32>) attributes {} {
// CHECK-NEXT: constant 10
// CHECK-NEXT: br ^[[bb1:.*]](%{{.*}}: i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: tensor.from_elements %{{.*}}
-// CHECK-NEXT: linalg.tensor_reshape %{{.*}}
// CHECK-NEXT: cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: tensor<i32>)
-// CHECK-NEXT: linalg.init_tensor
-// CHECK-NEXT: linalg.generic
-// CHECK-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32)
-// CHECK-NEXT: addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: linalg.yield %{{.*}}
-// CHECK-NEXT: } -> tensor<i32>
-// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : tensor<i32>)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: tensor<i32>)
+// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT: addi %{{.*}}, %{{.*}}
+// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32>
+// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT: return %{{.*}}
+// CHECK-NEXT: }
+
+// -----
+
+// Similar to the above test with one change: one of the block after the
+// if-condition passes/forwards its tensor argument to another block.
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = []
+}
+
+func @main() -> (tensor<i32>) attributes {} {
+ %c0 = constant 0 : i32
+ %0 = tensor.from_elements %c0 : tensor<1xi32>
+ %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+ %c10 = constant 10 : i32
+ %1 = tensor.from_elements %c10 : tensor<1xi32>
+ %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+ br ^bb1(%reshaped0 : tensor<i32>)
+
+^bb1(%2: tensor<i32>): // 2 preds: ^bb0, ^bb2
+ %3 = linalg.init_tensor [] : tensor<i1>
+ %4 = linalg.generic #attrs
+ ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+ outs(%3 : tensor<i1>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors
+ %8 = cmpi slt, %arg0, %arg1 : i32
+ linalg.yield %8 : i1
+ } -> tensor<i1>
+ %5 = tensor.extract %4[] : tensor<i1>
+ cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
+
+^bb2(%6: tensor<i32>): // pred: ^bb1
+ %7 = linalg.init_tensor [] : tensor<i32>
+ %8 = linalg.generic #attrs
+ ins(%6, %6 : tensor<i32>, tensor<i32>)
+ outs(%7 : tensor<i32>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors
+ %9 = addi %arg0, %arg1 : i32
+ linalg.yield %9 : i32
+ } -> tensor<i32>
+ br ^bb3(%8 : tensor<i32>)
+
+^bb3(%10: tensor<i32>): // pred: ^bb1
+ br ^bb4(%10 : tensor<i32>)
+
+^bb4(%11: tensor<i32>): // pred: ^bb1
+ return %11 : tensor<i32>
+}
+
+// CHECK-LABEL: func @main()
+// CHECK-NEXT: constant 0
+// CHECK-NEXT: constant 10
+// CHECK-NEXT: br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT: cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT: addi %{{.*}}, %{{.*}}
+// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT: br ^[[bb4:.*]](%{{.*}} : i32)
+// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32)
+// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32>
+// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT: return %{{.*}}
+// CHECK-NEXT: }
+
+// -----
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = []
+}
+
+func @main() -> (tensor<i32>) attributes {} {
+ %c0 = constant 0 : i32
+ %0 = tensor.from_elements %c0 : tensor<1xi32>
+ %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+ %c10 = constant 10 : i32
+ %1 = tensor.from_elements %c10 : tensor<1xi32>
+ %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+ br ^bb1(%reshaped0 : tensor<i32>)
+
+^bb1(%2: tensor<i32>): // 2 preds: ^bb0, ^bb2
+ %3 = linalg.init_tensor [] : tensor<i1>
+ %4 = linalg.generic #attrs
+ ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+ outs(%3 : tensor<i1>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors
+ %8 = cmpi slt, %arg0, %arg1 : i32
+ linalg.yield %8 : i1
+ } -> tensor<i1>
+ %5 = tensor.extract %4[] : tensor<i1>
+ // This cond_br intentionally has bb2 as it's target for both branches. This
+ // is to make sure that the "forward phase" of the cost-model correctly adds
+ // the users of a block argument (in this case bb2's argument) to the work
+ // list.
+ cond_br %5, ^bb2(%2 : tensor<i32>), ^bb2(%2 : tensor<i32>)
+
+^bb2(%6: tensor<i32>): // pred: ^bb1
+ %12 = tensor.from_elements %c10 : tensor<1xi32>
+ %reshaped12 = linalg.tensor_reshape %12 [] : tensor<1xi32> into tensor<i32>
+ %7 = linalg.init_tensor [] : tensor<i32>
+ %8 = linalg.generic #attrs
+ ins(%6, %reshaped12 : tensor<i32>, tensor<i32>)
+ outs(%7 : tensor<i32>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors
+ %9 = addi %arg0, %arg1 : i32
+ linalg.yield %9 : i32
+ } -> tensor<i32>
+ br ^bb3(%8 : tensor<i32>)
+
+^bb3(%10: tensor<i32>): // pred: ^bb1
+ return %10 : tensor<i32>
+}
+
+// CHECK-LABEL: func @main()
+// CHECK-NEXT: constant 0
+// CHECK-NEXT: constant 10
+// CHECK-NEXT: br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT: cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
+// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT: addi %{{.*}}, %{{.*}}
+// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<1xi32>
+// CHECK-NEXT: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// CHECK-NEXT: return %{{.*}}
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index 72390f0d7608..4b4cf3aa98bf 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -62,12 +62,12 @@ func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {
// DET-CF: tensor.extract {{.*}}
// DET-CF: br ^[[bb1:.*]](%{{.*}} : i32)
// DET-CF: ^[[bb1]](%{{.*}}: i32)
-// DET-CF-DAG tensor.from_elements {{.*}}
-// DET-CF-DAG: linalg.tensor_reshape {{.*}}
-// DET-CF-DAG: cmpi slt, {{.*}}
-// DET-CF: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : tensor<i32>)
+// DET-CF: cmpi slt, {{.*}}
+// DET-CF: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
// DET-CF: ^[[bb2]](%{{.*}}: i32)
// DET-CF: addi {{.*}}
// DET-CF: br ^[[bb1]](%{{.*}} : i32)
-// DET-CF: ^[[bb3]](%{{.*}}: tensor<i32>)
+// DET-CF: ^[[bb3]](%{{.*}}: i32)
+// DET-CF: tensor.from_elements %{{.*}} : tensor<1xi32>
+// DET-CF: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
// DET-CF: return %{{.*}} : tensor<i32>
More information about the Mlir-commits
mailing list