[Mlir-commits] [mlir] [mlir][bufferization] skip empty tensor elimination if they have different element type (PR #96998)

zhicong zhong llvmlistbot at llvm.org
Fri Jun 28 00:56:54 PDT 2024


https://github.com/zhczhong updated https://github.com/llvm/llvm-project/pull/96998

>From 32635f59d68329b2125f9dff88870cd4d3818dc5 Mon Sep 17 00:00:00 2001
From: "Zhong, Zhicong" <zhicong.zhong at intel.com>
Date: Mon, 24 Jun 2024 23:44:00 -0700
Subject: [PATCH] skip empty tensor elimination if they have different element
 type

---
 .../Transforms/EmptyTensorElimination.cpp     |  3 ++
 ...ize-analysis-empty-tensor-elimination.mlir | 29 +++++++++++++++++++
 2 files changed, 32 insertions(+)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index eba1273b36e24..cb2efef5c038b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -152,6 +152,9 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
       if (emptyTensorOp == replacement.getDefiningOp())
         continue;
       if (replacement.getType() != v.getType()) {
+        if (cast<ShapedType>(replacement.getType()).getElementType() !=
+            cast<ShapedType>(v.getType()).getElementType())
+          continue;
         rewriter.setInsertionPointAfterValue(replacement);
         replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
                                                       replacement);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
index 47ede793e9eab..2ba8246a8d525 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
@@ -47,3 +47,32 @@ func.func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {bufferization.wri
   // CHECK-SAME: __equivalent_func_args__ = [0, 0]
   return %2, %2 : tensor<?xf32>, tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @buffer_forwarding_conflict_with_different_element_type
+func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> (tensor<?xf32>, tensor<?xf32>) {
+  //      CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty(%arg1) : tensor<?xf32>
+
+  //      CHECK: bufferization.alloc_tensor(%arg1)
+  %1 = tensor.empty(%arg1) : tensor<?xbf16>
+
+  //      CHECK: linalg.copy
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
+  %2 = linalg.copy ins(%0 : tensor<?xf32>) outs(%1 : tensor<?xbf16>) -> tensor<?xbf16>
+
+  //      CHECK: linalg.copy
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
+  %3 = linalg.copy ins(%2 : tensor<?xbf16>) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
+
+  //      CHECK: tensor.insert_slice
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none"]
+  %4 = tensor.insert_slice %3 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: __equivalent_func_args__ = [0, 0]
+  return %4, %4 : tensor<?xf32>, tensor<?xf32>
+}



More information about the Mlir-commits mailing list