[Mlir-commits] [mlir] [mlir][bufferization] skip empty tensor elimination if they have different element type (PR #96998)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 27 20:28:08 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-bufferization
Author: zhicong zhong (zhczhong)
<details>
<summary>Changes</summary>
In the origin implementation, the empty tensor elimination will add a `tensor.cast` and eliminate the tensor even if they have different element type(f32, bf16). Here add a check for element type and skip the elimination if they are different.
---
Full diff: https://github.com/llvm/llvm-project/pull/96998.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp (+3)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir (+29)
``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index eba1273b36e24..0c3245eecfda7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -152,6 +152,9 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
if (emptyTensorOp == replacement.getDefiningOp())
continue;
if (replacement.getType() != v.getType()) {
+ if (ShapeAdaptor(replacement.getType()).getElementType() !=
+ ShapeAdaptor(v.getType()).getElementType())
+ continue;
rewriter.setInsertionPointAfterValue(replacement);
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
replacement);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
index 47ede793e9eab..2ba8246a8d525 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
@@ -47,3 +47,32 @@ func.func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {bufferization.wri
// CHECK-SAME: __equivalent_func_args__ = [0, 0]
return %2, %2 : tensor<?xf32>, tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @buffer_forwarding_conflict_with_different_element_type
+func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> (tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty(%arg1) : tensor<?xf32>
+
+ // CHECK: bufferization.alloc_tensor(%arg1)
+ %1 = tensor.empty(%arg1) : tensor<?xbf16>
+
+ // CHECK: linalg.copy
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
+ %2 = linalg.copy ins(%0 : tensor<?xf32>) outs(%1 : tensor<?xbf16>) -> tensor<?xbf16>
+
+ // CHECK: linalg.copy
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
+ %3 = linalg.copy ins(%2 : tensor<?xbf16>) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
+
+ // CHECK: tensor.insert_slice
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none"]
+ %4 = tensor.insert_slice %3 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
+
+ // CHECK: return
+ // CHECK-SAME: __equivalent_func_args__ = [0, 0]
+ return %4, %4 : tensor<?xf32>, tensor<?xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/96998
More information about the Mlir-commits
mailing list