[Mlir-commits] [mlir] 0b05207 - [MLIR][LinAlg] Detensoring CF cost-model: look forward.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 20 00:02:06 PDT 2021


Author: KareemErgawy-TomTom
Date: 2021-04-20T09:01:43+02:00
New Revision: 0b05207e45ef01899f5bdae5e35a1a93fa2f564f

URL: https://github.com/llvm/llvm-project/commit/0b05207e45ef01899f5bdae5e35a1a93fa2f564f
DIFF: https://github.com/llvm/llvm-project/commit/0b05207e45ef01899f5bdae5e35a1a93fa2f564f.diff

LOG: [MLIR][LinAlg] Detensoring CF cost-model: look forward.

This patch extends the control-flow cost-model for detensoring by
implementing a forward-looking pass on block arguments that should be
detensored. This makes sure that if a (to-be-detensored) block argument
"escapes" its block through the terminator, then the successor arguments
are also detensored.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D100457

Added: 
    mlir/test/Dialect/Linalg/detensorize_0d.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/test/Dialect/Linalg/detensorize_if.mlir
    mlir/test/Dialect/Linalg/detensorize_while.mlir

Removed: 
    mlir/test/Dialect/Linalg/detensorized_0d.mlir


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index c23b84cf1f62..4ca7da6fd22b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -300,14 +300,56 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
       DenseSet<Value> visitedValues;
       DenseSet<Operation *> visitedOps;
 
+      // For a (to-be-detesored) value, check if it "escapes" the block by being
+      // passed to terminator. If it does, then workList is updated with the
+      // corresponding argument to the successor block.
+      auto updateWorkListWithSuccessorArguments =
+          [&](Value value, BranchOpInterface terminator) {
+            if (!terminator)
+              return;
+
+            for (auto operandIdx :
+                 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
+              Value operand = terminator->getOperand(operandIdx);
+
+              if (operand == value) {
+                auto succBlockArg =
+                    terminator.getSuccessorBlockArgument(operandIdx);
+
+                if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
+                  workList.push_back(*succBlockArg);
+              }
+            }
+          };
+
       while (!workList.empty()) {
         Value currentItem = workList.pop_back_val();
 
         if (!visitedValues.insert(currentItem).second)
           continue;
 
-        // The current item is defined by a block argument.
-        if (auto bbarg = currentItem.dyn_cast<BlockArgument>()) {
+        // 1   - Look forward:
+        // 1.1 - If currentItem escapes to one or more successors, add
+        // the corresponding successor arguments to workList.
+        updateWorkListWithSuccessorArguments(
+            currentItem, dyn_cast<BranchOpInterface>(
+                             currentItem.getParentBlock()->getTerminator()));
+
+        // 1.2 - For each user of currentItem, add the defined values to
+        // workList. This way, the user ops can be inspected later if they are
+        // detensorable and if so, their operands will be added to workList to
+        // potentially discover other parts of the detensorable component.
+        for (auto *user : currentItem.getUsers())
+          for (Value result : user->getResults())
+            workList.push_back(result);
+
+        // 2   - Look backward:
+        // 2.1 - The current item is defined by a block argument. If the owner
+        // block is a non-entry one, then:
+        //       * Add the argument to blockArgsToDetensor.
+        //       * Walk the use-def chain backwards to add each predecessor's
+        //       terminator-operands corresponding to currentItem to workList.
+        if (currentItem.dyn_cast<BlockArgument>()) {
           BlockArgument currentItemBlockArgument =
               currentItem.cast<BlockArgument>();
           Block *ownerBlock = currentItemBlockArgument.getOwner();
@@ -354,7 +396,11 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
         if (!visitedOps.insert(currentItemDefiningOp).second)
           continue;
 
-        // The current item is computed by a GenericOp.
+        // 2.2 - The current item is computed by a GenericOp. If the op should
+        // be detensored, then:
+        //       * Add it to opsToDetensor.
+        //       * Add its operands to workList to discover other parts of the
+        //       potentially detensorable component.
         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
           // The op was encountered already, no need to inspect it again.
           if (opsToDetensor.count(genericOp))
@@ -376,7 +422,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
           continue;
         }
 
-        // The current item is the result of a FromElemntsOp, it will be
+        // 2.3 - The current item is the result of a FromElementsOp, it will be
         // trivially detensored later as part of canonicalization patterns
         // applied at the end of detensoring.
         //
@@ -386,8 +432,8 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
         if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
           continue;
 
-        // The current item is the result of a scalar op, add all its operands
-        // to the work list.
+        // 2.4 - The current item is the result of a scalar op, add all its
+        // operands to the work list.
         if (llvm::all_of(
                 currentItemDefiningOp->getResultTypes(),
                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
@@ -442,8 +488,8 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
 
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
       // A function is legal if all of its non-entry blocks are legal. We
-      // don't legalize the entry block (i.e. the function's signature) since
-      // detensoring can't happen along external calling convention
+      // don't legalize the entry block (i.e. the function's signature)
+      // since detensoring can't happen along external calling convention
       // boundaries, which we conservatively approximate as all function
       // signatures.
       return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) {

diff  --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
similarity index 100%
rename from mlir/test/Dialect/Linalg/detensorized_0d.mlir
rename to mlir/test/Dialect/Linalg/detensorize_0d.mlir

diff  --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 05a3720b5089..c5f7eaebbba4 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
 
 #map0 = affine_map<() -> ()>
 
@@ -48,18 +48,149 @@ func @main() -> (tensor<i32>) attributes {} {
 // CHECK-NEXT:     constant 10
 // CHECK-NEXT:     br ^[[bb1:.*]](%{{.*}}: i32)
 // CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     tensor.from_elements %{{.*}}
-// CHECK-NEXT:     linalg.tensor_reshape %{{.*}}
 // CHECK-NEXT:     cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: tensor<i32>)
-// CHECK-NEXT:     linalg.init_tensor
-// CHECK-NEXT:     linalg.generic
-// CHECK-NEXT:     ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32)
-// CHECK-NEXT:       addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:       linalg.yield %{{.*}}
-// CHECK-NEXT:     } -> tensor<i32>
-// CHECK-NEXT:     br ^[[bb3:.*]](%{{.*}} : tensor<i32>)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: tensor<i32>)
+// CHECK-NEXT:     cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<1xi32>
+// CHECK-NEXT:     linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
+// CHECK-NEXT:   }
+
+// -----
+
+// Similar to the above test with one change: one of the block after the
+// if-condition passes/forwards its tensor argument to another block.
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+func @main() -> (tensor<i32>) attributes {} {
+  %c0 = constant 0 : i32
+  %0 = tensor.from_elements %c0 : tensor<1xi32>
+  %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+  %c10 = constant 10 : i32
+  %1 = tensor.from_elements %c10 : tensor<1xi32>
+  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  br ^bb1(%reshaped0 : tensor<i32>)
+
+^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
+  %3 = linalg.init_tensor [] : tensor<i1>
+  %4 = linalg.generic #attrs
+    ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+    outs(%3 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %5 = tensor.extract %4[] : tensor<i1>
+  cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
+
+^bb2(%6: tensor<i32>):  // pred: ^bb1
+  %7 = linalg.init_tensor [] : tensor<i32>
+  %8 = linalg.generic #attrs
+    ins(%6, %6 : tensor<i32>, tensor<i32>)
+    outs(%7 : tensor<i32>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
+      %9 = addi %arg0, %arg1 : i32
+      linalg.yield %9 : i32
+  } -> tensor<i32>
+  br ^bb3(%8 : tensor<i32>)
+
+^bb3(%10: tensor<i32>):  // pred: ^bb1
+  br ^bb4(%10 : tensor<i32>)
+
+^bb4(%11: tensor<i32>):  // pred: ^bb1
+  return %11 : tensor<i32>
+}
+
+// CHECK-LABEL:  func @main()
+// CHECK-NEXT:     constant 0
+// CHECK-NEXT:     constant 10
+// CHECK-NEXT:     br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     br ^[[bb4:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<1xi32>
+// CHECK-NEXT:     linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT:     return %{{.*}}
+// CHECK-NEXT:   }
+
+// -----
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+func @main() -> (tensor<i32>) attributes {} {
+  %c0 = constant 0 : i32
+  %0 = tensor.from_elements %c0 : tensor<1xi32>
+  %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+  %c10 = constant 10 : i32
+  %1 = tensor.from_elements %c10 : tensor<1xi32>
+  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  br ^bb1(%reshaped0 : tensor<i32>)
+
+^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
+  %3 = linalg.init_tensor [] : tensor<i1>
+  %4 = linalg.generic #attrs
+    ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+    outs(%3 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %5 = tensor.extract %4[] : tensor<i1>
+  // This cond_br intentionally has bb2 as it's target for both branches. This
+  // is to make sure that the "forward phase" of the cost-model correctly adds
+  // the users of a block argument (in this case bb2's argument) to the work
+  // list.
+  cond_br %5, ^bb2(%2 : tensor<i32>), ^bb2(%2 : tensor<i32>)
+
+^bb2(%6: tensor<i32>):  // pred: ^bb1
+  %12 = tensor.from_elements %c10 : tensor<1xi32>
+  %reshaped12 = linalg.tensor_reshape %12 [] : tensor<1xi32> into tensor<i32>
+  %7 = linalg.init_tensor [] : tensor<i32>
+  %8 = linalg.generic #attrs
+    ins(%6, %reshaped12 : tensor<i32>, tensor<i32>)
+    outs(%7 : tensor<i32>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
+      %9 = addi %arg0, %arg1 : i32
+      linalg.yield %9 : i32
+  } -> tensor<i32>
+  br ^bb3(%8 : tensor<i32>)
+
+^bb3(%10: tensor<i32>):  // pred: ^bb1
+  return %10 : tensor<i32>
+}
+
+// CHECK-LABEL:  func @main()
+// CHECK-NEXT:     constant 0
+// CHECK-NEXT:     constant 10
+// CHECK-NEXT:     br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:     addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:     br ^[[bb3:.*]](%{{.*}} : i32)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<1xi32>
+// CHECK-NEXT:     linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
 // CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }

diff  --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index 72390f0d7608..4b4cf3aa98bf 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -62,12 +62,12 @@ func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {
 // DET-CF:         tensor.extract {{.*}}
 // DET-CF:         br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb1]](%{{.*}}: i32)
-// DET-CF-DAG      tensor.from_elements {{.*}}
-// DET-CF-DAG:     linalg.tensor_reshape {{.*}}
-// DET-CF-DAG:     cmpi slt, {{.*}}
-// DET-CF:         cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : tensor<i32>)
+// DET-CF:         cmpi slt, {{.*}}
+// DET-CF:         cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb2]](%{{.*}}: i32)
 // DET-CF:         addi {{.*}}
 // DET-CF:         br ^[[bb1]](%{{.*}} : i32)
-// DET-CF:       ^[[bb3]](%{{.*}}: tensor<i32>)
+// DET-CF:       ^[[bb3]](%{{.*}}: i32)
+// DET-CF:         tensor.from_elements %{{.*}} : tensor<1xi32>
+// DET-CF:         linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>


        


More information about the Mlir-commits mailing list