[Mlir-commits] [mlir] 6700a26 - [mlir][linalg][bufferize] Fix insertion point InitTensorElimination

Matthias Springer llvmlistbot at llvm.org
Sun Jan 30 05:25:53 PST 2022


Author: Matthias Springer
Date: 2022-01-30T22:25:39+09:00
New Revision: 6700a26d5f349c05a38d47a555ba2b24b5b3fcec

URL: https://github.com/llvm/llvm-project/commit/6700a26d5f349c05a38d47a555ba2b24b5b3fcec
DIFF: https://github.com/llvm/llvm-project/commit/6700a26d5f349c05a38d47a555ba2b24b5b3fcec.diff

LOG: [mlir][linalg][bufferize] Fix insertion point InitTensorElimination

There was a bug where some of the OpOperands needed in the replacement op were not in scope.

It does not matter where the replacement op is inserted. Any insertion point is OK as long as there are no dominance errors. In the worst case, the newly inserted op will bufferize out-of-place. This is no worse than not eliminating the InitTensorOp at all.

Differential Revision: https://reviews.llvm.org/D117685

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index 05f2257b972d8..06145d028d4d1 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -20,7 +20,10 @@ namespace linalg_ext {
 
 struct InitTensorEliminationStep : public bufferization::PostAnalysisStep {
   /// A function that matches anchor OpOperands for InitTensorOp elimination.
-  using AnchorMatchFn = std::function<bool(OpOperand &)>;
+  /// If an OpOperand is matched, the function should populate the SmallVector
+  /// with all values that are needed during `RewriteFn` to produce the
+  /// replacement value.
+  using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
 
   /// A function that rewrites matched anchors.
   using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index b01100cc9e08f..493044aa53aaa 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Operation.h"
 
 using namespace mlir;
@@ -444,6 +445,79 @@ struct LinalgOpInterfaceHelper<> {
 
 } // namespace
 
+/// Return true if all `neededValues` are in scope at the given
+/// `insertionPoint`.
+static bool
+neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
+                                   Operation *insertionPoint,
+                                   const SmallVector<Value> &neededValues) {
+  for (Value val : neededValues) {
+    if (auto bbArg = val.dyn_cast<BlockArgument>()) {
+      Block *owner = bbArg.getOwner();
+      if (!owner->findAncestorOpInBlock(*insertionPoint))
+        return false;
+    } else {
+      auto opResult = val.cast<OpResult>();
+      if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
+        return false;
+    }
+  }
+  return true;
+}
+
+/// Return true if the given `insertionPoint` dominates all uses of
+/// `initTensorOp`.
+static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
+                                        Operation *insertionPoint,
+                                        Operation *initTensorOp) {
+  for (Operation *user : initTensorOp->getUsers())
+    if (!domInfo.dominates(insertionPoint, user))
+      return false;
+  return true;
+}
+
+/// Find a valid insertion point for a replacement of `initTensorOp`, assuming
+/// that the replacement may use any value from `neededValues`.
+static Operation *
+findValidInsertionPoint(Operation *initTensorOp,
+                        const SmallVector<Value> &neededValues) {
+  DominanceInfo domInfo;
+
+  // Gather all possible insertion points: the location of `initTensorOp` and
+  // right after the definition of each value in `neededValues`.
+  SmallVector<Operation *> insertionPointCandidates;
+  insertionPointCandidates.push_back(initTensorOp);
+  for (Value val : neededValues) {
+    // Note: The anchor op is using all of `neededValues`, so:
+    // * in case of a block argument: There must be at least one op in the block
+    //                                (the anchor op or one of its parents).
+    // * in case of an OpResult: There must be at least one op right after the
+    //                           defining op (the anchor op or one of its
+    //                           parents).
+    if (auto bbArg = val.dyn_cast<BlockArgument>()) {
+      insertionPointCandidates.push_back(
+          &bbArg.getOwner()->getOperations().front());
+    } else {
+      insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
+    }
+  }
+
+  // Select first matching insertion point.
+  for (Operation *insertionPoint : insertionPointCandidates) {
+    // Check if all needed values are in scope.
+    if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
+                                            neededValues))
+      continue;
+    // Check if the insertion point is before all uses.
+    if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp))
+      continue;
+    return insertionPoint;
+  }
+
+  // No suitable insertion point was found.
+  return nullptr;
+}
+
 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced
 /// with the the result of `rewriteFunc` if it is anchored on a matching
 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
@@ -462,8 +536,10 @@ mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
       // Skip operands that do not bufferize inplace.
       if (!aliasInfo.isInPlace(operand))
         continue;
+      // All values that are needed to create the replacement op.
+      SmallVector<Value> neededValues;
       // Is this a matching OpOperand?
-      if (!anchorMatchFunc(operand))
+      if (!anchorMatchFunc(operand, neededValues))
         continue;
       SetVector<Value> maybeInitTensor =
           state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
@@ -492,8 +568,14 @@ mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
         return WalkResult::skip();
       Value initTensor = maybeInitTensor.front();
 
+      // Find a suitable insertion point.
+      Operation *insertionPoint =
+          findValidInsertionPoint(initTensor.getDefiningOp(), neededValues);
+      if (!insertionPoint)
+        continue;
+
       // Create a replacement for the InitTensorOp.
-      b.setInsertionPoint(initTensor.getDefiningOp());
+      b.setInsertionPoint(insertionPoint);
       Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
       if (!replacement)
         continue;
@@ -552,7 +634,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
   return eliminateInitTensors(
       op, state, aliasInfo,
       /*anchorMatchFunc=*/
-      [&](OpOperand &operand) {
+      [&](OpOperand &operand, SmallVector<Value> &neededValues) {
         auto insertSliceOp =
             dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
         if (!insertSliceOp)
@@ -560,7 +642,19 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
         // Only inplace bufferized InsertSliceOps are eligible.
         if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/))
           return false;
-        return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
+        if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
+          return false;
+
+        // Collect all values that are needed to construct the replacement op.
+        neededValues.append(insertSliceOp.offsets().begin(),
+                            insertSliceOp.offsets().end());
+        neededValues.append(insertSliceOp.sizes().begin(),
+                            insertSliceOp.sizes().end());
+        neededValues.append(insertSliceOp.strides().begin(),
+                            insertSliceOp.strides().end());
+        neededValues.push_back(insertSliceOp.dest());
+
+        return true;
       },
       /*rewriteFunc=*/
       [](OpBuilder &b, Location loc, OpOperand &operand) {

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
index de1b8321ed5eb..a96deea161cdf 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref init-tensor-elimination" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref init-tensor-elimination" -canonicalize -split-input-file | FileCheck %s
 
 // -----
 
@@ -62,3 +62,62 @@ func @buffer_forwarding_no_conflict(
 
   return %r1: tensor<?xf32>
 }
+
+// -----
+
+//      CHECK: func @insertion_point_inside_loop(
+// CHECK-SAME:     %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index)
+func @insertion_point_inside_loop(%t : tensor<?xf32>, %sz : index) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c5 = arith.constant 5 : index
+
+  // CHECK-NOT: memref.alloc
+  %blank = linalg.init_tensor [5] : tensor<5xf32>
+
+  // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} {
+  %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor<?xf32>) {
+    // CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[iv]]] [5] [1]
+    %iv_i32 = arith.index_cast %iv : index to i32
+    %f = arith.sitofp %iv_i32 : i32 to f32
+
+    // CHECK: linalg.fill(%{{.*}}, %[[subview]])
+    %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32>
+
+    // CHECK-NOT: memref.copy
+    %inserted = tensor.insert_slice %filled into %bb[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
+    scf.yield %inserted : tensor<?xf32>
+  }
+
+  return %r : tensor<?xf32>
+}
+
+// -----
+
+//      CHECK: func @insertion_point_outside_loop(
+// CHECK-SAME:     %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index, %[[idx:.*]]: index)
+func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
+                                   %idx : index) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c5 = arith.constant 5 : index
+
+  // CHECK-NOT: memref.alloc
+  // CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[idx]]] [5] [1]
+  %blank = linalg.init_tensor [5] : tensor<5xf32>
+
+  // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} {
+  %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor<?xf32>) {
+    %iv_i32 = arith.index_cast %iv : index to i32
+    %f = arith.sitofp %iv_i32 : i32 to f32
+
+    // CHECK: linalg.fill(%{{.*}}, %[[subview]])
+    %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32>
+
+    // CHECK-NOT: memref.copy
+    %inserted = tensor.insert_slice %filled into %bb[%idx][5][1] : tensor<5xf32> into tensor<?xf32>
+    scf.yield %inserted : tensor<?xf32>
+  }
+
+  return %r : tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list