[Mlir-commits] [mlir] bdcf4b9 - [MLIR][Linalg] Make detensoring cost-model more flexible.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 20 01:21:37 PDT 2021


Author: KareemErgawy-TomTom
Date: 2021-09-20T10:21:31+02:00
New Revision: bdcf4b9b9620afe24d17132027a7d12e2f1a598b

URL: https://github.com/llvm/llvm-project/commit/bdcf4b9b9620afe24d17132027a7d12e2f1a598b
DIFF: https://github.com/llvm/llvm-project/commit/bdcf4b9b9620afe24d17132027a7d12e2f1a598b.diff

LOG: [MLIR][Linalg] Make detensoring cost-model more flexible.

So far, the CF cost-model for detensoring was limited to discovering
pure CF structures. This means, if while discovering the CF component,
the cost-model found any op that is not detensorable, it gives up on
detensoring altogether. This patch makes it a bit more flexible by
cleaning-up the detensorable component from non-detensorable ops without
giving up entirely.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D109965

Added: 
    mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp

Removed: 
    mlir/test/Dialect/Linalg/detensorize_while_failure.mlir


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 8fe6ac5be9806..08278bc017b7c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -272,25 +272,16 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
 
   /// Detensorize linalg ops involved in control-flow within a function.
   ///
-  /// This model starts from CondBranchOps within a function. For each cond_br,
-  /// the model then walks the use-def chain for the branch's condition
-  /// backwards in order to understand where the condition's value comes from.
-  /// If the condition value is (indirectly) computed by a linalg op that can be
-  /// detensored, the model then continues walking the use-def chain in order to
-  /// understand where the linalg op's operands come from. This leads to
-  /// discovering a "detensoring component". A detensoring component is the set
-  /// of operations + block arguments that are involved in control-flow AND can
-  /// be detensored.
-  ///
-  /// For examples where this model succeeds to discover a detensoring
-  /// component, see:
-  /// - test/Dialect/Linalg/detensorize_while.mlir
-  /// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir.
-  ///
-  /// For an example where this model marks control-flow as "non-detensorable",
-  /// see:
-  /// - test/Dialect/Linalg/detensorize_while_failure.mlir
-  class PureControlFlowDetectionModel : public CostModel {
+  /// This model starts from BranchOps and CondBranchOps within a function. For
+  /// each such branch, the model then walks the use-def chain for the branch's
+  /// condition backwards in order to understand where the condition's value
+  /// comes from. If the condition value is (indirectly) computed by a linalg op
+  /// that can be detensored, the model then continues walking the use-def chain
+  /// in order to understand where the linalg op's operands come from. This
+  /// leads to discovering a "detensoring component". A detensoring component is
+  /// the set of operations + block arguments that are involved in control-flow
+  /// AND can be detensored.
+  class ControlFlowDetectionModel : public CostModel {
   public:
     void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
                  DenseSet<Operation *> &opsToDetensor,
@@ -376,19 +367,19 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
 
           for (PredecessorIterator pred = ownerBlock->pred_begin();
                pred != ownerBlock->pred_end(); ++pred) {
-            BranchOpInterface terminator =
+            BranchOpInterface predTerminator =
                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
 
             // TODO: For now, we give up if any of the control-flow components
             // in a function is not detensorable. Fix that.
-            if (!terminator) {
+            if (!predTerminator) {
               opsToDetensor.clear();
               blockArgsToDetensor.clear();
               return;
             }
 
             auto ownerBlockOperands =
-                terminator.getSuccessorOperands(pred.getSuccessorIndex());
+                predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
 
             if (!ownerBlockOperands || ownerBlockOperands->empty())
               continue;
@@ -418,12 +409,10 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
           if (opsToDetensor.count(genericOp))
             continue;
 
-          // TODO: For now, we give up if any of the control-flow components
-          // in a function is not detensorable. Fix that.
+          // The op should not be detensored, give up on it but continue with
+          // discovering the rest of the control-flow component.
           if (!shouldBeDetensored(genericOp, typeConverter)) {
-            opsToDetensor.clear();
-            blockArgsToDetensor.clear();
-            return;
+            continue;
           }
 
           opsToDetensor.insert(genericOp);
@@ -452,6 +441,47 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
           for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
             workList.push_back(scalarOpOperand);
       }
+
+      // Since the cost model gives up on some ops (see the details of step 2.2
+      // above), block arguments that correspond to the values produced by those
+      // ops should not be detensored as well.
+
+      DenseSet<BlockArgument> blockArgsToRemove;
+
+      for (auto &blockArg : blockArgsToDetensor) {
+        Block *block = blockArg.getParentBlock();
+
+        // For the potentially detensorable block argument, find the
+        // correpsonding operands in predecessor blocks.
+        for (PredecessorIterator pred = block->pred_begin();
+             pred != block->pred_end(); ++pred) {
+          BranchOpInterface terminator =
+              dyn_cast<BranchOpInterface>((*pred)->getTerminator());
+          auto blockOperands =
+              terminator.getSuccessorOperands(pred.getSuccessorIndex());
+
+          if (!blockOperands || blockOperands->empty())
+            continue;
+
+          Operation *definingOp =
+              terminator
+                  ->getOperand(blockOperands->getBeginOperandIndex() +
+                               blockArg.getArgNumber())
+                  .getDefiningOp();
+
+          // If the operand is defined by a GenericOp that will not be
+          // detensored, then do not detensor the corresponding block argument.
+          if (dyn_cast_or_null<GenericOp>(definingOp) &&
+              opsToDetensor.count(definingOp) == 0) {
+            blockArgsToRemove.insert(blockArg);
+            break;
+          }
+        }
+      }
+
+      for (auto &blockArg : blockArgsToRemove) {
+        blockArgsToDetensor.erase(blockArg);
+      }
     }
   };
 
@@ -487,7 +517,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
                         blockArgsToDetensor);
 
     } else {
-      PureControlFlowDetectionModel costModel;
+      ControlFlowDetectionModel costModel;
       costModel.compute(getFunction(), typeConverter, opsToDetensor,
                         blockArgsToDetensor);
     }

diff  --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
similarity index 93%
rename from mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
rename to mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index b7db6adfbfe49..5b8bd7e578ab4 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -93,15 +93,14 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
 // DET-ALL:         return %{{.*}} : tensor<i32>
 // DET-ALL:       }
 
-// Try to detensor pure control-flow. However, that fails since the potential
-// detensorable component contains some ops that cannot be detensored.
-//
 // DET-CF-LABEL: func @main
 // DET-CF-SAME:    (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
 // DET-CF:         br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
 // DET-CF:       ^bb1(%{{.*}}: tensor<10xi32>)
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
-// DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>) outs(%{{.*}} : tensor<i1>) {
+// DET-CF:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-CF:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-CF:         cmpi slt, %{{.*}}, %{{.*}} : i32
 // DET-CF:         cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
 // DET-CF:       ^bb2(%{{.*}}: tensor<i32>)
 // DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {


        


More information about the Mlir-commits mailing list