[Mlir-commits] [mlir] ff21a90 - [mlir][sparse] introduce sparse_tensor.crd_translate operation (#69630)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 15:42:13 PDT 2023
Author: Peiming Liu
Date: 2023-10-19T15:42:09-07:00
New Revision: ff21a90e51ac3ad954df4f13adcf030a24c2a6a3
URL: https://github.com/llvm/llvm-project/commit/ff21a90e51ac3ad954df4f13adcf030a24c2a6a3
DIFF: https://github.com/llvm/llvm-project/commit/ff21a90e51ac3ad954df4f13adcf030a24c2a6a3.diff
LOG: [mlir][sparse] introduce sparse_tensor.crd_translate operation (#69630)
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/fold.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 47fd18a689d5a8d..b0fbbd747b76604 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -541,4 +541,26 @@ def SparseTensorSortKindAttr
"SparseTensorSortAlgorithm"> {
}
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Coordinate Translation Direction Attribute.
+//===----------------------------------------------------------------------===//
+
+// The C++ enum for sparse tensor coordinate translation direction enum.
+def SparseTensorCrdTransDirectionEnum
+ : I32EnumAttr<"CrdTransDirectionKind", "sparse tensor coordinate translation direction", [
+ I32EnumAttrCase<"dim2lvl", 0, "dim_to_lvl">,
+ I32EnumAttrCase<"lvl2dim", 1, "lvl_to_dim">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = SparseTensor_Dialect.cppNamespace;
+}
+
+// The C++ enum for sparse tensor coordinate translation direction attribute.
+def SparseTensorCrdTransDirectionAttr
+ : EnumAttr<SparseTensor_Dialect, SparseTensorCrdTransDirectionEnum,
+ "CrdTransDirection"> {
+}
+
+
#endif // SPARSETENSOR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index c446b84c5d34103..7209f2ef8488bcc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -520,6 +520,32 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Coordinate Translation Operation.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
+ Arguments<(ins Variadic<Index>:$in_crds,
+ SparseTensorCrdTransDirectionAttr:$direction,
+ SparseTensorEncodingAttr:$encoder)>,
+ Results<(outs Variadic<Index>:$out_crds)> {
+ string summary = "Performs coordinate translation between level and dimension coordinate space.";
+ string description = [{
+ Performs coordinate translation between level and dimension coordinate space according
+ to the affine maps defined by $encoder.
+
+ Example:
+
+ ```mlir
+ %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%d0, %d1] as #BSR
+ : index, index, index, index
+ ```
+ }];
+ let assemblyFormat = "$direction `[` $in_crds `]` `as` $encoder attr-dict `:` type($out_crds)";
+ let hasVerifier = 1;
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Management Operations. These operations are "impure" in the
// sense that some behavior is defined by side-effects. These operations provide
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b03a9140a9f1994..c6e7bfaf47d04d3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1140,6 +1140,61 @@ bool ConvertOp::needsExtraSort() {
return true;
}
+LogicalResult CrdTranslateOp::verify() {
+ uint64_t inRank = getEncoder().getLvlRank();
+ uint64_t outRank = getEncoder().getDimRank();
+
+ if (getDirection() == CrdTransDirectionKind::dim2lvl)
+ std::swap(inRank, outRank);
+
+ if (inRank != getInCrds().size() || outRank != getOutCrds().size())
+ return emitError("Coordinate rank mismatch with encoding");
+
+ return success();
+}
+
+LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ if (getEncoder().isPermutation()) {
+ AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
+ ? getEncoder().getDimToLvl()
+ : getEncoder().getLvlToDim();
+ for (AffineExpr exp : perm.getResults())
+ results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
+ return success();
+ }
+
+ // Fuse dim2lvl/lvl2dim pairs.
+ auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
+ bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
+ return v.getDefiningOp() == def;
+ });
+ if (!sameDef)
+ return failure();
+
+ bool oppositeDir = def.getDirection() != getDirection();
+ bool sameOracle =
+ def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
+ bool sameCount = def.getNumResults() == getInCrds().size();
+ if (!oppositeDir || !sameOracle || !sameCount)
+ return failure();
+
+ // The definition produces the coordinates in the same order as the input
+ // coordinates.
+ bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
+ [](auto valuePair) {
+ auto [lhs, rhs] = valuePair;
+ return lhs == rhs;
+ });
+
+ if (!sameOrder)
+ return failure();
+ // l1 = dim2lvl (lvl2dim l0)
+ // ==> l0
+ results.append(def.getInCrds().begin(), def.getInCrds().end());
+ return success();
+}
+
LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 3dd1a629c129fff..3428f6d4ae5a117 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -75,3 +75,21 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #COO>) -> tensor<?x?xf32,
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #COO> to tensor<?x?xf32, #COO>
return %ret : tensor<?x?xf32, #COO>
}
+
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+// CHECK-LABEL: func @sparse_crd_translate(
+// CHECK-NOT: sparse_tensor.crd_translate
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
+ %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
+ %d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
+ return %d0, %d1 : index, index
+}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 805f3d161921c13..1ab1bac6b592ee7 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -861,3 +861,37 @@ func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf64, #OrderedCOO>
return %ret : tensor<?x?xf64, #OrderedCOO>
}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index) {
+ // expected-error at +1 {{Coordinate rank mismatch with encoding}}
+ %l0, %l1, %l2 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index
+ return %l0, %l1, %l2 : index, index, index
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ // expected-error at +1 {{Coordinate rank mismatch with encoding}}
+ %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
+ return %l0, %l1, %l2, %l3 : index, index, index, index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index cbc3bb824924cdb..af9618ebe380d80 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -647,3 +647,25 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOO>
return %ret : tensor<?x?xf32, #OrderedCOO>
}
+
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_crd_translate(
+// CHECK-SAME: %[[VAL_0:.*]]: index,
+// CHECK-SAME: %[[VAL_1:.*]]: index)
+// CHECK: %[[VAL_2:.*]]:4 = sparse_tensor.crd_translate dim_to_lvl{{\[}}%[[VAL_0]], %[[VAL_1]]]
+// CHECK: return %[[VAL_2]]#0, %[[VAL_2]]#1, %[[VAL_2]]#2, %[[VAL_2]]#3
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index, index) {
+ %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
+ return %l0, %l1, %l2, %l3 : index, index, index, index
+}
More information about the Mlir-commits
mailing list