[Mlir-commits] [mlir] [mlir][tensor][sparse] don't drop encoding when infer result type (PR #91817)

Peiming Liu llvmlistbot at llvm.org
Fri May 10 15:13:33 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/91817

>From 31c6a1a81dafa8b2af1d4d5a47c122ed27d16d21 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 10 May 2024 21:54:58 +0000
Subject: [PATCH 1/2] [mlir][tensor][sparse] don't drop encoding when infer
 result type

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1f94397e823f7..e41d59a0e0b94 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2020,7 +2020,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
   assert(static_cast<int64_t>(staticSizes.size()) ==
              sourceTensorType.getRank() &&
          "unexpected staticSizes not equal to rank of source");
-  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
+  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
+                               sourceTensorType.getEncoding());
 }
 
 RankedTensorType ExtractSliceOp::inferResultType(

>From 01e3c0126be2860112ab904c19e718d0f2374cf4 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 10 May 2024 22:13:06 +0000
Subject: [PATCH 2/2] add check test.

---
 .../Dialect/SparseTensor/canonicalize.mlir    | 23 +++++++++++++++++++
 1 file changed, 23 insertions(+)
 create mode 100644 mlir/test/Dialect/SparseTensor/canonicalize.mlir

diff --git a/mlir/test/Dialect/SparseTensor/canonicalize.mlir b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
new file mode 100644
index 0000000000000..b1d3d7916c142
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+
+#BCOO = #sparse_tensor.encoding<{
+  map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
+}>
+
+// CHECK-DAG: #[[$BCOO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton) }>
+// CHECK-LABEL: func @sparse_slice_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32, #[[$BCOO]]>
+//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32, #[[$BCOO]]> to tensor<4x1x?xf32, #[[$BCOO]]>
+//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SLICE]]
+//       CHECK:   return %[[RESULT]]
+func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : index,
+    %arg2 : index) -> tensor<?x?x?xf32, #BCOO>
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
+  return %0 : tensor<?x?x?xf32, #BCOO>
+}



More information about the Mlir-commits mailing list