[Mlir-commits] [mlir] [mlir][sparse] introduce sparse_tensor.lvl operation. (PR #69978)

Peiming Liu llvmlistbot at llvm.org
Mon Oct 23 15:22:36 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/69978

None

>From a2fd58dd6ed6363917b61eb9e7d38be88111ab4f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 23 Oct 2023 22:21:22 +0000
Subject: [PATCH] [mlir][sparse] introduce sparse_tensor.lvl operation.

---
 .../SparseTensor/IR/SparseTensorBase.td       |  1 +
 .../SparseTensor/IR/SparseTensorOps.td        | 58 +++++++++++-
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 88 +++++++++++++++++++
 mlir/test/Dialect/SparseTensor/fold.mlir      | 18 ++++
 mlir/test/Dialect/SparseTensor/invalid.mlir   | 18 ++++
 mlir/test/Dialect/SparseTensor/roundtrip.mlir | 21 +++++
 6 files changed, 203 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
index cb4668c795b5d1c..74e6783e260faf9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td
@@ -90,6 +90,7 @@ def SparseTensor_Dialect : Dialect {
 
   let useDefaultAttributePrinterParser = 1;
   let useDefaultTypePrinterParser = 1;
+  let hasConstantMaterializer = 1;
 }
 
 #endif // SPARSETENSOR_BASE
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 7209f2ef8488bcc..ce71c5884a93d28 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -521,9 +521,65 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
 }
 
 //===----------------------------------------------------------------------===//
-// Sparse Tensor Coordinate Translation Operation.
+// Sparse Tensor Coordinate Operations.
 //===----------------------------------------------------------------------===//
 
+def SparseTensor_LvlOp : SparseTensor_Op<"lvl", [ConditionallySpeculatable, NoMemoryEffect]>,
+    Arguments<(ins AnySparseTensor:$source, Index:$index)>,
+    Results<(outs Index:$result)> {
+  let summary = "dimension index operation";
+  let description = [{
+    The `sparse_tensor.lvl` behaves similar to `tensor.dim` operation.
+    It takes a sparse tensor and a level operand of type `index` and returns
+    the size of the requested level of the given sparse tensor.
+    If the level index is out of bounds, the behavior is undefined.
+
+    Example:
+
+    ```mlir
+    #BSR = #sparse_tensor.encoding<{
+      map = ( i, j ) ->
+        ( i floordiv 2 : dense,
+          j floordiv 3 : compressed,
+          i mod 2      : dense,
+          j mod 3      : dense
+        )
+    }>
+
+    // Always returns 2 (4 floordiv 2), can be constant folded:
+    %c0 = arith.constant 0 : index
+    %x = sparse_tensor.lvl %A, %c0 : tensor<4x?xf32, #BSR>
+
+    // Return the dynamic dimension of %A computed by %j mod 3.
+    %c1 = arith.constant 1 : index
+    %y = sparse_tensor.lvl %A, %c1 : tensor<4x?xf32, #BSR>
+
+    // Always return 3 (since j mod 3 < 3), can be constant fold
+    %c3 = arith.constant 3 : index
+    %y = sparse_tensor.lvl %A, %c3 : tensor<4x?xf32, #BSR>
+    ```
+  }];
+
+  let assemblyFormat = [{
+    attr-dict $source `,` $index `:` type($source)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value":$source, "int64_t":$index)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Helper function to get the index as a simple integer if it is constant.
+    std::optional<uint64_t> getConstantLvlIndex();
+
+    /// Interface method for ConditionallySpeculatable.
+    Speculation::Speculatability getSpeculatability();
+  }];
+
+  let hasVerifier = 1;
+  let hasFolder = 1;
+}
+
 def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
     Arguments<(ins Variadic<Index>:$in_crds,
                    SparseTensorCrdTransDirectionAttr:$direction,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 56214c2b41c387b..71eae5b078dddcc 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1208,6 +1208,84 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
+LogicalResult LvlOp::verify() {
+  if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
+    auto stt = getSparseTensorType(getSource());
+    if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
+      emitError("Level index exceeds the rank of the input sparse tensor");
+  }
+  return success();
+}
+
+std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
+  return getConstantIntValue(getIndex());
+}
+
+Speculation::Speculatability LvlOp::getSpeculatability() {
+  auto constantIndex = getConstantLvlIndex();
+  if (!constantIndex)
+    return Speculation::NotSpeculatable;
+
+  assert(constantIndex <
+         cast<RankedTensorType>(getSource().getType()).getRank());
+  return Speculation::Speculatable;
+}
+
+OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
+  auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
+  if (!lvlIndex)
+    return {};
+
+  Level lvl = lvlIndex.getAPSInt().getZExtValue();
+  auto stt = getSparseTensorType(getSource());
+  if (lvl >= stt.getLvlRank()) {
+    // Follows the same convention used by tensor.dim operation. Out of bound
+    // indices produce undefined behavior but are still valid IR. Don't choke on
+    // them.
+    return {};
+  }
+
+  // Helper lambda to build an IndexAttr;
+  auto getIndexAttr = [this](int64_t lvlSz) {
+    return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
+  };
+
+  // TODO: we can remove this after SparseTensorEncoding always returns non-null
+  // dimToLvl map.
+  ArrayRef<DynSize> shape = stt.getDimShape();
+  if (stt.isPermutation()) {
+    Dimension dim = toOrigDim(stt, lvl);
+    if (!ShapedType::isDynamic(shape[dim])) {
+      return getIndexAttr(shape[dim]);
+    }
+    return {};
+  }
+
+  // Non-permutation dim2lvl/lvl2dim maps.
+  AffineExpr lvlExpr = stt.getDimToLvl().getResult(lvl);
+  if (auto binExpr = lvlExpr.dyn_cast<AffineBinaryOpExpr>()) {
+    if (lvlExpr.getKind() == AffineExprKind::Mod) {
+      // j % block_sz, the level size equals to the block size.
+      int64_t lvlSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
+      return getIndexAttr(lvlSz);
+    }
+    if (lvlExpr.getKind() == AffineExprKind::FloorDiv) {
+      // j / block_sz, the level size equals to dim[j] / block_sz.
+      Dimension dim = binExpr.getLHS().cast<AffineDimExpr>().getPosition();
+      int64_t blockSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
+      if (ShapedType::isDynamic(shape[dim]))
+        return {};
+      return getIndexAttr(shape[dim] / blockSz);
+    }
+  }
+
+  auto dim = lvlExpr.cast<AffineDimExpr>().getPosition();
+  if (!ShapedType::isDynamic(dim))
+    return getIndexAttr(shape[dim]);
+
+  return {};
+}
+
 LogicalResult ToPositionsOp::verify() {
   auto e = getSparseTensorEncoding(getTensor().getType());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1639,6 +1717,16 @@ LogicalResult YieldOp::verify() {
 // TensorDialect Methods.
 //===----------------------------------------------------------------------===//
 
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
+                                                    Attribute value, Type type,
+                                                    Location loc) {
+  if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
+    return op;
+  return nullptr;
+}
+
 void SparseTensorDialect::initialize() {
   addAttributes<
 #define GET_ATTRDEF_LIST
diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 3428f6d4ae5a117..cf70cdc61ba783f 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -93,3 +93,21 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
   %d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
   return  %d0, %d1 : index, index
 }
+
+// CHECK-LABEL:   func.func @sparse_lvl_0(
+// CHECK:           %[[C5:.*]] = arith.constant 5 : index
+// CHECK:           return %[[C5]] : index
+func.func @sparse_lvl_0(%t : tensor<10x?xi32, #BSR>) -> index {
+  %lvl = arith.constant 0 : index
+  %l0 = sparse_tensor.lvl %t, %lvl : tensor<10x?xi32, #BSR>
+  return  %l0 : index
+}
+
+// CHECK-LABEL:   func.func @sparse_lvl_3(
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           return %[[C3]] : index
+func.func @sparse_lvl_3(%t : tensor<?x?xi32, #BSR>) -> index {
+  %lvl = arith.constant 3 : index
+  %l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
+  return  %l0 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 1ab1bac6b592ee7..33aa81c5a747d9b 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -895,3 +895,21 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (in
   %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
   return  %l0, %l1, %l2, %l3 : index, index, index, index
 }
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+func.func @sparse_lvl(%t : tensor<?x?xi32, #BSR>) -> index {
+  %lvl = arith.constant 5 : index
+  // expected-error at +1 {{Level index exceeds the rank of the input sparse tensor}}
+  %l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
+  return  %l0 : index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index af9618ebe380d80..d3332eb3bbe33d2 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -669,3 +669,24 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, in
   %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
   return  %l0, %l1, %l2, %l3 : index, index, index, index
 }
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_lvl(
+// CHECK-SAME:      %[[VAL_0:.*]]: index,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor
+// CHECK:           %[[VAL_2:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_0]]
+// CHECK:           return %[[VAL_2]]
+func.func @sparse_lvl(%arg0: index, %t : tensor<?x?xi32, #BSR>) -> index {
+  %l0 = sparse_tensor.lvl %t, %arg0 : tensor<?x?xi32, #BSR>
+  return  %l0 : index
+}



More information about the Mlir-commits mailing list