[Mlir-commits] [mlir] [mlir][bufferization] Fix invalid IR from eliminate-empty-tensors (PR #95978)

Spenser Bauman llvmlistbot at llvm.org
Tue Jun 18 12:44:59 PDT 2024


https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/95978

EmptyTensorElimination can construct IR which violates dominance requirements when a tensor.empty operation serves as the root for multiple inputs, such as the tensor.insert_slice operation below:

  %0 = tensor.empty() : tensor<1x7x1xf32>
  %inserted_slice = tensor.insert_slice %0 into %0[0, 0, 0] [1, 7, 1] [1, 1, 1]
      : tensor<1x7x1xf32> into tensor<1x7x1xf32>

>From this IR, EmptyTensorElimination would construct the following tensor.extract_slice operation:

  %0 = tensor.extract_slice %0[0, 0, 0] [1, 7, 1] [1, 1, 1]
      : tensor<1x7x1xf32> to tensor<1x7x1xf32>

This change avoids constructing the above IR by examining the values needed to construct the new tensor.extract_slice operation and bails out if the tensor.empty candidate is in the set of needed values.

>From fdb5d58c39c22c736b54d88c6dc7e90f3ff0f91a Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Mon, 17 Jun 2024 20:33:37 -0400
Subject: [PATCH] [mlir][bufferization] Fix invalid IR from
 eliminate-empty-tensors

EmptyTensorElimination can construct IR which violates dominance
requirements when a tensor.empty operation serves as the root for
multiple inputs, such as the tensor.insert_slice operation below:

  %0 = tensor.empty() : tensor<1x7x1xf32>
  %inserted_slice = tensor.insert_slice %0 into %0[0, 0, 0] [1, 7, 1] [1, 1, 1]
      : tensor<1x7x1xf32> into tensor<1x7x1xf32>

>From this IR, EmptyTensorElimination would construct the following
tensor.extract_slice operation

  %0 = tensor.extract_slice %0[0, 0, 0] [1, 7, 1] [1, 1, 1]
      : tensor<1x7x1xf32> to tensor<1x7x1xf32>

This change avoids constructing the above IR by examining the values
needed to construct the new tensor.extract_slice operation and bails out
if the tensor.empty candidate is in the set of needed values.
---
 .../Transforms/EmptyTensorElimination.cpp     |  7 ++++++-
 ...ot-bufferize-empty-tensor-elimination.mlir | 21 +++++++++++++++++++
 2 files changed, 27 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index eba1273b36e24..67ee95bf25816 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -137,13 +137,18 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
     for (Value v : emptyTensors) {
       Operation *emptyTensorOp = v.getDefiningOp();
 
+      // The empty tensor op is the operation that we are trying to replace.
+      // If it is one of the values needed as input to the new operation, then
+      // it cannot be eliminated.
+      if (llvm::is_contained(neededValues, emptyTensorOp->getResult(0)))
+        continue;
+
       // Find a suitable insertion point. If no suitable insertion point for
       // the replacement can be found, skip this replacement.
       Operation *insertionPoint =
           findValidInsertionPoint(emptyTensorOp, neededValues);
       if (!insertionPoint)
         continue;
-
       rewriter.setInsertionPoint(insertionPoint);
       Value replacement =
           op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index efe59af97d964..920c16d57709e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -365,3 +365,24 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
   bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
   return
 }
+
+// -----
+
+// This is a regression test to ensure tensor.extract_slice operations
+// which consume their own results are not created.
+// This would occur in the following example when trying to replace the
+
+// CHECK-ELIM-LABEL: func @multiple_uses_of_empty_by_same_op
+//       CHECK-ELIM: tensor.empty
+//       CHECK-ELIM: linalg.fill
+//       CHECK-ELIM: tensor.insert_slice
+func.func @multiple_uses_of_empty_by_same_op() -> tensor<1x7x1xf32> {
+  // Single empty tensor which is the root of both inputs to tensor.insert_slice
+  %0 = tensor.empty() : tensor<1x7x1xf32>
+
+  %zero = arith.constant 0.0 : f32
+  %filled = linalg.fill ins(%zero : f32) outs(%0 : tensor<1x7x1xf32>) -> tensor<1x7x1xf32>
+
+  %inserted_slice = tensor.insert_slice %filled into %0[0, 0, 0] [1, 7, 1] [1, 1, 1] : tensor<1x7x1xf32> into tensor<1x7x1xf32>
+  return %inserted_slice : tensor<1x7x1xf32>
+}



More information about the Mlir-commits mailing list