[Mlir-commits] [mlir] [mlir][sparse] introduce sparse_tensor.crd_translate operation (PR #69630)

Peiming Liu llvmlistbot at llvm.org
Thu Oct 19 13:12:42 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/69630

>From bd0e42673fecc65a81b927be84fa8379457dd0e9 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 18:40:52 +0000
Subject: [PATCH 1/3] [mlir][sparse] introduce sparse_tensor.crd_translate
 operation

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   | 22 ++++++++
 .../SparseTensor/IR/SparseTensorOps.td        | 19 +++++++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 54 +++++++++++++++++++
 mlir/test/Dialect/SparseTensor/fold.mlir      | 18 +++++++
 mlir/test/Dialect/SparseTensor/invalid.mlir   | 34 ++++++++++++
 mlir/test/Dialect/SparseTensor/roundtrip.mlir | 22 ++++++++
 6 files changed, 169 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 47fd18a689d5a8d..8b80384d59fdd83 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -541,4 +541,26 @@ def SparseTensorSortKindAttr
                "SparseTensorSortAlgorithm"> {
 }
 
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Coordinate Translation Direction Attribute.
+//===----------------------------------------------------------------------===//
+
+// The C++ enum for sparse tensor sort kind.
+def SparseTensorCrdTransDirectionEnum
+    : I32EnumAttr<"CrdTransDirectionKind", "sparse tensor coordinate translation direction", [
+        I32EnumAttrCase<"dim2lvl", 0, "dim_to_lvl">,
+        I32EnumAttrCase<"lvl2dim", 1, "lvl_to_dim">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = SparseTensor_Dialect.cppNamespace;
+}
+
+// Define the enum sparse tensor sort kind attribute.
+def SparseTensorCrdTransDirectionAttr
+    : EnumAttr<SparseTensor_Dialect, SparseTensorCrdTransDirectionEnum,
+               "CrdTransDirection"> {
+}
+
+
 #endif // SPARSETENSOR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index c446b84c5d34103..03cfc5ca8f6011f 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -520,6 +520,25 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Coordinate Translation Operation.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
+    Arguments<(ins Variadic<Index>:$in_crds,
+                   SparseTensorCrdTransDirectionAttr:$direction,
+                   SparseTensorEncodingAttr:$oracle)>,
+    Results<(outs Variadic<Index>:$out_crds)> {
+  string summary = "Performs coordinate translation between level and dimension coordinate space.";
+  string description = [{
+    Performs coordinate translation between level and dimension coordinate space according
+    to the provided affine maps.
+  }];
+  let assemblyFormat = "$direction `[` $in_crds `]` `as` $oracle attr-dict `:` type($out_crds)";
+  let hasVerifier = 1;
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Management Operations. These operations are "impure" in the
 // sense that some behavior is defined by side-effects. These operations provide
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index fd87bbfa905ed69..0ab3bd13cb72533 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1160,6 +1160,60 @@ bool ConvertOp::needsExtraSort() {
   return true;
 }
 
+LogicalResult CrdTranslateOp::verify() {
+  size_t inRank = getOracle().getLvlRank();
+  size_t outRank = getOracle().getDimRank();
+
+  if (getDirection() == CrdTransDirectionKind::dim2lvl)
+    std::swap(inRank, outRank);
+
+  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
+    return emitError("Coordinate rank mismatch with encoding");
+
+  return success();
+}
+
+LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
+                                   SmallVectorImpl<OpFoldResult> &results) {
+  if (getOracle().isPermutation()) {
+    AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
+                         ? getOracle().getDimToLvl()
+                         : getOracle().getLvlToDim();
+    for (AffineExpr exp : perm.getResults())
+      results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
+    return success();
+  }
+
+  // Fuse dim2lvl/lvl2dim pairs.
+  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
+  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
+                   return v.getDefiningOp() == def;
+                 });
+  if (!sameDef)
+    return failure();
+
+  bool oppositeDir = def.getDirection() != getDirection();
+  bool sameOracle = def.getOracle().getDimToLvl() == getOracle().getDimToLvl();
+  bool sameCount = def.getNumResults() == getInCrds().size();
+  if (!oppositeDir || !sameOracle || !sameCount)
+    return failure();
+
+  // The definition produce the coordinate in the same order as the input
+  // coordinates.
+  bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
+                                [](auto valuePair) {
+                                  auto [lhs, rhs] = valuePair;
+                                  return lhs == rhs;
+                                });
+
+  if (!sameOrder)
+    return failure();
+  // l1 = dim2lvl (lvl2dim l0)
+  // ==> l0
+  results.append(def.getInCrds().begin(), def.getInCrds().end());
+  return success();
+}
+
 LogicalResult ToPositionsOp::verify() {
   auto e = getSparseTensorEncoding(getTensor().getType());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))
diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 3dd1a629c129fff..3428f6d4ae5a117 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -75,3 +75,21 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #COO>) -> tensor<?x?xf32,
   %ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #COO> to tensor<?x?xf32, #COO>
   return %ret : tensor<?x?xf32, #COO>
 }
+
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+// CHECK-LABEL: func @sparse_crd_translate(
+//   CHECK-NOT:   sparse_tensor.crd_translate
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
+  %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
+  %d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
+  return  %d0, %d1 : index, index
+}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 805f3d161921c13..1ab1bac6b592ee7 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -861,3 +861,37 @@ func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -
   %ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf64, #OrderedCOO>
   return %ret : tensor<?x?xf64, #OrderedCOO>
 }
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index) {
+  // expected-error at +1 {{Coordinate rank mismatch with encoding}}
+  %l0, %l1, %l2 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index
+  return  %l0, %l1, %l2 : index, index, index
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  // expected-error at +1 {{Coordinate rank mismatch with encoding}}
+  %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
+  return  %l0, %l1, %l2, %l3 : index, index, index, index
+}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index cbc3bb824924cdb..af9618ebe380d80 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -647,3 +647,25 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<
   %ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOO>
   return %ret : tensor<?x?xf32, #OrderedCOO>
 }
+
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_crd_translate(
+// CHECK-SAME:      %[[VAL_0:.*]]: index,
+// CHECK-SAME:      %[[VAL_1:.*]]: index)
+// CHECK:           %[[VAL_2:.*]]:4 = sparse_tensor.crd_translate  dim_to_lvl{{\[}}%[[VAL_0]], %[[VAL_1]]]
+// CHECK:           return %[[VAL_2]]#0, %[[VAL_2]]#1, %[[VAL_2]]#2, %[[VAL_2]]#3
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index, index) {
+  %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
+  return  %l0, %l1, %l2, %l3 : index, index, index, index
+}

>From cc1f014d24405c5b62bd28ff55390c6254cfd52f Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 18:49:00 +0000
Subject: [PATCH 2/3] avoid using size_t

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0ab3bd13cb72533..8b11c144137dbcd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1161,8 +1161,8 @@ bool ConvertOp::needsExtraSort() {
 }
 
 LogicalResult CrdTranslateOp::verify() {
-  size_t inRank = getOracle().getLvlRank();
-  size_t outRank = getOracle().getDimRank();
+  uint64_t inRank = getOracle().getLvlRank();
+  uint64_t outRank = getOracle().getDimRank();
 
   if (getDirection() == CrdTransDirectionKind::dim2lvl)
     std::swap(inRank, outRank);

>From 611eb1ab45b6ed1eb72a6946375b094eb62474df Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 20:12:26 +0000
Subject: [PATCH 3/3] address comments

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td       |  4 ++--
 .../Dialect/SparseTensor/IR/SparseTensorOps.td    | 13 ++++++++++---
 .../SparseTensor/IR/SparseTensorDialect.cpp       | 15 ++++++++-------
 3 files changed, 20 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 8b80384d59fdd83..b0fbbd747b76604 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -546,7 +546,7 @@ def SparseTensorSortKindAttr
 // Sparse Tensor Coordinate Translation Direction Attribute.
 //===----------------------------------------------------------------------===//
 
-// The C++ enum for sparse tensor sort kind.
+// The C++ enum for sparse tensor coordinate translation direction enum.
 def SparseTensorCrdTransDirectionEnum
     : I32EnumAttr<"CrdTransDirectionKind", "sparse tensor coordinate translation direction", [
         I32EnumAttrCase<"dim2lvl", 0, "dim_to_lvl">,
@@ -556,7 +556,7 @@ def SparseTensorCrdTransDirectionEnum
   let cppNamespace = SparseTensor_Dialect.cppNamespace;
 }
 
-// Define the enum sparse tensor sort kind attribute.
+// The C++ enum for sparse tensor coordinate translation direction attribute.
 def SparseTensorCrdTransDirectionAttr
     : EnumAttr<SparseTensor_Dialect, SparseTensorCrdTransDirectionEnum,
                "CrdTransDirection"> {
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 03cfc5ca8f6011f..4ed3f8a58cbd492 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -527,14 +527,21 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
 def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
     Arguments<(ins Variadic<Index>:$in_crds,
                    SparseTensorCrdTransDirectionAttr:$direction,
-                   SparseTensorEncodingAttr:$oracle)>,
+                   SparseTensorEncodingAttr:$encoder)>,
     Results<(outs Variadic<Index>:$out_crds)> {
   string summary = "Performs coordinate translation between level and dimension coordinate space.";
   string description = [{
     Performs coordinate translation between level and dimension coordinate space according
-    to the provided affine maps.
+    to the affine maps defined by $encoding.
+
+    Example:
+
+    ```mlir
+    %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%d0, %d1] as #BSR
+                       : index, index, index, index
+    ```
   }];
-  let assemblyFormat = "$direction `[` $in_crds `]` `as` $oracle attr-dict `:` type($out_crds)";
+  let assemblyFormat = "$direction `[` $in_crds `]` `as` $encoder attr-dict `:` type($out_crds)";
   let hasVerifier = 1;
   let hasFolder = 1;
 }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 8b11c144137dbcd..5c6bd8620234e49 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1161,8 +1161,8 @@ bool ConvertOp::needsExtraSort() {
 }
 
 LogicalResult CrdTranslateOp::verify() {
-  uint64_t inRank = getOracle().getLvlRank();
-  uint64_t outRank = getOracle().getDimRank();
+  uint64_t inRank = getEncoder().getLvlRank();
+  uint64_t outRank = getEncoder().getDimRank();
 
   if (getDirection() == CrdTransDirectionKind::dim2lvl)
     std::swap(inRank, outRank);
@@ -1175,10 +1175,10 @@ LogicalResult CrdTranslateOp::verify() {
 
 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
                                    SmallVectorImpl<OpFoldResult> &results) {
-  if (getOracle().isPermutation()) {
+  if (getEncoder().isPermutation()) {
     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
-                         ? getOracle().getDimToLvl()
-                         : getOracle().getLvlToDim();
+                         ? getEncoder().getDimToLvl()
+                         : getEncoder().getLvlToDim();
     for (AffineExpr exp : perm.getResults())
       results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
     return success();
@@ -1193,12 +1193,13 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
     return failure();
 
   bool oppositeDir = def.getDirection() != getDirection();
-  bool sameOracle = def.getOracle().getDimToLvl() == getOracle().getDimToLvl();
+  bool sameOracle =
+      def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
   bool sameCount = def.getNumResults() == getInCrds().size();
   if (!oppositeDir || !sameOracle || !sameCount)
     return failure();
 
-  // The definition produce the coordinate in the same order as the input
+  // The definition produces the coordinates in the same order as the input
   // coordinates.
   bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
                                 [](auto valuePair) {



More information about the Mlir-commits mailing list