[Mlir-commits] [mlir] bdcf4b9 - [MLIR][Linalg] Make detensoring cost-model more flexible.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 20 01:21:37 PDT 2021
Author: KareemErgawy-TomTom
Date: 2021-09-20T10:21:31+02:00
New Revision: bdcf4b9b9620afe24d17132027a7d12e2f1a598b
URL: https://github.com/llvm/llvm-project/commit/bdcf4b9b9620afe24d17132027a7d12e2f1a598b
DIFF: https://github.com/llvm/llvm-project/commit/bdcf4b9b9620afe24d17132027a7d12e2f1a598b.diff
LOG: [MLIR][Linalg] Make detensoring cost-model more flexible.
So far, the CF cost-model for detensoring was limited to discovering
pure CF structures. This means, if while discovering the CF component,
the cost-model found any op that is not detensorable, it gives up on
detensoring altogether. This patch makes it a bit more flexible by
cleaning-up the detensorable component from non-detensorable ops without
giving up entirely.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D109965
Added:
mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
Removed:
mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 8fe6ac5be9806..08278bc017b7c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -272,25 +272,16 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
/// Detensorize linalg ops involved in control-flow within a function.
///
- /// This model starts from CondBranchOps within a function. For each cond_br,
- /// the model then walks the use-def chain for the branch's condition
- /// backwards in order to understand where the condition's value comes from.
- /// If the condition value is (indirectly) computed by a linalg op that can be
- /// detensored, the model then continues walking the use-def chain in order to
- /// understand where the linalg op's operands come from. This leads to
- /// discovering a "detensoring component". A detensoring component is the set
- /// of operations + block arguments that are involved in control-flow AND can
- /// be detensored.
- ///
- /// For examples where this model succeeds to discover a detensoring
- /// component, see:
- /// - test/Dialect/Linalg/detensorize_while.mlir
- /// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir.
- ///
- /// For an example where this model marks control-flow as "non-detensorable",
- /// see:
- /// - test/Dialect/Linalg/detensorize_while_failure.mlir
- class PureControlFlowDetectionModel : public CostModel {
+ /// This model starts from BranchOps and CondBranchOps within a function. For
+ /// each such branch, the model then walks the use-def chain for the branch's
+ /// condition backwards in order to understand where the condition's value
+ /// comes from. If the condition value is (indirectly) computed by a linalg op
+ /// that can be detensored, the model then continues walking the use-def chain
+ /// in order to understand where the linalg op's operands come from. This
+ /// leads to discovering a "detensoring component". A detensoring component is
+ /// the set of operations + block arguments that are involved in control-flow
+ /// AND can be detensored.
+ class ControlFlowDetectionModel : public CostModel {
public:
void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
DenseSet<Operation *> &opsToDetensor,
@@ -376,19 +367,19 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
for (PredecessorIterator pred = ownerBlock->pred_begin();
pred != ownerBlock->pred_end(); ++pred) {
- BranchOpInterface terminator =
+ BranchOpInterface predTerminator =
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
// TODO: For now, we give up if any of the control-flow components
// in a function is not detensorable. Fix that.
- if (!terminator) {
+ if (!predTerminator) {
opsToDetensor.clear();
blockArgsToDetensor.clear();
return;
}
auto ownerBlockOperands =
- terminator.getSuccessorOperands(pred.getSuccessorIndex());
+ predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
if (!ownerBlockOperands || ownerBlockOperands->empty())
continue;
@@ -418,12 +409,10 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
if (opsToDetensor.count(genericOp))
continue;
- // TODO: For now, we give up if any of the control-flow components
- // in a function is not detensorable. Fix that.
+ // The op should not be detensored, give up on it but continue with
+ // discovering the rest of the control-flow component.
if (!shouldBeDetensored(genericOp, typeConverter)) {
- opsToDetensor.clear();
- blockArgsToDetensor.clear();
- return;
+ continue;
}
opsToDetensor.insert(genericOp);
@@ -452,6 +441,47 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
workList.push_back(scalarOpOperand);
}
+
+ // Since the cost model gives up on some ops (see the details of step 2.2
+ // above), block arguments that correspond to the values produced by those
+ // ops should not be detensored as well.
+
+ DenseSet<BlockArgument> blockArgsToRemove;
+
+ for (auto &blockArg : blockArgsToDetensor) {
+ Block *block = blockArg.getParentBlock();
+
+ // For the potentially detensorable block argument, find the
+ // correpsonding operands in predecessor blocks.
+ for (PredecessorIterator pred = block->pred_begin();
+ pred != block->pred_end(); ++pred) {
+ BranchOpInterface terminator =
+ dyn_cast<BranchOpInterface>((*pred)->getTerminator());
+ auto blockOperands =
+ terminator.getSuccessorOperands(pred.getSuccessorIndex());
+
+ if (!blockOperands || blockOperands->empty())
+ continue;
+
+ Operation *definingOp =
+ terminator
+ ->getOperand(blockOperands->getBeginOperandIndex() +
+ blockArg.getArgNumber())
+ .getDefiningOp();
+
+ // If the operand is defined by a GenericOp that will not be
+ // detensored, then do not detensor the corresponding block argument.
+ if (dyn_cast_or_null<GenericOp>(definingOp) &&
+ opsToDetensor.count(definingOp) == 0) {
+ blockArgsToRemove.insert(blockArg);
+ break;
+ }
+ }
+ }
+
+ for (auto &blockArg : blockArgsToRemove) {
+ blockArgsToDetensor.erase(blockArg);
+ }
}
};
@@ -487,7 +517,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
blockArgsToDetensor);
} else {
- PureControlFlowDetectionModel costModel;
+ ControlFlowDetectionModel costModel;
costModel.compute(getFunction(), typeConverter, opsToDetensor,
blockArgsToDetensor);
}
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
similarity index 93%
rename from mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
rename to mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index b7db6adfbfe49..5b8bd7e578ab4 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -93,15 +93,14 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
// DET-ALL: return %{{.*}} : tensor<i32>
// DET-ALL: }
-// Try to detensor pure control-flow. However, that fails since the potential
-// detensorable component contains some ops that cannot be detensored.
-//
// DET-CF-LABEL: func @main
// DET-CF-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
// DET-CF: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
// DET-CF: ^bb1(%{{.*}}: tensor<10xi32>)
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
-// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>) outs(%{{.*}} : tensor<i1>) {
+// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
+// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
+// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32
// DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
// DET-CF: ^bb2(%{{.*}}: tensor<i32>)
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
More information about the Mlir-commits
mailing list