[Mlir-commits] [mlir] 37ffbbb - [mlir][tensor][sparse] don't drop encoding when infer result type (#91817)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 13 09:53:19 PDT 2024
Author: Peiming Liu
Date: 2024-05-13T09:53:15-07:00
New Revision: 37ffbbb19576a884c5bb93b9ac0ae97f89523b6b
URL: https://github.com/llvm/llvm-project/commit/37ffbbb19576a884c5bb93b9ac0ae97f89523b6b
DIFF: https://github.com/llvm/llvm-project/commit/37ffbbb19576a884c5bb93b9ac0ae97f89523b6b.diff
LOG: [mlir][tensor][sparse] don't drop encoding when infer result type (#91817)
A general question is: is it possible to support hooks here to infer the
encoding? E.g., when the extracted tensor slice is rank-reduced, the
encoding need to be updated accordingly as well.
Added:
mlir/test/Dialect/SparseTensor/canonicalize.mlir
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1f94397e823f7..e41d59a0e0b94 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2020,7 +2020,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
- return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
+ return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
+ sourceTensorType.getEncoding());
}
RankedTensorType ExtractSliceOp::inferResultType(
diff --git a/mlir/test/Dialect/SparseTensor/canonicalize.mlir b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
new file mode 100644
index 0000000000000..b1d3d7916c142
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/canonicalize.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
+
+#BCOO = #sparse_tensor.encoding<{
+ map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
+}>
+
+// CHECK-DAG: #[[$BCOO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton) }>
+// CHECK-LABEL: func @sparse_slice_canonicalize
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32, #[[$BCOO]]>
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+// CHECK-SAME: : tensor<?x?x?xf32, #[[$BCOO]]> to tensor<4x1x?xf32, #[[$BCOO]]>
+// CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]]
+// CHECK: return %[[RESULT]]
+func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : index,
+ %arg2 : index) -> tensor<?x?x?xf32, #BCOO>
+{
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
+ return %0 : tensor<?x?x?xf32, #BCOO>
+}
More information about the Mlir-commits
mailing list