[Mlir-commits] [mlir] [mlir][tensor] Fix bug in insert_slice canonical. with tensor encoding (PR #81045)

Alexey Z. llvmlistbot at llvm.org
Wed Feb 7 13:44:50 PST 2024


https://github.com/last5bits created https://github.com/llvm/llvm-project/pull/81045

Previously, `InsertSliceOpSourceCastInserter` was incorrectly applied to a case when tensor types have an encoding attribute attached to them. The type `newSrcType` was missing that attribute from the old `srcType`, which made the expression `srcType == newSrcType` false, since `tensor<2x2xf32, "foo">` is not equal to `tensor<2x2xf32>`. That lead to an endless back and forth between `InsertSliceOpSourceCastInserter` that would introduce a cast and `InsertSliceOpCastFolder` that would remove it right after.

>From 39e0c80552a5680b78cb380806023e7b37559e4e Mon Sep 17 00:00:00 2001
From: Alexey Zhikhartsev <alexey.zhikhar at gmail.com>
Date: Wed, 7 Feb 2024 16:30:29 -0500
Subject: [PATCH] [mlir][tensor] Fix bug in insert_slice canonical. with tensor
 encoding

Previously, `InsertSliceOpSourceCastInserter` was incorrectly applied to
a case when tensor types have an encoding attribute attached to them.
The type `newSrcType` was missing that attribute from the old `srcType`,
which made the expression `srcType == newSrcType` false, since
`tensor<2x2xf32, "foo">` is not equal to `tensor<2x2xf32>`. That lead to
an endless back and forth between `InsertSliceOpSourceCastInserter` that
would introduce a cast and `InsertSliceOpCastFolder` that would remove
it right after.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   |  4 ++--
 mlir/test/Dialect/Tensor/canonicalize.mlir | 18 ++++++++++++++++++
 2 files changed, 20 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b21e89ae3a5713..8298cf102e28a3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2663,8 +2663,8 @@ struct InsertSliceOpSourceCastInserter final
     if (!hasValidSizesOffsets(newSrcShape))
       return failure();
 
-    RankedTensorType newSrcType =
-        RankedTensorType::get(newSrcShape, srcType.getElementType());
+    RankedTensorType newSrcType = RankedTensorType::get(
+        newSrcShape, srcType.getElementType(), srcType.getEncoding());
     if (srcType == newSrcType ||
         !preservesStaticInformation(srcType, newSrcType) ||
         !tensor::CastOp::areCastCompatible(srcType, newSrcType))
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7192a719ceb13d..90c715bf2eb2da 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -555,6 +555,24 @@ func.func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
 
 // -----
 
+// Do not insert a cast for the following example. The new source type wouldn't be "more static" than the old one.
+func.func @insert_slice_canonicalize_encoding(%arg0 : tensor<2x2xf32, "foo">,
+                                              %arg1 : tensor<4x4xf32, "foo">) -> tensor<4x4xf32, "foo">
+{
+  %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [2, 2] [1, 1] : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
+  return %0 : tensor<4x4xf32, "foo">
+}
+// CHECK-LABEL: func @insert_slice_canonicalize_encoding
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x2xf32, "foo">
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: tensor<4x4xf32, "foo">
+//       CHECK-NOT: tensor.cast
+//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[ARG1]]
+//  CHECK-SAME:      [0, 0] [2, 2] [1, 1]
+//  CHECK-SAME:      : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
     %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
 {



More information about the Mlir-commits mailing list