[Mlir-commits] [mlir] d808d92 - [mlir][sparse] introduce sparse_tensor.reinterpret_map operation. (#70378)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 26 15:04:13 PDT 2023


Author: Peiming Liu
Date: 2023-10-26T15:04:09-07:00
New Revision: d808d922b42b82ecb21a45e446d612e334d96488

URL: https://github.com/llvm/llvm-project/commit/d808d922b42b82ecb21a45e446d612e334d96488
DIFF: https://github.com/llvm/llvm-project/commit/d808d922b42b82ecb21a45e446d612e334d96488.diff

LOG: [mlir][sparse] introduce sparse_tensor.reinterpret_map operation. (#70378)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/fold.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 2dd7f8e961929cf..6c3cfc2c49b2e91 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -430,8 +430,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
 
     //
-    // Helper function to build IR related to the encoding.
+    // Helper function to translate between level/dimension space.
     //
+    SmallVector<int64_t> tranlateShape(::mlir::ArrayRef<int64_t> srcShape, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
     ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
 
     //

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 50f5e7335dc923b..c5cb0ac155d6828 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -208,6 +208,55 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
   let hasVerifier = 1;
 }
 
+def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemoryEffect]>,
+    Arguments<(ins AnySparseTensor:$source)>,
+    Results<(outs AnySparseTensor:$dest)> {
+  let summary = "Reinterprets the dimension/level maps of the source tensor";
+  let description = [{
+    Reinterprets the dimension-to-level and level-to-dimension map specified in
+    `source` according to the type of `dest`.
+    `reinterpret_map` is a no-op and is introduced merely to resolve type conflicts.
+    It does not make any modification to the source tensor and source/dest tensors
+    are considered to be aliases.
+
+    `source` and `dest` tensors are "reinterpretable" if and only if they have
+    the exactly same storage at a low level.
+    That is, both `source` and `dest` has the same number of levels and level types,
+    and their shape is consistent before and after `reinterpret_map`.
+
+    Example:
+    ```mlir
+    #CSC = #sparse_tensor.encoding<{
+      map = (d0, d1) -> (d1: dense, d0: compressed)
+    }>
+    #CSR = #sparse_tensor.encoding<{
+      map = (d0, d1) -> (d0: dense, d1: compressed)
+    }>
+    %t1 = sparse_tensor.reinterpret_map %t0 : tensor<3x4xi32, #CSC> to tensor<4x3xi32, #CSR>
+
+    #BSR = #sparse_tensor.encoding<{
+      map = ( i, j ) -> ( i floordiv 2 : dense,
+                          j floordiv 3 : compressed,
+                          i mod 2      : dense,
+                          j mod 3      : dense
+      )
+    }>
+    #DSDD = #sparse_tensor.encoding<{
+      map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+    }>
+    %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR> to tensor<3x4x2x3xi32, #DSDD>
+    ```
+    }];
+
+  let builders = [
+    OpBuilder<(ins "SparseTensorEncodingAttr":$dstEnc, "Value":$source)>
+  ];
+
+  let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
     Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index c3e967fdcd90fc0..7a1f1e2144e049d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -245,6 +245,12 @@ class SparseTensorType {
   /// Returns the dimension-shape.
   ArrayRef<DynSize> getDimShape() const { return rtp.getShape(); }
 
+  /// Returns the Level-shape.
+  SmallVector<DynSize> getLvlShape() const {
+    return getEncoding().tranlateShape(getDimShape(),
+                                       CrdTransDirectionKind::dim2lvl);
+  }
+
   /// Safely looks up the requested dimension-DynSize.  If you intend
   /// to check the result with `ShapedType::isDynamic`, then see the
   /// `getStaticDimSize` method instead.
@@ -281,6 +287,7 @@ class SparseTensorType {
   /// `ShapedType::Trait<T>::getNumDynamicDims`.
   int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }
 
+  ArrayRef<DimLevelType> getLvlTypes() const { return enc.getLvlTypes(); }
   DimLevelType getLvlType(Level l) const {
     // This OOB check is for dense-tensors, since this class knows
     // their lvlRank (whereas STEA::getLvlType will/can only check

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 359c0a696858329..be44a0d31c92a7d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -412,6 +412,55 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
   return getStaticDimSliceStride(toOrigDim(*this, lvl));
 }
 
+SmallVector<int64_t>
+SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
+                                        CrdTransDirectionKind dir) const {
+  if (isIdentity())
+    return SmallVector<int64_t>(srcShape);
+
+  SmallVector<int64_t> ret;
+  unsigned rank =
+      dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
+  ret.reserve(rank);
+
+  if (isPermutation()) {
+    for (unsigned r = 0; r < rank; r++) {
+      unsigned trans = dir == CrdTransDirectionKind::dim2lvl
+                           ? toOrigDim(*this, r)
+                           : toStoredDim(*this, r);
+      ret.push_back(srcShape[trans]);
+    }
+    return ret;
+  }
+
+  // Handle non-permutation maps.
+  AffineMap transMap =
+      dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
+
+  SmallVector<AffineExpr> dimRep;
+  dimRep.reserve(srcShape.size());
+  for (int64_t sz : srcShape) {
+    if (!ShapedType::isDynamic(sz)) {
+      // Push back the max coordinate for the given dimension/level size.
+      dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
+    } else {
+      // A dynamic size, use a AffineDimExpr to symbolize the value.
+      dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
+    }
+  };
+
+  for (AffineExpr exp : transMap.getResults()) {
+    // Do constant propagation on the affine map.
+    AffineExpr evalExp =
+        simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
+    if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
+      ret.push_back(c.getValue() + 1);
+    else
+      ret.push_back(ShapedType::kDynamic);
+  }
+  return ret;
+}
+
 ValueRange
 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
                                         ValueRange crds,
@@ -1286,6 +1335,64 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                             SparseTensorEncodingAttr dstEnc, Value source) {
+  auto srcStt = getSparseTensorType(source);
+  SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
+  SmallVector<int64_t> dstDimShape =
+      dstEnc.tranlateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
+  auto dstTp =
+      RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
+  return build(odsBuilder, odsState, dstTp, source);
+}
+
+LogicalResult ReinterpretMapOp::verify() {
+  auto srcStt = getSparseTensorType(getSource());
+  auto dstStt = getSparseTensorType(getDest());
+  ArrayRef<DimLevelType> srcLvlTps = srcStt.getLvlTypes();
+  ArrayRef<DimLevelType> dstLvlTps = dstStt.getLvlTypes();
+
+  if (srcLvlTps.size() != dstLvlTps.size())
+    return emitError("Level rank mismatch between source/dest tensors");
+
+  for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
+    if (srcLvlTp != dstLvlTp)
+      return emitError("Level type mismatch between source/dest tensors");
+
+  if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
+      srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
+    return emitError("Crd/Pos width mismatch between source/dest tensors");
+  }
+
+  if (srcStt.getElementType() != dstStt.getElementType())
+    return emitError("Element type mismatch between source/dest tensors");
+
+  SmallVector<DynSize> srcLvlShape = srcStt.getLvlShape();
+  SmallVector<DynSize> dstLvlShape = dstStt.getLvlShape();
+  for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
+    if (srcLvlSz != dstLvlSz) {
+      // Should we allow one side to be dynamic size, e.g., <?x?> should be
+      // compatible to <3x4>? For now, we require all the level sizes to be
+      // *exactly* matched for simplicity.
+      return emitError("Level size mismatch between source/dest tensors");
+    }
+  }
+
+  return success();
+}
+
+OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
+  if (getSource().getType() == getDest().getType())
+    return getSource();
+
+  if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
+    // A -> B, B -> A ==> A
+    if (def.getSource().getType() == getDest().getType())
+      return def.getSource();
+  }
+  return {};
+}
+
 LogicalResult ToPositionsOp::verify() {
   auto e = getSparseTensorEncoding(getTensor().getType());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))

diff  --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index cf70cdc61ba783f..73851c086aa2425 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -111,3 +111,18 @@ func.func @sparse_lvl_3(%t : tensor<?x?xi32, #BSR>) -> index {
   %l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
   return  %l0 : index
 }
+
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+
+// CHECK-LABEL:   func.func @sparse_reinterpret_map(
+// CHECK-NOT: sparse_tensor.reinterpret_map
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<6x12xi32, #BSR> {
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x3xi32, #DSDD>
+  %t2 = sparse_tensor.reinterpret_map %t1 : tensor<3x4x2x3xi32, #DSDD>
+                                         to tensor<6x12xi32, #BSR>
+  return %t2 : tensor<6x12xi32, #BSR>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 0217ef152be6a0d..dd2298d4c3c143c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -964,3 +964,66 @@ func.func @sparse_lvl(%t : tensor<?x?xi32, #BSR>) -> index {
   %l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
   return  %l0 : index
 }
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> ( i floordiv 2 : dense,
+                      j floordiv 3 : compressed,
+                      i mod 2      : dense,
+                      j mod 3      : dense
+  )
+}>
+
+#DSDC = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: compressed)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDC> {
+  // expected-error at +1 {{Level type mismatch between source/dest tensors}}
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x3xf32, #DSDC>
+  return %t1 : tensor<3x4x2x3xf32, #DSDC>
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> ( i floordiv 2 : dense,
+                      j floordiv 3 : compressed,
+                      i mod 2      : dense,
+                      j mod 3      : dense
+  )
+}>
+
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDD> {
+  // expected-error at +1 {{Element type mismatch between source/dest tensors}}
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x3xf32, #DSDD>
+  return %t1 : tensor<3x4x2x3xf32, #DSDD>
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> ( i floordiv 2 : dense,
+                      j floordiv 3 : compressed,
+                      i mod 2      : dense,
+                      j mod 3      : dense
+  )
+}>
+
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x4xi32, #DSDD> {
+  // expected-error at +1 {{Level size mismatch between source/dest tensors}}
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x4xi32, #DSDD>
+  return %t1 : tensor<3x4x2x4xi32, #DSDD>
+}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d3332eb3bbe33d2..17ae8c065945a1b 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -690,3 +690,23 @@ func.func @sparse_lvl(%arg0: index, %t : tensor<?x?xi32, #BSR>) -> index {
   %l0 = sparse_tensor.lvl %t, %arg0 : tensor<?x?xi32, #BSR>
   return  %l0 : index
 }
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> ( i floordiv 2 : dense,
+                      j floordiv 3 : compressed,
+                      i mod 2      : dense,
+                      j mod 3      : dense
+  )
+}>
+
+#DSDD = #sparse_tensor.encoding<{
+  map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xi32, #DSDD> {
+  %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+                                         to tensor<3x4x2x3xi32, #DSDD>
+  return %t1 : tensor<3x4x2x3xi32, #DSDD>
+}


        


More information about the Mlir-commits mailing list