[Mlir-commits] [mlir] aba0ef7 - [mlir][bufferization] Support casts in EmptyTensorElimination

Matthias Springer llvmlistbot at llvm.org
Mon Jul 31 06:20:13 PDT 2023

Author: Matthias Springer
Date: 2023-07-31T15:20:00+02:00
New Revision: aba0ef70597980d84e31f41e09cdbd00ea65d9fd

URL: https://github.com/llvm/llvm-project/commit/aba0ef70597980d84e31f41e09cdbd00ea65d9fd
DIFF: https://github.com/llvm/llvm-project/commit/aba0ef70597980d84e31f41e09cdbd00ea65d9fd.diff

LOG: [mlir][bufferization] Support casts in EmptyTensorElimination

EmptyTensorElimination is a pre-bufferization transformation that replaces "tensor.empty" ops with "tensor.extract_slice" ops. This revision adds support for cases where the input IR contains "tensor.cast" ops.

Differential Revision: https://reviews.llvm.org/D156167




diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index d1faaf56c6afc4..9cde7740066e59 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -408,6 +408,10 @@ struct TraversalConfig {
   /// Specifies whether unknown/non-bufferizable/ops not included in the
   /// OpFilter of BufferizationOptions should be followed.
   bool followUnknownOps = false;
+  /// Specifies whether OpOperands with a 
diff erent type that are not the result
+  /// of a CastOpInterface op should be followed.
+  bool followSameTypeOrCastsOnly = false;
 /// AnalysisState provides a variety of helper functions for dealing with

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d9c334983ad814..5eb345df5fe236 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -580,6 +580,16 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
+      if (config.followSameTypeOrCastsOnly &&
+          a.opOperand->get().getType() != value.getType() &&
+          !opResult.getDefiningOp<CastOpInterface>()) {
+        // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is
+        // has a 
diff erent type and the op is not a cast.
+        if (config.alwaysIncludeLeaves)
+          result.insert(value);
+        continue;
+      }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 6aa256881b5c66..4e0781dae0c252 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -135,6 +135,14 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
       TraversalConfig config;
       config.followEquivalentOnly = true;
       config.alwaysIncludeLeaves = false;
+      // Replace only if the types match or are static <-> dynamic casts. We do
+      // not support slices or reshapes.
+      // TODO: This could be extended to support IR such as:
+      // %0 = tensor.empty() : tensor<128xf32>
+      // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
+      // %2 = tensor.expand_shape %1 ...
+      // %3 = tensor.insert_slice %2 into ...
+      config.followSameTypeOrCastsOnly = true;
       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
           operand.get(), /*condition=*/
           [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
@@ -143,15 +151,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
       for (Value v : emptyTensors) {
         Operation *emptyTensorOp = v.getDefiningOp();
-        // Replace only if the types match. We do not support slices or casts.
-        // TODO: This could be extended to support IR such as:
-        // %0 = tensor.empty() : tensor<128xf32>
-        // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
-        // %2 = tensor.expand_shape %1 ...
-        // %3 = tensor.insert_slice %2 into ...
-        if (v.getType() != operand.get().getType())
-          continue;
         // Find a suitable insertion point. If no suitable insertion point for
         // the replacement can be found, skip this replacement.
         Operation *insertionPoint =
@@ -164,7 +163,11 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
             rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand);
         if (!replacement)
+        if (replacement.getType() != v.getType()) {
+          rewriter.setInsertionPointAfterValue(replacement);
+          replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
+                                                        replacement);
+        }
         // Replace the tensor::EmptyOp.
         rewriter.replaceOp(emptyTensorOp, replacement);

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 7cb57cb206723b..3d15599915f0cf 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -123,8 +123,8 @@ func.func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
 // -----
 // EmptyTensorElimination does currently not apply to chains where the type is
-// changing. This test just ensures that we do not crash or generate IR that
-// does not verify.
+// changing. (Casts are supported.) This test just ensures that we do not crash
+// or generate IR that does not verify.
 // CHECK-LABEL: func @shape_mismatch
 func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
@@ -140,6 +140,24 @@ func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
 // -----
+// CHECK-LABEL: func @cast(
+//  CHECK-SAME:     %[[t:.*]]: memref<256xf32,
+//       CHECK:   %[[sv:.*]] = memref.subview %[[t]]
+//       CHECK:   linalg.fill {{.*}} outs(%[[sv]]
+//       CHECK:   return %[[t]]
+func.func @cast(%t: tensor<256xf32>) -> tensor<256xf32> {
+  %cst = arith.constant 8.0 : f32
+  %c128 = arith.constant 128 : index
+  %0 = tensor.empty(%c128) : tensor<?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
+  %2 = tensor.cast %1 : tensor<?xf32> to tensor<128xf32>
+  %3 = tensor.insert_slice %2 into %t[2][128][1]
+      : tensor<128xf32> into tensor<256xf32>
+  return %3 : tensor<256xf32>
+// -----
 //      CHECK: func @parallel_insert_slice(
 // CHECK-SAME:   %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
 // CHECK-SAME:   %[[sz:[0-9a-zA-Z]*]]: index


More information about the Mlir-commits mailing list