[Mlir-commits] [mlir] 37ffbbb - [mlir][tensor][sparse] don't drop encoding when infer result type (#91817)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 13 09:53:19 PDT 2024


Author: Peiming Liu
Date: 2024-05-13T09:53:15-07:00
New Revision: 37ffbbb19576a884c5bb93b9ac0ae97f89523b6b

URL: https://github.com/llvm/llvm-project/commit/37ffbbb19576a884c5bb93b9ac0ae97f89523b6b
DIFF: https://github.com/llvm/llvm-project/commit/37ffbbb19576a884c5bb93b9ac0ae97f89523b6b.diff

LOG: [mlir][tensor][sparse] don't drop encoding when infer result type (#91817)

A general question is: is it possible to support hooks here to infer the
encoding? E.g., when the extracted tensor slice is rank-reduced, the
encoding need to be updated accordingly as well.

Added: 
    mlir/test/Dialect/SparseTensor/canonicalize.mlir

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1f94397e823f7..e41d59a0e0b94 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2020,7 +2020,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
   assert(static_cast<int64_t>(staticSizes.size()) ==
              sourceTensorType.getRank() &&
          "unexpected staticSizes not equal to rank of source");
-  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
+  return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
+                               sourceTensorType.getEncoding());
 }
 
 RankedTensorType ExtractSliceOp::inferResultType(

diff  --git a/mlir/test/Dialect/SparseTensor/canonicalize.mlir b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
new file mode 100644
index 0000000000000..b1d3d7916c142
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+
+#BCOO = #sparse_tensor.encoding<{
+  map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
+}>
+
+// CHECK-DAG: #[[$BCOO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton) }>
+// CHECK-LABEL: func @sparse_slice_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32, #[[$BCOO]]>
+//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32, #[[$BCOO]]> to tensor<4x1x?xf32, #[[$BCOO]]>
+//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SLICE]]
+//       CHECK:   return %[[RESULT]]
+func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : index,
+    %arg2 : index) -> tensor<?x?x?xf32, #BCOO>
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
+  return %0 : tensor<?x?x?xf32, #BCOO>
+}


        


More information about the Mlir-commits mailing list