[Mlir-commits] [mlir] ff21a90 - [mlir][sparse] introduce sparse_tensor.crd_translate operation (#69630)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 19 15:42:13 PDT 2023


Author: Peiming Liu
Date: 2023-10-19T15:42:09-07:00
New Revision: ff21a90e51ac3ad954df4f13adcf030a24c2a6a3

URL: https://github.com/llvm/llvm-project/commit/ff21a90e51ac3ad954df4f13adcf030a24c2a6a3
DIFF: https://github.com/llvm/llvm-project/commit/ff21a90e51ac3ad954df4f13adcf030a24c2a6a3.diff

LOG: [mlir][sparse] introduce sparse_tensor.crd_translate operation (#69630)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/fold.mlir
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 47fd18a689d5a8d..b0fbbd747b76604 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -541,4 +541,26 @@ def SparseTensorSortKindAttr
                "SparseTensorSortAlgorithm"> {
 }
 
+
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Coordinate Translation Direction Attribute.
+//===----------------------------------------------------------------------===//
+
+// The C++ enum for sparse tensor coordinate translation direction enum.
+def SparseTensorCrdTransDirectionEnum
+    : I32EnumAttr<"CrdTransDirectionKind", "sparse tensor coordinate translation direction", [
+        I32EnumAttrCase<"dim2lvl", 0, "dim_to_lvl">,
+        I32EnumAttrCase<"lvl2dim", 1, "lvl_to_dim">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = SparseTensor_Dialect.cppNamespace;
+}
+
+// The C++ enum for sparse tensor coordinate translation direction attribute.
+def SparseTensorCrdTransDirectionAttr
+    : EnumAttr<SparseTensor_Dialect, SparseTensorCrdTransDirectionEnum,
+               "CrdTransDirection"> {
+}
+
+
 #endif // SPARSETENSOR_ATTRDEFS

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index c446b84c5d34103..7209f2ef8488bcc 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -520,6 +520,32 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Coordinate Translation Operation.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
+    Arguments<(ins Variadic<Index>:$in_crds,
+                   SparseTensorCrdTransDirectionAttr:$direction,
+                   SparseTensorEncodingAttr:$encoder)>,
+    Results<(outs Variadic<Index>:$out_crds)> {
+  string summary = "Performs coordinate translation between level and dimension coordinate space.";
+  string description = [{
+    Performs coordinate translation between level and dimension coordinate space according
+    to the affine maps defined by $encoder.
+
+    Example:
+
+    ```mlir
+    %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%d0, %d1] as #BSR
+                       : index, index, index, index
+    ```
+  }];
+  let assemblyFormat = "$direction `[` $in_crds `]` `as` $encoder attr-dict `:` type($out_crds)";
+  let hasVerifier = 1;
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Management Operations. These operations are "impure" in the
 // sense that some behavior is defined by side-effects. These operations provide

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b03a9140a9f1994..c6e7bfaf47d04d3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1140,6 +1140,61 @@ bool ConvertOp::needsExtraSort() {
   return true;
 }
 
+LogicalResult CrdTranslateOp::verify() {
+  uint64_t inRank = getEncoder().getLvlRank();
+  uint64_t outRank = getEncoder().getDimRank();
+
+  if (getDirection() == CrdTransDirectionKind::dim2lvl)
+    std::swap(inRank, outRank);
+
+  if (inRank != getInCrds().size() || outRank != getOutCrds().size())
+    return emitError("Coordinate rank mismatch with encoding");
+
+  return success();
+}
+
+LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
+                                   SmallVectorImpl<OpFoldResult> &results) {
+  if (getEncoder().isPermutation()) {
+    AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
+                         ? getEncoder().getDimToLvl()
+                         : getEncoder().getLvlToDim();
+    for (AffineExpr exp : perm.getResults())
+      results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
+    return success();
+  }
+
+  // Fuse dim2lvl/lvl2dim pairs.
+  auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
+  bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
+                   return v.getDefiningOp() == def;
+                 });
+  if (!sameDef)
+    return failure();
+
+  bool oppositeDir = def.getDirection() != getDirection();
+  bool sameOracle =
+      def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
+  bool sameCount = def.getNumResults() == getInCrds().size();
+  if (!oppositeDir || !sameOracle || !sameCount)
+    return failure();
+
+  // The definition produces the coordinates in the same order as the input
+  // coordinates.
+  bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
+                                [](auto valuePair) {
+                                  auto [lhs, rhs] = valuePair;
+                                  return lhs == rhs;
+                                });
+
+  if (!sameOrder)
+    return failure();
+  // l1 = dim2lvl (lvl2dim l0)
+  // ==> l0
+  results.append(def.getInCrds().begin(), def.getInCrds().end());
+  return success();
+}
+
 LogicalResult ToPositionsOp::verify() {
   auto e = getSparseTensorEncoding(getTensor().getType());
   if (failed(lvlIsInBounds(getLevel(), getTensor())))

diff  --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir
index 3dd1a629c129fff..3428f6d4ae5a117 100644
--- a/mlir/test/Dialect/SparseTensor/fold.mlir
+++ b/mlir/test/Dialect/SparseTensor/fold.mlir
@@ -75,3 +75,21 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #COO>) -> tensor<?x?xf32,
   %ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #COO> to tensor<?x?xf32, #COO>
   return %ret : tensor<?x?xf32, #COO>
 }
+
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+// CHECK-LABEL: func @sparse_crd_translate(
+//   CHECK-NOT:   sparse_tensor.crd_translate
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
+  %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
+  %d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
+  return  %d0, %d1 : index, index
+}

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 805f3d161921c13..1ab1bac6b592ee7 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -861,3 +861,37 @@ func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -
   %ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf64, #OrderedCOO>
   return %ret : tensor<?x?xf64, #OrderedCOO>
 }
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index) {
+  // expected-error at +1 {{Coordinate rank mismatch with encoding}}
+  %l0, %l1, %l2 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index
+  return  %l0, %l1, %l2 : index, index, index
+}
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  // expected-error at +1 {{Coordinate rank mismatch with encoding}}
+  %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
+  return  %l0, %l1, %l2, %l3 : index, index, index, index
+}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index cbc3bb824924cdb..af9618ebe380d80 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -647,3 +647,25 @@ func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<
   %ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOO>
   return %ret : tensor<?x?xf32, #OrderedCOO>
 }
+
+
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i floordiv 2 : dense,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_crd_translate(
+// CHECK-SAME:      %[[VAL_0:.*]]: index,
+// CHECK-SAME:      %[[VAL_1:.*]]: index)
+// CHECK:           %[[VAL_2:.*]]:4 = sparse_tensor.crd_translate  dim_to_lvl{{\[}}%[[VAL_0]], %[[VAL_1]]]
+// CHECK:           return %[[VAL_2]]#0, %[[VAL_2]]#1, %[[VAL_2]]#2, %[[VAL_2]]#3
+func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, index, index) {
+  %l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
+  return  %l0, %l1, %l2, %l3 : index, index, index, index
+}


        


More information about the Mlir-commits mailing list