[Mlir-commits] [mlir] [mlir][sparse] introduce sparse_tensor.lvl operation. (PR #69978)
Peiming Liu
llvmlistbot at llvm.org
Mon Oct 23 15:22:36 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/69978
None
>From a2fd58dd6ed6363917b61eb9e7d38be88111ab4f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 23 Oct 2023 22:21:22 +0000
Subject: [PATCH] [mlir][sparse] introduce sparse_tensor.lvl operation.
---
.../SparseTensor/IR/SparseTensorBase.td | 1 +
.../SparseTensor/IR/SparseTensorOps.td | 58 +++++++++++-
.../SparseTensor/IR/SparseTensorDialect.cpp | 88 +++++++++++++++++++
mlir/test/Dialect/SparseTensor/fold.mlir | 18 ++++
mlir/test/Dialect/SparseTensor/invalid.mlir | 18 ++++
mlir/test/Dialect/SparseTensor/roundtrip.mlir | 21 +++++
6 files changed, 203 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
index cb4668c795b5d1c..74e6783e260faf9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
@@ -90,6 +90,7 @@ def SparseTensor_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
+ let hasConstantMaterializer = 1;
}
#endif // SPARSETENSOR_BASE
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 7209f2ef8488bcc..ce71c5884a93d28 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -521,9 +521,65 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
}
//===----------------------------------------------------------------------===//
-// Sparse Tensor Coordinate Translation Operation.
+// Sparse Tensor Coordinate Operations.
//===----------------------------------------------------------------------===//
+def SparseTensor_LvlOp : SparseTensor_Op<"lvl", [ConditionallySpeculatable, NoMemoryEffect]>,
+ Arguments<(ins AnySparseTensor:$source, Index:$index)>,
+ Results<(outs Index:$result)> {
+ let summary = "dimension index operation";
+ let description = [{
+ The `sparse_tensor.lvl` behaves similar to `tensor.dim` operation.
+ It takes a sparse tensor and a level operand of type `index` and returns
+ the size of the requested level of the given sparse tensor.
+ If the level index is out of bounds, the behavior is undefined.
+
+ Example:
+
+ ```mlir
+ #BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+ }>
+
+ // Always returns 2 (4 floordiv 2), can be constant folded:
+ %c0 = arith.constant 0 : index
+ %x = sparse_tensor.lvl %A, %c0 : tensor<4x?xf32, #BSR>
+
+ // Return the dynamic dimension of %A computed by %j mod 3.
+ %c1 = arith.constant 1 : index
+ %y = sparse_tensor.lvl %A, %c1 : tensor<4x?xf32, #BSR>
+
+ // Always return 3 (since j mod 3 < 3), can be constant fold
+ %c3 = arith.constant 3 : index
+ %y = sparse_tensor.lvl %A, %c3 : tensor<4x?xf32, #BSR>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ attr-dict $source `,` $index `:` type($source)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$source, "int64_t":$index)>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Helper function to get the index as a simple integer if it is constant.
+ std::optional<uint64_t> getConstantLvlIndex();
+
+ /// Interface method for ConditionallySpeculatable.
+ Speculation::Speculatability getSpeculatability();
+ }];
+
+ let hasVerifier = 1;
+ let hasFolder = 1;
+}
+
def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
Arguments<(ins Variadic<Index>:$in_crds,
SparseTensorCrdTransDirectionAttr:$direction,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 56214c2b41c387b..71eae5b078dddcc 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1208,6 +1208,84 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
return success();
}
+LogicalResult LvlOp::verify() {
+ if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
+ auto stt = getSparseTensorType(getSource());
+ if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
+ emitError("Level index exceeds the rank of the input sparse tensor");
+ }
+ return success();
+}
+
+std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
+ return getConstantIntValue(getIndex());
+}
+
+Speculation::Speculatability LvlOp::getSpeculatability() {
+ auto constantIndex = getConstantLvlIndex();
+ if (!constantIndex)
+ return Speculation::NotSpeculatable;
+
+ assert(constantIndex <
+ cast<RankedTensorType>(getSource().getType()).getRank());
+ return Speculation::Speculatable;
+}
+
+OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
+ auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
+ if (!lvlIndex)
+ return {};
+
+ Level lvl = lvlIndex.getAPSInt().getZExtValue();
+ auto stt = getSparseTensorType(getSource());
+ if (lvl >= stt.getLvlRank()) {
+ // Follows the same convention used by tensor.dim operation. Out of bound
+ // indices produce undefined behavior but are still valid IR. Don't choke on
+ // them.
+ return {};
+ }
+
+ // Helper lambda to build an IndexAttr;
+ auto getIndexAttr = [this](int64_t lvlSz) {
+ return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
+ };
+
+ // TODO: we can remove this after SparseTensorEncoding always returns non-null
+ // dimToLvl map.
+ ArrayRef<DynSize> shape = stt.getDimShape();
+ if (stt.isPermutation()) {
+ Dimension dim = toOrigDim(stt, lvl);
+ if (!ShapedType::isDynamic(shape[dim])) {
+ return getIndexAttr(shape[dim]);
+ }
+ return {};
+ }
+
+ // Non-permutation dim2lvl/lvl2dim maps.
+ AffineExpr lvlExpr = stt.getDimToLvl().getResult(lvl);
+ if (auto binExpr = lvlExpr.dyn_cast<AffineBinaryOpExpr>()) {
+ if (lvlExpr.getKind() == AffineExprKind::Mod) {
+ // j % block_sz, the level size equals to the block size.
+ int64_t lvlSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
+ return getIndexAttr(lvlSz);
+ }
+ if (lvlExpr.getKind() == AffineExprKind::FloorDiv) {
+ // j / block_sz, the level size equals to dim[j] / block_sz.
+ Dimension dim = binExpr.getLHS().cast<AffineDimExpr>().getPosition();
+ int64_t blockSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
+ if (ShapedType::isDynamic(shape[dim]))
+ return {};
+ return getIndexAttr(shape[dim] / blockSz);
+ }
+ }
+
+ auto dim = lvlExpr.cast<AffineDimExpr>().getPosition();
+ if (!ShapedType::isDynamic(dim))
+ return getIndexAttr(shape[dim]);
+
+ return {};
+}
+
LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1639,6 +1717,16 @@ LogicalResult YieldOp::verify() {
// TensorDialect Methods.
//===----------------------------------------------------------------------===//
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
+ return op;
+ return nullptr;
+}
+
void SparseTensorDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 3428f6d4ae5a117..cf70cdc61ba783f 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -93,3 +93,21 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
%d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
return %d0, %d1 : index, index
}
+
+// CHECK-LABEL: func.func @sparse_lvl_0(
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: return %[[C5]] : index
+func.func @sparse_lvl_0(%t : tensor<10x?xi32, #BSR>) -> index {
+ %lvl = arith.constant 0 : index
+ %l0 = sparse_tensor.lvl %t, %lvl : tensor<10x?xi32, #BSR>
+ return %l0 : index
+}
+
+// CHECK-LABEL: func.func @sparse_lvl_3(
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: return %[[C3]] : index
+func.func @sparse_lvl_3(%t : tensor<?x?xi32, #BSR>) -> index {
+ %lvl = arith.constant 3 : index
+ %l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
+ return %l0 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 1ab1bac6b592ee7..33aa81c5a747d9b 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -895,3 +895,21 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (in
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
return %l0, %l1, %l2, %l3 : index, index, index, index
}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+func.func @sparse_lvl(%t : tensor<?x?xi32, #BSR>) -> index {
+ %lvl = arith.constant 5 : index
+ // expected-error at +1 {{Level index exceeds the rank of the input sparse tensor}}
+ %l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
+ return %l0 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index af9618ebe380d80..d3332eb3bbe33d2 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -669,3 +669,24 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, in
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
return %l0, %l1, %l2, %l3 : index, index, index, index
}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_lvl(
+// CHECK-SAME: %[[VAL_0:.*]]: index,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor
+// CHECK: %[[VAL_2:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_0]]
+// CHECK: return %[[VAL_2]]
+func.func @sparse_lvl(%arg0: index, %t : tensor<?x?xi32, #BSR>) -> index {
+ %l0 = sparse_tensor.lvl %t, %arg0 : tensor<?x?xi32, #BSR>
+ return %l0 : index
+}
More information about the Mlir-commits
mailing list