[Mlir-commits] [mlir] [mlir][bufferization] skip empty tensor elimination if they have different element type (PR #96998)
zhicong zhong
llvmlistbot at llvm.org
Fri Jun 28 00:56:54 PDT 2024
https://github.com/zhczhong updated https://github.com/llvm/llvm-project/pull/96998
>From 32635f59d68329b2125f9dff88870cd4d3818dc5 Mon Sep 17 00:00:00 2001
From: "Zhong, Zhicong" <zhicong.zhong at intel.com>
Date: Mon, 24 Jun 2024 23:44:00 -0700
Subject: [PATCH] skip empty tensor elimination if they have different element
type
---
.../Transforms/EmptyTensorElimination.cpp | 3 ++
...ize-analysis-empty-tensor-elimination.mlir | 29 +++++++++++++++++++
2 files changed, 32 insertions(+)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index eba1273b36e24..cb2efef5c038b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -152,6 +152,9 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
if (emptyTensorOp == replacement.getDefiningOp())
continue;
if (replacement.getType() != v.getType()) {
+ if (cast<ShapedType>(replacement.getType()).getElementType() !=
+ cast<ShapedType>(v.getType()).getElementType())
+ continue;
rewriter.setInsertionPointAfterValue(replacement);
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
replacement);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
index 47ede793e9eab..2ba8246a8d525 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
@@ -47,3 +47,32 @@ func.func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {bufferization.wri
// CHECK-SAME: __equivalent_func_args__ = [0, 0]
return %2, %2 : tensor<?xf32>, tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @buffer_forwarding_conflict_with_different_element_type
+func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> (tensor<?xf32>, tensor<?xf32>) {
+ // CHECK: tensor.extract_slice
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty(%arg1) : tensor<?xf32>
+
+ // CHECK: bufferization.alloc_tensor(%arg1)
+ %1 = tensor.empty(%arg1) : tensor<?xbf16>
+
+ // CHECK: linalg.copy
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
+ %2 = linalg.copy ins(%0 : tensor<?xf32>) outs(%1 : tensor<?xbf16>) -> tensor<?xbf16>
+
+ // CHECK: linalg.copy
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
+ %3 = linalg.copy ins(%2 : tensor<?xbf16>) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
+
+ // CHECK: tensor.insert_slice
+ // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none"]
+ %4 = tensor.insert_slice %3 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
+
+ // CHECK: return
+ // CHECK-SAME: __equivalent_func_args__ = [0, 0]
+ return %4, %4 : tensor<?xf32>, tensor<?xf32>
+}
More information about the Mlir-commits
mailing list