[Mlir-commits] [mlir] [mlir][sparse] introduce sparse_tensor.crd_translate operation (PR #69630)
Peiming Liu
llvmlistbot at llvm.org
Thu Oct 19 13:12:42 PDT 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/69630
>From bd0e42673fecc65a81b927be84fa8379457dd0e9 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 18:40:52 +0000
Subject: [PATCH 1/3] [mlir][sparse] introduce sparse_tensor.crd_translate
operation
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 22 ++++++++
.../SparseTensor/IR/SparseTensorOps.td | 19 +++++++
.../SparseTensor/IR/SparseTensorDialect.cpp | 54 +++++++++++++++++++
mlir/test/Dialect/SparseTensor/fold.mlir | 18 +++++++
mlir/test/Dialect/SparseTensor/invalid.mlir | 34 ++++++++++++
mlir/test/Dialect/SparseTensor/roundtrip.mlir | 22 ++++++++
6 files changed, 169 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 47fd18a689d5a8d..8b80384d59fdd83 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -541,4 +541,26 @@ def SparseTensorSortKindAttr
"SparseTensorSortAlgorithm"> {
}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Coordinate Translation Direction Attribute.
+//===----------------------------------------------------------------------===//
+
+// The C++ enum for sparse tensor sort kind.
+def SparseTensorCrdTransDirectionEnum
+ : I32EnumAttr<"CrdTransDirectionKind", "sparse tensor coordinate translation direction", [
+ I32EnumAttrCase<"dim2lvl", 0, "dim_to_lvl">,
+ I32EnumAttrCase<"lvl2dim", 1, "lvl_to_dim">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = SparseTensor_Dialect.cppNamespace;
+}
+
+// Define the enum sparse tensor sort kind attribute.
+def SparseTensorCrdTransDirectionAttr
+ : EnumAttr<SparseTensor_Dialect, SparseTensorCrdTransDirectionEnum,
+ "CrdTransDirection"> {
+}
+
+
#endif // SPARSETENSOR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index c446b84c5d34103..03cfc5ca8f6011f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -520,6 +520,25 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Coordinate Translation Operation.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
+ Arguments<(ins Variadic<Index>:$in_crds,
+ SparseTensorCrdTransDirectionAttr:$direction,
+ SparseTensorEncodingAttr:$oracle)>,
+ Results<(outs Variadic<Index>:$out_crds)> {
+ string summary = "Performs coordinate translation between level and dimension coordinate space.";
+ string description = [{
+ Performs coordinate translation between level and dimension coordinate space according
+ to the provided affine maps.
+ }];
+ let assemblyFormat = "$direction `[` $in_crds `]` `as` $oracle attr-dict `:` type($out_crds)";
+ let hasVerifier = 1;
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Management Operations. These operations are "impure" in the
// sense that some behavior is defined by side-effects. These operations provide
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index fd87bbfa905ed69..0ab3bd13cb72533 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1160,6 +1160,60 @@ bool ConvertOp::needsExtraSort() {
return true;
}
+LogicalResult CrdTranslateOp::verify() {
+ size_t inRank = getOracle().getLvlRank();
+ size_t outRank = getOracle().getDimRank();
+
+ if (getDirection() == CrdTransDirectionKind::dim2lvl)
+ std::swap(inRank, outRank);
+
+ if (inRank != getInCrds().size() || outRank != getOutCrds().size())
+ return emitError("Coordinate rank mismatch with encoding");
+
+ return success();
+}
+
+LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ if (getOracle().isPermutation()) {
+ AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
+ ? getOracle().getDimToLvl()
+ : getOracle().getLvlToDim();
+ for (AffineExpr exp : perm.getResults())
+ results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
+ return success();
+ }
+
+ // Fuse dim2lvl/lvl2dim pairs.
+ auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
+ bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
+ return v.getDefiningOp() == def;
+ });
+ if (!sameDef)
+ return failure();
+
+ bool oppositeDir = def.getDirection() != getDirection();
+ bool sameOracle = def.getOracle().getDimToLvl() == getOracle().getDimToLvl();
+ bool sameCount = def.getNumResults() == getInCrds().size();
+ if (!oppositeDir || !sameOracle || !sameCount)
+ return failure();
+
+ // The definition produce the coordinate in the same order as the input
+ // coordinates.
+ bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
+ [](auto valuePair) {
+ auto [lhs, rhs] = valuePair;
+ return lhs == rhs;
+ });
+
+ if (!sameOrder)
+ return failure();
+ // l1 = dim2lvl (lvl2dim l0)
+ // ==> l0
+ results.append(def.getInCrds().begin(), def.getInCrds().end());
+ return success();
+}
+
LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 3dd1a629c129fff..3428f6d4ae5a117 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -75,3 +75,21 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #COO>) -> tensor<?x?xf32,
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #COO> to tensor<?x?xf32, #COO>
return %ret : tensor<?x?xf32, #COO>
}
+
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+// CHECK-LABEL: func @sparse_crd_translate(
+// CHECK-NOT: sparse_tensor.crd_translate
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
+ %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
+ %d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
+ return %d0, %d1 : index, index
+}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 805f3d161921c13..1ab1bac6b592ee7 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -861,3 +861,37 @@ func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf64, #OrderedCOO>
return %ret : tensor<?x?xf64, #OrderedCOO>
}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index) {
+ // expected-error at +1 {{Coordinate rank mismatch with encoding}}
+ %l0, %l1, %l2 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index
+ return %l0, %l1, %l2 : index, index, index
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ // expected-error at +1 {{Coordinate rank mismatch with encoding}}
+ %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
+ return %l0, %l1, %l2, %l3 : index, index, index, index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index cbc3bb824924cdb..af9618ebe380d80 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -647,3 +647,25 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOO>
return %ret : tensor<?x?xf32, #OrderedCOO>
}
+
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_crd_translate(
+// CHECK-SAME: %[[VAL_0:.*]]: index,
+// CHECK-SAME: %[[VAL_1:.*]]: index)
+// CHECK: %[[VAL_2:.*]]:4 = sparse_tensor.crd_translate dim_to_lvl{{\[}}%[[VAL_0]], %[[VAL_1]]]
+// CHECK: return %[[VAL_2]]#0, %[[VAL_2]]#1, %[[VAL_2]]#2, %[[VAL_2]]#3
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index, index) {
+ %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
+ return %l0, %l1, %l2, %l3 : index, index, index, index
+}
>From cc1f014d24405c5b62bd28ff55390c6254cfd52f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 18:49:00 +0000
Subject: [PATCH 2/3] avoid using size_t
---
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0ab3bd13cb72533..8b11c144137dbcd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1161,8 +1161,8 @@ bool ConvertOp::needsExtraSort() {
}
LogicalResult CrdTranslateOp::verify() {
- size_t inRank = getOracle().getLvlRank();
- size_t outRank = getOracle().getDimRank();
+ uint64_t inRank = getOracle().getLvlRank();
+ uint64_t outRank = getOracle().getDimRank();
if (getDirection() == CrdTransDirectionKind::dim2lvl)
std::swap(inRank, outRank);
>From 611eb1ab45b6ed1eb72a6946375b094eb62474df Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 20:12:26 +0000
Subject: [PATCH 3/3] address comments
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 4 ++--
.../Dialect/SparseTensor/IR/SparseTensorOps.td | 13 ++++++++++---
.../SparseTensor/IR/SparseTensorDialect.cpp | 15 ++++++++-------
3 files changed, 20 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 8b80384d59fdd83..b0fbbd747b76604 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -546,7 +546,7 @@ def SparseTensorSortKindAttr
// Sparse Tensor Coordinate Translation Direction Attribute.
//===----------------------------------------------------------------------===//
-// The C++ enum for sparse tensor sort kind.
+// The C++ enum for sparse tensor coordinate translation direction enum.
def SparseTensorCrdTransDirectionEnum
: I32EnumAttr<"CrdTransDirectionKind", "sparse tensor coordinate translation direction", [
I32EnumAttrCase<"dim2lvl", 0, "dim_to_lvl">,
@@ -556,7 +556,7 @@ def SparseTensorCrdTransDirectionEnum
let cppNamespace = SparseTensor_Dialect.cppNamespace;
}
-// Define the enum sparse tensor sort kind attribute.
+// The C++ enum for sparse tensor coordinate translation direction attribute.
def SparseTensorCrdTransDirectionAttr
: EnumAttr<SparseTensor_Dialect, SparseTensorCrdTransDirectionEnum,
"CrdTransDirection"> {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 03cfc5ca8f6011f..4ed3f8a58cbd492 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -527,14 +527,21 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
Arguments<(ins Variadic<Index>:$in_crds,
SparseTensorCrdTransDirectionAttr:$direction,
- SparseTensorEncodingAttr:$oracle)>,
+ SparseTensorEncodingAttr:$encoder)>,
Results<(outs Variadic<Index>:$out_crds)> {
string summary = "Performs coordinate translation between level and dimension coordinate space.";
string description = [{
Performs coordinate translation between level and dimension coordinate space according
- to the provided affine maps.
+ to the affine maps defined by $encoding.
+
+ Example:
+
+ ```mlir
+ %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%d0, %d1] as #BSR
+ : index, index, index, index
+ ```
}];
- let assemblyFormat = "$direction `[` $in_crds `]` `as` $oracle attr-dict `:` type($out_crds)";
+ let assemblyFormat = "$direction `[` $in_crds `]` `as` $encoder attr-dict `:` type($out_crds)";
let hasVerifier = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 8b11c144137dbcd..5c6bd8620234e49 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1161,8 +1161,8 @@ bool ConvertOp::needsExtraSort() {
}
LogicalResult CrdTranslateOp::verify() {
- uint64_t inRank = getOracle().getLvlRank();
- uint64_t outRank = getOracle().getDimRank();
+ uint64_t inRank = getEncoder().getLvlRank();
+ uint64_t outRank = getEncoder().getDimRank();
if (getDirection() == CrdTransDirectionKind::dim2lvl)
std::swap(inRank, outRank);
@@ -1175,10 +1175,10 @@ LogicalResult CrdTranslateOp::verify() {
LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
- if (getOracle().isPermutation()) {
+ if (getEncoder().isPermutation()) {
AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
- ? getOracle().getDimToLvl()
- : getOracle().getLvlToDim();
+ ? getEncoder().getDimToLvl()
+ : getEncoder().getLvlToDim();
for (AffineExpr exp : perm.getResults())
results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
return success();
@@ -1193,12 +1193,13 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
return failure();
bool oppositeDir = def.getDirection() != getDirection();
- bool sameOracle = def.getOracle().getDimToLvl() == getOracle().getDimToLvl();
+ bool sameOracle =
+ def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
bool sameCount = def.getNumResults() == getInCrds().size();
if (!oppositeDir || !sameOracle || !sameCount)
return failure();
- // The definition produce the coordinate in the same order as the input
+ // The definition produces the coordinates in the same order as the input
// coordinates.
bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
[](auto valuePair) {
More information about the Mlir-commits
mailing list