[Mlir-commits] [mlir] [mlir][sparse] introduce sparse_tensor.reinterpret_map operation. (PR #70378)
Peiming Liu
llvmlistbot at llvm.org
Thu Oct 26 13:43:38 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/70378
None
>From 4186744f14280e4de43714a2e2707aec54fb40f2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 26 Oct 2023 20:43:02 +0000
Subject: [PATCH] [mlir][sparse] introduce sparse_tensor.reinterpret_map
operation.
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 3 +-
.../SparseTensor/IR/SparseTensorOps.td | 45 +++++++++
.../SparseTensor/IR/SparseTensorType.h | 7 ++
.../SparseTensor/IR/SparseTensorDialect.cpp | 97 +++++++++++++++++++
mlir/test/Dialect/SparseTensor/fold.mlir | 15 +++
mlir/test/Dialect/SparseTensor/invalid.mlir | 63 ++++++++++++
mlir/test/Dialect/SparseTensor/roundtrip.mlir | 20 ++++
7 files changed, 249 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 2dd7f8e961929cf..48ff37b3828d3c6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -430,8 +430,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
//
- // Helper function to build IR related to the encoding.
+ // Helper function to build translate between level/dimension space.
//
+ SmallVector<int64_t> tranlateShape(::mlir::ArrayRef<int64_t> srcShape, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 8c33e8651b1694e..e410a801796276f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -208,6 +208,51 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
let hasVerifier = 1;
}
+def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemoryEffect]>,
+ Arguments<(ins AnySparseTensor:$source)>,
+ Results<(outs AnySparseTensor:$dest)> {
+ let summary = "Reinterprets the dimension/level maps of the source tensor";
+ let description = [{
+ Reinterprets the dimension-to-level and level-to-dimension map specified in
+ `source` according to the type of `dest`.
+ `reinterpret_map` is a no-op and is introduced merely to resolve type conflicts.
+ It does not make any modification to the source tensor and source/dest tensors
+ are considered to be aliases.
+
+ `source` and `dest` tensors are "reinterpretable" if and only if they have
+ the exactly same storage at a low level.
+ That is, both `source` and `dest` has the same number of levels and level types,
+ and their shape is consistent before and after `reinterpret_map`.
+
+ Example:
+ ```mlir
+ #CSC = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d1: dense, d0: compressed)
+ }>
+ #CSR = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0: dense, d1: compressed)
+ }>
+ %t1 = sparse_tensor.reinterpret_map %t0 : tensor<3x4xi32, #CSC> to tensor<4x3xi32, #CSR>
+
+ #BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) -> ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+ }>
+ #DSDD = #sparse_tensor.encoding<{
+ map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+ }>
+ %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR> to tensor<3x4x2x3xi32, #DSDD>
+ ```
+ }];
+
+ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index c3e967fdcd90fc0..7a1f1e2144e049d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -245,6 +245,12 @@ class SparseTensorType {
/// Returns the dimension-shape.
ArrayRef<DynSize> getDimShape() const { return rtp.getShape(); }
+ /// Returns the Level-shape.
+ SmallVector<DynSize> getLvlShape() const {
+ return getEncoding().tranlateShape(getDimShape(),
+ CrdTransDirectionKind::dim2lvl);
+ }
+
/// Safely looks up the requested dimension-DynSize. If you intend
/// to check the result with `ShapedType::isDynamic`, then see the
/// `getStaticDimSize` method instead.
@@ -281,6 +287,7 @@ class SparseTensorType {
/// `ShapedType::Trait<T>::getNumDynamicDims`.
int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }
+ ArrayRef<DimLevelType> getLvlTypes() const { return enc.getLvlTypes(); }
DimLevelType getLvlType(Level l) const {
// This OOB check is for dense-tensors, since this class knows
// their lvlRank (whereas STEA::getLvlType will/can only check
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 17e6ef53fe596e0..d392f1c52d61a1d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -415,6 +415,56 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
return getStaticDimSliceStride(toOrigDim(*this, lvl));
}
+SmallVector<int64_t>
+SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
+ CrdTransDirectionKind dir) const {
+ if (isIdentity()) {
+ return SmallVector<int64_t>(srcShape);
+ }
+
+ SmallVector<int64_t> ret;
+ unsigned rank =
+ dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
+ ret.reserve(rank);
+
+ if (isPermutation()) {
+ for (unsigned r = 0; r < rank; r++) {
+ unsigned trans = dir == CrdTransDirectionKind::dim2lvl
+ ? toOrigDim(*this, r)
+ : toStoredDim(*this, r);
+ ret.push_back(srcShape[trans]);
+ }
+ return ret;
+ }
+
+ // Non-permutation
+ AffineMap transMap =
+ dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
+
+ SmallVector<AffineExpr> dimRep;
+ dimRep.reserve(srcShape.size());
+ for (int64_t sz : srcShape) {
+ if (!ShapedType::isDynamic(sz)) {
+ // Push back the max coordinate for the given dimension/level size
+ dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
+ } else {
+ // A dynamic size, use a AffineDimExpr to symbolize the value.
+ dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
+ }
+ };
+
+ for (AffineExpr exp : transMap.getResults()) {
+ // Do constant propagation on the affine map.
+ AffineExpr evalExp =
+ simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
+ if (auto c = evalExp.dyn_cast<AffineConstantExpr>())
+ ret.push_back(c.getValue() + 1);
+ else
+ ret.push_back(ShapedType::kDynamic);
+ }
+ return ret;
+}
+
ValueRange
SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
ValueRange crds,
@@ -1292,6 +1342,53 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
return {};
}
+LogicalResult ReinterpretMapOp::verify() {
+ auto srcStt = getSparseTensorType(getSource());
+ auto dstStt = getSparseTensorType(getDest());
+ ArrayRef<DimLevelType> srcLvlTps = srcStt.getLvlTypes();
+ ArrayRef<DimLevelType> dstLvlTps = dstStt.getLvlTypes();
+
+ if (srcLvlTps.size() != dstLvlTps.size())
+ return emitError("Level rank mismatch between source/dest tensors");
+
+ for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
+ if (srcLvlTp != dstLvlTp)
+ return emitError("Level type mismatch between source/dest tensors");
+
+ if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
+ srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
+ return emitError("Crd/Pos width mismatch between source/dest tensors");
+ }
+
+ if (srcStt.getElementType() != dstStt.getElementType())
+ return emitError("Element type mismatch between source/dest tensors");
+
+ SmallVector<DynSize> srcLvlShape = srcStt.getLvlShape();
+ SmallVector<DynSize> dstLvlShape = dstStt.getLvlShape();
+ for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
+ if (srcLvlSz != dstLvlSz) {
+ // Should we allow one side to be dynamic size, e.g., <?x?> should be
+ // compatible to <3x4>? For now, we require all the level sizes to be
+ // *exactly* matched for simplicity.
+ return emitError("Level size mismatch between source/dest tensors");
+ }
+ }
+
+ return success();
+}
+
+OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
+ if (getSource().getType() == getDest().getType())
+ return getSource();
+
+ if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
+ // A -> B, B -> A ==> A
+ if (def.getSource().getType() == getDest().getType())
+ return def.getSource();
+ }
+ return {};
+}
+
LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index cf70cdc61ba783f..73851c086aa2425 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -111,3 +111,18 @@ func.func @sparse_lvl_3(%t : tensor<?x?xi32, #BSR>) -> index {
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
return %l0 : index
}
+
+#DSDD = #sparse_tensor.encoding<{
+ map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+
+// CHECK-LABEL: func.func @sparse_reinterpret_map(
+// CHECK-NOT: sparse_tensor.reinterpret_map
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<6x12xi32, #BSR> {
+ %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+ to tensor<3x4x2x3xi32, #DSDD>
+ %t2 = sparse_tensor.reinterpret_map %t1 : tensor<3x4x2x3xi32, #DSDD>
+ to tensor<6x12xi32, #BSR>
+ return %t2 : tensor<6x12xi32, #BSR>
+}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 33aa81c5a747d9b..7f9bb4fea7bd459 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -913,3 +913,66 @@ func.func @sparse_lvl(%t : tensor<?x?xi32, #BSR>) -> index {
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
return %l0 : index
}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) -> ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+#DSDC = #sparse_tensor.encoding<{
+ map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: compressed)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDC> {
+ // expected-error at +1 {{Level type mismatch between source/dest tensors}}
+ %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+ to tensor<3x4x2x3xf32, #DSDC>
+ return %t1 : tensor<3x4x2x3xf32, #DSDC>
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) -> ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+#DSDD = #sparse_tensor.encoding<{
+ map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xf32, #DSDD> {
+ // expected-error at +1 {{Element type mismatch between source/dest tensors}}
+ %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+ to tensor<3x4x2x3xf32, #DSDD>
+ return %t1 : tensor<3x4x2x3xf32, #DSDD>
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) -> ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+#DSDD = #sparse_tensor.encoding<{
+ map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x4xi32, #DSDD> {
+ // expected-error at +1 {{Level size mismatch between source/dest tensors}}
+ %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+ to tensor<3x4x2x4xi32, #DSDD>
+ return %t1 : tensor<3x4x2x4xi32, #DSDD>
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d3332eb3bbe33d2..17ae8c065945a1b 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -690,3 +690,23 @@ func.func @sparse_lvl(%arg0: index, %t : tensor<?x?xi32, #BSR>) -> index {
%l0 = sparse_tensor.lvl %t, %arg0 : tensor<?x?xi32, #BSR>
return %l0 : index
}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) -> ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+#DSDD = #sparse_tensor.encoding<{
+ map = (i, j, k, l) -> (i: dense, j: compressed, k: dense, l: dense)
+}>
+
+func.func @sparse_reinterpret_map(%t0 : tensor<6x12xi32, #BSR>) -> tensor<3x4x2x3xi32, #DSDD> {
+ %t1 = sparse_tensor.reinterpret_map %t0 : tensor<6x12xi32, #BSR>
+ to tensor<3x4x2x3xi32, #DSDD>
+ return %t1 : tensor<3x4x2x3xi32, #DSDD>
+}
More information about the Mlir-commits
mailing list