[Mlir-commits] [mlir] 12189f8 - [mlir][sparse] introduce `sparse_tensor.extract_value` operation. (#101219)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 30 16:26:21 PDT 2024


Author: Peiming Liu
Date: 2024-07-30T16:26:18-07:00
New Revision: 12189f800585ab459afa20b4f159db93ae474b57

URL: https://github.com/llvm/llvm-project/commit/12189f800585ab459afa20b4f159db93ae474b57
DIFF: https://github.com/llvm/llvm-project/commit/12189f800585ab459afa20b4f159db93ae474b57.diff

LOG: [mlir][sparse] introduce `sparse_tensor.extract_value` operation. (#101219)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index f31df080d7811..ff9858d5832ba 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1531,6 +1531,31 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
   let hasVerifier = 1;
 }
 
+def ExtractValOp : SparseTensor_Op<"extract_value", [
+    Pure,
+    TypesMatchWith<"result type matches element type of tensor",
+                   "tensor", "result",
+                   "::llvm::cast<TensorType>($_self).getElementType()">]> {
+  let summary = "Extracts a value from a sparse tensor using an iterator.";
+  let description = [{
+      The `sparse_tensor.extract_value` operation extracts the value
+      pointed to by a sparse iterator from a sparse tensor.
+
+      Example:
+
+      ```mlir
+      %val = sparse_tensor.extract_value %sp at %it
+           : tensor<?x?xf32, #CSR>, !sparse_tensor.iterator<#CSR, lvl = 1>
+      ```
+  }];
+
+  let arguments = (ins AnySparseTensor:$tensor, AnySparseIterator:$iterator);
+  let results = (outs AnyType:$result);
+
+  let assemblyFormat = "$tensor `at` $iterator attr-dict `:` type($tensor)`,` qualified(type($iterator))";
+  let hasVerifier = 1;
+}
+
 def IterateOp : SparseTensor_Op<"iterate",
     [RecursiveMemoryEffects, RecursivelySpeculatable,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 616e91ae04055..0a276d87f3bca 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2267,6 +2267,19 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+LogicalResult ExtractValOp::verify() {
+  auto stt = getSparseTensorType(getTensor());
+  auto itTp = getIterator().getType();
+
+  if (stt.getEncoding() != itTp.getEncoding())
+    return emitOpError("mismatch in tensor encoding and iterator encoding.");
+
+  if (stt.getLvlRank() != itTp.getHiLvl())
+    return emitOpError("must use last-level iterator to extract values. ");
+
+  return success();
+}
+
 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
   using OpRewritePattern::OpRewritePattern;
 

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index eb0dc01be25b9..61cc9be88685c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1099,6 +1099,42 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
   return
 }
 
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : compressed
+  )
+}>
+
+func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 1>) -> f32 {
+  // expected-error at +1 {{'sparse_tensor.extract_value' op mismatch in tensor encoding and iterator encoding.}}
+  %f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 1>
+  return %f : f32
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) -> f32 {
+  // expected-error at +1 {{'sparse_tensor.extract_value' op must use last-level iterator to extract values.}}
+  %f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  return %f : f32
+}
 
 // -----
 

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index bce0b41a99828..055709ee69eb7 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -739,6 +739,27 @@ func.func @sparse_has_runtime() -> i1 {
   return %has_runtime : i1
 }
 
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_extract_value(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
+// CHECK-SAME:      %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse, lvls = 1>) -> f32 {
+// CHECK:           %[[VAL_2:.*]] = sparse_tensor.extract_value %[[VAL_0]] at %[[VAL_1]] : tensor<4x8xf32, #sparse>, !sparse_tensor.iterator<#sparse, lvls = 1>
+// CHECK:           return %[[VAL_2]] : f32
+// CHECK:         }
+func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 1>) -> f32 {
+  %f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 1>
+  return %f : f32
+}
+
+
 // -----
 
 #COO = #sparse_tensor.encoding<{


        


More information about the Mlir-commits mailing list