[Mlir-commits] [mlir] 2441c07 - [mlir][bufferization] Support multiple leaves in EmptyTensorElimination

Matthias Springer llvmlistbot at llvm.org
Fri Feb 10 00:38:55 PST 2023

Author: Matthias Springer
Date: 2023-02-10T09:38:47+01:00
New Revision: 2441c0730603eb6a543dae56c14c3e1ccb08fb55

URL: https://github.com/llvm/llvm-project/commit/2441c0730603eb6a543dae56c14c3e1ccb08fb55
DIFF: https://github.com/llvm/llvm-project/commit/2441c0730603eb6a543dae56c14c3e1ccb08fb55.diff

LOG: [mlir][bufferization] Support multiple leaves in EmptyTensorElimination

Support cases where a source tensor can be traced back to multiple possible tensor.empty ops.

Differential Revision: https://reviews.llvm.org/D142130




diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 1579cfd04c79b..5fc12573912f3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -103,64 +103,73 @@ findValidInsertionPoint(Operation *emptyTensorOp,
 /// with the result of `rewriteFunc` if it is anchored on a matching
 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
 /// chain, starting from the OpOperand and always following the aliasing
-/// OpOperand, that eventually ends at a single tensor::EmptyOp.
+/// OpOperand, that eventually ends at the tensor::EmptyOp.
+/// E.g.:
+/// %0 = tensor.empty() : tensor<10xf32>
+/// %1 = linalg.fill ... outs(%0 : tensor<10xf32>)
+/// %2 = tensor.insert_slice %0 into %t ...
+/// In the above example, the anchor is the source operand of the insert_slice
+/// op. When tracing back the reverse use-def chain, we end up at a
+/// tensor.empty op.
 LogicalResult mlir::bufferization::eliminateEmptyTensors(
     RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
     AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
   OpBuilder::InsertionGuard g(rewriter);
-  WalkResult status = op->walk([&](Operation *op) {
+  op->walk([&](Operation *op) {
     for (OpOperand &operand : op->getOpOperands()) {
       // Skip operands that do not bufferize inplace.
       if (!state.isInPlace(operand))
       // All values that are needed to create the replacement op.
       SmallVector<Value> neededValues;
-      // Is this a matching OpOperand?
+      // Is this an anchor?
       if (!anchorMatchFunc(operand, neededValues))
-      SetVector<Value> maybeEmptyTensor = state.findValueInReverseUseDefChain(
-          operand.get(), /*condition=*/[&](Value val) { return false; },
-          /*followEquivalentOnly=*/true);
-      // Replace only if the reverse use-def chain ends at exactly one
-      // tensor::EmptyOp.
-      if (maybeEmptyTensor.size() != 1 ||
-          !maybeEmptyTensor.front().getDefiningOp<tensor::EmptyOp>())
-        continue;
-      Value emptyTensor = maybeEmptyTensor.front();
+      // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
+      // equivalent tensors. I.e., stop when there are ops such as extract_slice
+      // on the path.
+      SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
+          operand.get(), /*condition=*/
+          [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
+          /*followEquivalentOnly=*/true, /*alwaysIncludeLeaves=*/false);
-      // Replace only if the types match.
-      // TODO: This could be extended to support IR such as:
-      // %0 = tensor.empty() : tensor<128xf32>
-      // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
-      // %2 = tensor.expand_shape %1 ...
-      // %3 = tensor.insert_slice %2 into ...
-      if (emptyTensor.getType() != operand.get().getType())
-        continue;
+      for (Value v : emptyTensors) {
+        Operation *emptyTensorOp = v.getDefiningOp();
-      // Find a suitable insertion point.
-      Operation *insertionPoint =
-          findValidInsertionPoint(emptyTensor.getDefiningOp(), neededValues);
-      if (!insertionPoint)
-        continue;
+        // Replace only if the types match. We do not support slices or casts.
+        // TODO: This could be extended to support IR such as:
+        // %0 = tensor.empty() : tensor<128xf32>
+        // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
+        // %2 = tensor.expand_shape %1 ...
+        // %3 = tensor.insert_slice %2 into ...
+        if (v.getType() != operand.get().getType())
+          continue;
-      // Create a replacement for the tensor::EmptyOp.
-      rewriter.setInsertionPoint(insertionPoint);
-      Value replacement = rewriteFunc(rewriter, emptyTensor.getLoc(), operand);
-      if (!replacement)
-        continue;
+        // Find a suitable insertion point. If no suitable insertion point for
+        // the replacement can be found, skip this replacement.
+        Operation *insertionPoint =
+            findValidInsertionPoint(emptyTensorOp, neededValues);
+        if (!insertionPoint)
+          continue;
-      // Replace the tensor::EmptyOp.
-      rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement);
-      state.resetCache();
-    }
+        rewriter.setInsertionPoint(insertionPoint);
+        Value replacement =
+            rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand);
+        if (!replacement)
+          continue;
-    // Advance to the next operation.
-    return WalkResult::advance();
+        // Replace the tensor::EmptyOp.
+        rewriter.replaceOp(emptyTensorOp, replacement);
+        state.resetCache();
+      }
+    }
-  return failure(status.wasInterrupted());
+  return success();
 /// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be
@@ -253,6 +262,7 @@ struct EmptyTensorElimination
 void EmptyTensorElimination::runOnOperation() {
   Operation *op = getOperation();
   OneShotBufferizationOptions options;
+  options.allowReturnAllocs = true;
   OneShotAnalysisState state(op, options);
   if (failed(analyzeOp(op, state))) {

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 1c0860ffafbef..753840572f4b3 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -169,3 +169,38 @@ func.func @parallel_insert_slice(
   return %r1: tensor<?xf32>
+// -----
+// CHECK-LABEL: func @eleminate_multiple_ops(
+//  CHECK-SAME:   %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
+//  CHECK-SAME:   %[[sz:[0-9a-zA-Z]*]]: index
+func.func @eleminate_multiple_ops(%t: tensor<?xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>}, %sz: index, %c: i1)
+    -> (tensor<?xf32>)
+  %cst1 = arith.constant 0.0: f32
+  %cst2 = arith.constant 1.0: f32
+  // CHECK: %[[r:.*]] = scf.if %{{.*}} -> (memref
+  %if = scf.if %c -> tensor<?xf32> {
+    // CHECK: %[[T_SUBVIEW_1:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
+    %a1 = tensor.empty(%sz) : tensor<?xf32>
+    // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref<?xf32
+    %f1 = linalg.fill ins(%cst1 : f32) outs(%a1 : tensor<?xf32>) -> tensor<?xf32>
+    // CHECK: scf.yield %[[T_SUBVIEW_1]]
+    scf.yield %f1 : tensor<?xf32>
+  } else {
+      // CHECK: %[[T_SUBVIEW_2:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
+    %a2 = tensor.empty(%sz) : tensor<?xf32>
+    // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_2]] : memref<?xf32
+    %f2 = linalg.fill ins(%cst2 : f32) outs(%a2 : tensor<?xf32>) -> tensor<?xf32>
+    // CHECK: scf.yield %[[T_SUBVIEW_2]]
+    scf.yield %f2 : tensor<?xf32>
+  }
+  // Self-copy could canonicalize away later.
+  // CHECK: %[[T_SUBVIEW_3:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
+  // CHECK: memref.copy %[[r]], %[[T_SUBVIEW_3]]
+  %r1 = tensor.insert_slice %if into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
+  return %r1: tensor<?xf32>


More information about the Mlir-commits mailing list