[Mlir-commits] [mlir] aba0ef7 - [mlir][bufferization] Support casts in EmptyTensorElimination
Matthias Springer
llvmlistbot at llvm.org
Mon Jul 31 06:20:13 PDT 2023
Author: Matthias Springer
Date: 2023-07-31T15:20:00+02:00
New Revision: aba0ef70597980d84e31f41e09cdbd00ea65d9fd
URL: https://github.com/llvm/llvm-project/commit/aba0ef70597980d84e31f41e09cdbd00ea65d9fd
DIFF: https://github.com/llvm/llvm-project/commit/aba0ef70597980d84e31f41e09cdbd00ea65d9fd.diff
LOG: [mlir][bufferization] Support casts in EmptyTensorElimination
EmptyTensorElimination is a pre-bufferization transformation that replaces "tensor.empty" ops with "tensor.extract_slice" ops. This revision adds support for cases where the input IR contains "tensor.cast" ops.
Differential Revision: https://reviews.llvm.org/D156167
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index d1faaf56c6afc4..9cde7740066e59 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -408,6 +408,10 @@ struct TraversalConfig {
/// Specifies whether unknown/non-bufferizable/ops not included in the
/// OpFilter of BufferizationOptions should be followed.
bool followUnknownOps = false;
+
+ /// Specifies whether OpOperands with a
diff erent type that are not the result
+ /// of a CastOpInterface op should be followed.
+ bool followSameTypeOrCastsOnly = false;
};
/// AnalysisState provides a variety of helper functions for dealing with
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d9c334983ad814..5eb345df5fe236 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -580,6 +580,16 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
continue;
}
+ if (config.followSameTypeOrCastsOnly &&
+ a.opOperand->get().getType() != value.getType() &&
+ !opResult.getDefiningOp<CastOpInterface>()) {
+ // Stop iterating if `followSameTypeOrCastsOnly` is set but the alias is
+ // has a
diff erent type and the op is not a cast.
+ if (config.alwaysIncludeLeaves)
+ result.insert(value);
+ continue;
+ }
+
workingSet.insert(a.opOperand->get());
}
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 6aa256881b5c66..4e0781dae0c252 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -135,6 +135,14 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
+ // Replace only if the types match or are static <-> dynamic casts. We do
+ // not support slices or reshapes.
+ // TODO: This could be extended to support IR such as:
+ // %0 = tensor.empty() : tensor<128xf32>
+ // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
+ // %2 = tensor.expand_shape %1 ...
+ // %3 = tensor.insert_slice %2 into ...
+ config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
operand.get(), /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
@@ -143,15 +151,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
for (Value v : emptyTensors) {
Operation *emptyTensorOp = v.getDefiningOp();
- // Replace only if the types match. We do not support slices or casts.
- // TODO: This could be extended to support IR such as:
- // %0 = tensor.empty() : tensor<128xf32>
- // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
- // %2 = tensor.expand_shape %1 ...
- // %3 = tensor.insert_slice %2 into ...
- if (v.getType() != operand.get().getType())
- continue;
-
// Find a suitable insertion point. If no suitable insertion point for
// the replacement can be found, skip this replacement.
Operation *insertionPoint =
@@ -164,7 +163,11 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand);
if (!replacement)
continue;
-
+ if (replacement.getType() != v.getType()) {
+ rewriter.setInsertionPointAfterValue(replacement);
+ replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
+ replacement);
+ }
// Replace the tensor::EmptyOp.
rewriter.replaceOp(emptyTensorOp, replacement);
state.resetCache();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 7cb57cb206723b..3d15599915f0cf 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -123,8 +123,8 @@ func.func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
// -----
// EmptyTensorElimination does currently not apply to chains where the type is
-// changing. This test just ensures that we do not crash or generate IR that
-// does not verify.
+// changing. (Casts are supported.) This test just ensures that we do not crash
+// or generate IR that does not verify.
// CHECK-LABEL: func @shape_mismatch
func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
@@ -140,6 +140,24 @@ func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
// -----
+// CHECK-LABEL: func @cast(
+// CHECK-SAME: %[[t:.*]]: memref<256xf32,
+// CHECK: %[[sv:.*]] = memref.subview %[[t]]
+// CHECK: linalg.fill {{.*}} outs(%[[sv]]
+// CHECK: return %[[t]]
+func.func @cast(%t: tensor<256xf32>) -> tensor<256xf32> {
+ %cst = arith.constant 8.0 : f32
+ %c128 = arith.constant 128 : index
+ %0 = tensor.empty(%c128) : tensor<?xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
+ %2 = tensor.cast %1 : tensor<?xf32> to tensor<128xf32>
+ %3 = tensor.insert_slice %2 into %t[2][128][1]
+ : tensor<128xf32> into tensor<256xf32>
+ return %3 : tensor<256xf32>
+}
+
+// -----
+
// CHECK: func @parallel_insert_slice(
// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
// CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index
More information about the Mlir-commits
mailing list