[Mlir-commits] [mlir] [mlir][sparse] introduce sparse_tensor.reinterpret_map operation. (PR #70378)

Peiming Liu llvmlistbot at llvm.org
Thu Oct 26 13:43:38 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/70378

None

>From 4186744f14280e4de43714a2e2707aec54fb40f2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 26 Oct 2023 20:43:02 +0000
Subject: [PATCH] [mlir][sparse] introduce sparse_tensor.reinterpret_map
 operation.

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  3 +-
 .../SparseTensor/IR/SparseTensorOps.td        | 45 +++++++++
 .../SparseTensor/IR/SparseTensorType.h        |  7 ++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 97 +++++++++++++++++++
 mlir/test/Dialect/SparseTensor/fold.mlir      | 15 +++
 mlir/test/Dialect/SparseTensor/invalid.mlir   | 63 ++++++++++++
 mlir/test/Dialect/SparseTensor/roundtrip.mlir | 20 ++++
 7 files changed, 249 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 2dd7f8e961929cf..48ff37b3828d3c6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -430,8 +430,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
 
     //
-    // Helper function to build IR related to the encoding.
+    // Helper function to build translate between level/dimension space.
     //
+    SmallVector<int64_t> tranlateShape(::mlir::ArrayRef<int64_t> srcShape, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
     ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
 
     //
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 8c33e8651b1694e..e410a801796276f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -208,6 +208,51 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
   let hasVerifier = 1;
 }
 
+def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemoryEffect]>,
+    Arguments<(ins AnySparseTensor:$source)>,
+    Results<(outs AnySparseTensor:$dest)> {
+  let summary = "Reinterprets the dimension/level maps of the source tensor";
+  let description = [{
+    Reinterprets the dimension-to-level and level-to-dimension map specified in
+    `source` according to the type of `dest`.
+    `reinterpret_map` is a no-op and is introduced merely to resolve type conflicts.
+    It does not make any modification to the source tensor and source/dest tensors
+    are considered to be aliases.
+
+    `source` and `dest` tensors are "reinterpretable" if and only if they have
+    the exactly same storage at a low level.
+    That is, both `source` and `dest` has the same number of levels and level types,
+    and their shape is consistent before and after `reinterpret_map`.
+
+    Example:
+    ```mlir
+    #CSC = #sparse_tensor.encoding<{
+      map = (d0, d1) -> (d1: dense, d0: compressed)
+    }>
+    #CSR = #sparse_tensor.encoding<{
+      map = (d0, d1) -> (d0: dense, d1: compressed)
+    }>
+    %t1 = sparse_tensor.reinterpret_map %t0 : tensor<3x4xi32, #CSC> to tensor<4x3xi32, #CSR>
+
+    #BSR = #sparse_tensor.encoding<{
+      map = ( i, j ) -> ( i floordiv 2 : dense,
+                          j floordiv 3 : compressed,
+                          i mod 2      : dense,
+                          j mod 3      : dense
+      )
+    }>
+    #DSDD = #sparse_tensor.encoding<{
+      map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+    }>
+    %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR> to tensor<3x4x2x3xi32, #DSDD>
+    ```
+    }];
+
+  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index c3e967fdcd90fc0..7a1f1e2144e049d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -245,6 +245,12 @@ class SparseTensorType {
   /// Returns the dimension-shape.
   ArrayRef<DynSize> getDimShape() const { return rtp.getShape(); }
 
+  /// Returns the Level-shape.
+  SmallVector<DynSize> getLvlShape() const {
+    return getEncoding().tranlateShape(getDimShape(),
+                                       CrdTransDirectionKind::dim2lvl);
+  }
+
   /// Safely looks up the requested dimension-DynSize.  If you intend
   /// to check the result with `ShapedType::isDynamic`, then see the
   /// `getStaticDimSize` method instead.
@@ -281,6 +287,7 @@ class SparseTensorType {
   /// `ShapedType::Trait<T>::getNumDynamicDims`.
   int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }
 
+  ArrayRef<DimLevelType> getLvlTypes() const { return enc.getLvlTypes(); }
   DimLevelType getLvlType(Level l) const {
     // This OOB check is for dense-tensors, since this class knows
     // their lvlRank (whereas STEA::getLvlType will/can only check
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 17e6ef53fe596e0..d392f1c52d61a1d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -415,6 +415,56 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
   return getStaticDimSliceStride(toOrigDim(*this, lvl));
 }
 
+SmallVector<int64_t>
+SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
+                                        CrdTransDirectionKind dir) const {
+  if (isIdentity()) {
+    return SmallVector<int64_t>(srcShape);
+  }
+
+  SmallVector<int64_t> ret;
+  unsigned rank =
+      dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
+  ret.reserve(rank);
+
+  if (isPermutation()) {
+    for (unsigned r = 0; r < rank; r++) {
+      unsigned trans = dir == CrdTransDirectionKind::dim2lvl
+                           ? toOrigDim(*this, r)
+                           : toStoredDim(*this, r);
+      ret.push_back(srcShape[trans]);
+    }
+    return ret;
+  }
+
+  // Non-permutation
+  AffineMap transMap =
+      dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
+
+  SmallVector<AffineExpr> dimRep;
+  dimRep.reserve(srcShape.size());
+  for (int64_t sz : srcShape) {
+    if (!ShapedType::isDynamic(sz)) {
+      // Push back the max coordinate for the given dimension/level size
+      dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
+    } else {
+      // A dynamic size, use a AffineDimExpr to symbolize the value.
+      dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
+    }
+  };
+
+  for (AffineExpr exp : transMap.getResults()) {
+    // Do constant propagation on the affine map.
+    AffineExpr evalExp =
+        simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
+    if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
+      ret.push_back(c.getValue() + 1);
+    else
+      ret.push_back(ShapedType::kDynamic);
+  }
+  return ret;
+}
+
 ValueRange
 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
                                         ValueRange crds,
@@ -1292,6 +1342,53 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+LogicalResult ReinterpretMapOp::verify() {
+  auto srcStt = getSparseTensorType(getSource());
+  auto dstStt = getSparseTensorType(getDest());
+  ArrayRef<DimLevelType> srcLvlTps = srcStt.getLvlTypes();
+  ArrayRef<DimLevelType> dstLvlTps = dstStt.getLvlTypes();
+
+  if (srcLvlTps.size() != dstLvlTps.size())
+    return emitError("Level rank mismatch between source/dest tensors");
+
+  for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
+    if (srcLvlTp != dstLvlTp)
+      return emitError("Level type mismatch between source/dest tensors");
+
+  if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
+      srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
+    return emitError("Crd/Pos width mismatch between source/dest tensors");
+  }
+
+  if (srcStt.getElementType() != dstStt.getElementType())
+    return emitError("Element type mismatch between source/dest tensors");
+
+  SmallVector<DynSize> srcLvlShape = srcStt.getLvlShape();
+  SmallVector<DynSize> dstLvlShape = dstStt.getLvlShape();
+  for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
+    if (srcLvlSz != dstLvlSz) {
+      // Should we allow one side to be dynamic size, e.g., <?x?> should be
+      // compatible to <3x4>? For now, we require all the level sizes to be
+      // *exactly* matched for simplicity.
+      return emitError("Level size mismatch between source/dest tensors");
+    }
+  }
+
+  return success();
+}
+
+OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
+  if (getSource().getType() == getDest().getType())
+    return getSource();
+
+  if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
+    // A -> B, B -> A ==> A
+    if (def.getSource().getType() == getDest().getType())
+      return def.getSource();
+  }
+  return {};
+}
+
 LogicalResult ToPositionsOp::verify() {
   auto e = getSparseTensorEncoding(getTensor().getType());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index cf70cdc61ba783f..73851c086aa2425 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -111,3 +111,18 @@ func.func @sparse_lvl_3(%t : tensor<?x?xi32, #BSR>) -> index {
   %l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
   return  %l0 : index
 }
+
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+
+// CHECK-LABEL:   func.func @sparse_reinterpret_map(
+// CHECK-NOT: sparse_tensor.reinterpret_map
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<6x12xi32, #BSR> {
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x3xi32, #DSDD>
+  %t2 = sparse_tensor.reinterpret_map %t1 : tensor<3x4x2x3xi32, #DSDD>
+                                         to tensor<6x12xi32, #BSR>
+  return %t2 : tensor<6x12xi32, #BSR>
+}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 33aa81c5a747d9b..7f9bb4fea7bd459 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -913,3 +913,66 @@ func.func @sparse_lvl(%t : tensor<?x?xi32, #BSR>) -> index {
   %l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
   return  %l0 : index
 }
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> ( i floordiv 2 : dense,
+                      j floordiv 3 : compressed,
+                      i mod 2      : dense,
+                      j mod 3      : dense
+  )
+}>
+
+#DSDC = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: compressed)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDC> {
+  // expected-error at +1 {{Level type mismatch between source/dest tensors}}
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x3xf32, #DSDC>
+  return %t1 : tensor<3x4x2x3xf32, #DSDC>
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> ( i floordiv 2 : dense,
+                      j floordiv 3 : compressed,
+                      i mod 2      : dense,
+                      j mod 3      : dense
+  )
+}>
+
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDD> {
+  // expected-error at +1 {{Element type mismatch between source/dest tensors}}
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x3xf32, #DSDD>
+  return %t1 : tensor<3x4x2x3xf32, #DSDD>
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> ( i floordiv 2 : dense,
+                      j floordiv 3 : compressed,
+                      i mod 2      : dense,
+                      j mod 3      : dense
+  )
+}>
+
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x4xi32, #DSDD> {
+  // expected-error at +1 {{Level size mismatch between source/dest tensors}}
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x4xi32, #DSDD>
+  return %t1 : tensor<3x4x2x4xi32, #DSDD>
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d3332eb3bbe33d2..17ae8c065945a1b 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -690,3 +690,23 @@ func.func @sparse_lvl(%arg0: index, %t : tensor<?x?xi32, #BSR>) -> index {
   %l0 = sparse_tensor.lvl %t, %arg0 : tensor<?x?xi32, #BSR>
   return  %l0 : index
 }
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> ( i floordiv 2 : dense,
+                      j floordiv 3 : compressed,
+                      i mod 2      : dense,
+                      j mod 3      : dense
+  )
+}>
+
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xi32, #DSDD> {
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x3xi32, #DSDD>
+  return %t1 : tensor<3x4x2x3xi32, #DSDD>
+}



More information about the Mlir-commits mailing list