[Mlir-commits] [mlir] [mlir][sparse] implement s`parse_tensor.extract_value` operation. (PR #101220)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 30 11:50:19 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/101220.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+25) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+13) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+20-4) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+10) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+5) 
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+36) 
- (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+21) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir (+5-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index f31df080d7811..14fb471939f11 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1531,6 +1531,31 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
   let hasVerifier = 1;
 }
 
+def ExtractValOp : SparseTensor_Op<"extract_value", [
+    Pure,
+    TypesMatchWith<"result type matches element type of tensor",
+                   "tensor", "result",
+                   "::llvm::cast<TensorType>($_self).getElementType()">]> {
+  let summary = "Extracts a value from a sparse tensor.";
+  let description = [{
+      The `sparse_tensor.extract_value` operation extract the value
+      pointed by a sparse iterator from a sparse tensor.
+
+      Example:
+
+      ```mlir
+      %val = sparse_tensor.extract_value %sp at %it
+           : tensor<?x?xf32, #CSR>, !sparse_tensor.iterator<#CSR, lvl = 1>
+      ```
+  }];
+
+  let arguments = (ins AnySparseTensor:$tensor, AnySparseIterator:$iterator);
+  let results = (outs AnyType:$result);
+
+  let assemblyFormat = "$tensor `at` $iterator attr-dict `:` type($tensor)`,` qualified(type($iterator))";
+  let hasVerifier = 1;
+}
+
 def IterateOp : SparseTensor_Op<"iterate",
     [RecursiveMemoryEffects, RecursivelySpeculatable,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 616e91ae04055..0a276d87f3bca 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2267,6 +2267,19 @@ LogicalResult ExtractIterSpaceOp::verify() {
   return success();
 }
 
+LogicalResult ExtractValOp::verify() {
+  auto stt = getSparseTensorType(getTensor());
+  auto itTp = getIterator().getType();
+
+  if (stt.getEncoding() != itTp.getEncoding())
+    return emitOpError("mismatch in tensor encoding and iterator encoding.");
+
+  if (stt.getLvlRank() != itTp.getHiLvl())
+    return emitOpError("must use last-level iterator to extract values. ");
+
+  return success();
+}
+
 struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
   using OpRewritePattern::OpRewritePattern;
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 1d614b7b29361..b1451dee738ac 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -2,6 +2,7 @@
 #include "Utils/CodegenUtils.h"
 #include "Utils/SparseTensorIterator.h"
 
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -10,8 +11,8 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
-void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
-                      SmallVectorImpl<Type> &fields) {
+static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
+                             SmallVectorImpl<Type> &fields) {
   // Position and coordinate buffer in the sparse structure.
   if (enc.getLvlType(lvl).isWithPosLT())
     fields.push_back(enc.getPosMemRefType());
@@ -71,6 +72,21 @@ class ExtractIterSpaceConverter
   }
 };
 
+/// Sparse codegen rule for number of entries operator.
+class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
+public:
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value pos = adaptor.getIterator().back();
+    Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
+    return success();
+  }
+};
+
 class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
 public:
   using OneToNOpConversionPattern::OneToNOpConversionPattern;
@@ -193,6 +209,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
     TypeConverter &converter, RewritePatternSet &patterns) {
 
   IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
-  patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
-      converter, patterns.getContext());
+  patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
+               SparseIterateOpConverter>(converter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index c612a52aa8d50..08fc104fcbeea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -357,6 +357,9 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
     const auto pos = env.emitter().getValPosits(tid);
     assert(!pos.empty());
     args.append(pos);
+    // Simply returns the tensor to extract value using iterators.
+    if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator)
+      return t->get();
   } else {
     // For dense tensors we push all level's coordinates onto `args`.
     const Level lvlRank = stt.getLvlRank();
@@ -512,9 +515,16 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
       return genInsertionLoadReduce(env, builder, t);
     return genInsertionLoad(env, builder, t);
   }
+
   // Actual load.
   SmallVector<Value> args;
   Value ptr = genSubscript(env, builder, t, args);
+  if (llvm::isa<TensorType>(ptr.getType())) {
+    assert(env.options().sparseEmitStrategy ==
+               SparseEmitStrategy::kSparseIterator &&
+           args.size() == 1);
+    return builder.create<ExtractValOp>(loc, ptr, args.front());
+  }
   return builder.create<memref::LoadOp>(loc, ptr, args);
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 2a884b10e36b0..f3e73e4692c1f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -221,6 +221,11 @@ class LoopEmitter {
   /// Getters.
   ///
   SmallVector<Value> getValPosits(TensorId tid) const {
+    // Returns the iterator if we are generating sparse (co)iterate-based loops.
+    if (emitStrategy == SparseEmitStrategy::kSparseIterator)
+      return {spIterVals[tid].back()};
+
+    // Returns {[batch coords], last-level position}.
     SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
     Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
     batchCrds.push_back(lastLvlPos);
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index eb0dc01be25b9..61cc9be88685c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1099,6 +1099,42 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
   return
 }
 
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+#CSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : dense,
+    j : compressed
+  )
+}>
+
+func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 1>) -> f32 {
+  // expected-error at +1 {{'sparse_tensor.extract_value' op mismatch in tensor encoding and iterator encoding.}}
+  %f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 1>
+  return %f : f32
+}
+
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) -> f32 {
+  // expected-error at +1 {{'sparse_tensor.extract_value' op must use last-level iterator to extract values.}}
+  %f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+  return %f : f32
+}
 
 // -----
 
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index bce0b41a99828..055709ee69eb7 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -739,6 +739,27 @@ func.func @sparse_has_runtime() -> i1 {
   return %has_runtime : i1
 }
 
+// -----
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_extract_value(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
+// CHECK-SAME:      %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse, lvls = 1>) -> f32 {
+// CHECK:           %[[VAL_2:.*]] = sparse_tensor.extract_value %[[VAL_0]] at %[[VAL_1]] : tensor<4x8xf32, #sparse>, !sparse_tensor.iterator<#sparse, lvls = 1>
+// CHECK:           return %[[VAL_2]] : f32
+// CHECK:         }
+func.func @sparse_extract_value(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 1>) -> f32 {
+  %f = sparse_tensor.extract_value %sp at %it1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 1>
+  return %f : f32
+}
+
+
 // -----
 
 #COO = #sparse_tensor.encoding<{
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index f5bbea0d340fb..268b3940418b7 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
 
 
 #COO = #sparse_tensor.encoding<{
@@ -7,8 +7,7 @@
     d1 : singleton(nonunique, soa),
     d2 : singleton(nonunique, soa),
     d3 : singleton(soa)
-  ),
-  explicitVal = 1 : i32
+  )
 }>
 
 // CHECK-LABEL:   func.func @sqsum(
@@ -17,7 +16,10 @@
 // CHECK-DAG:       %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse> to memref<?xindex>
 // CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C0]]] : memref<?xindex>
 // CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK:           %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse> to memref<?xi32>
 // CHECK:           %[[SQ_SUM:.*]] = scf.for %[[POS:.*]] = %[[POS_LO]] to %[[POS_HI]] step %[[C1]] {{.*}} {
+// CHECK:             %[[VAL:.*]] = memref.load %[[VAL_BUF]]{{\[}}%[[POS]]] : memref<?xi32>
+// CHECK:             %[[MUL:.*]] = arith.muli %[[VAL]], %[[VAL]] : i32
 // CHECK:             %[[SUM:.*]] = arith.addi
 // CHECK:             scf.yield %[[SUM]] : i32
 // CHECK:           }

``````````

</details>


https://github.com/llvm/llvm-project/pull/101220


More information about the Mlir-commits mailing list