[Mlir-commits] [mlir] 951a363 - [mlir][sparse] implement `sparse_tensor.extract_value` operation. (#101220)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 31 14:47:33 PDT 2024
Author: Peiming Liu
Date: 2024-07-31T14:47:29-07:00
New Revision: 951a36309787c39d102798c7b86b06caa1a35257
URL: https://github.com/llvm/llvm-project/commit/951a36309787c39d102798c7b86b06caa1a35257
DIFF: https://github.com/llvm/llvm-project/commit/951a36309787c39d102798c7b86b06caa1a35257.diff
LOG: [mlir][sparse] implement `sparse_tensor.extract_value` operation. (#101220)
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 1d614b7b29361..b1451dee738ac 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -2,6 +2,7 @@
#include "Utils/CodegenUtils.h"
#include "Utils/SparseTensorIterator.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -10,8 +11,8 @@
using namespace mlir;
using namespace mlir::sparse_tensor;
-void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
- SmallVectorImpl<Type> &fields) {
+static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
+ SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
if (enc.getLvlType(lvl).isWithPosLT())
fields.push_back(enc.getPosMemRefType());
@@ -71,6 +72,21 @@ class ExtractIterSpaceConverter
}
};
+/// Sparse codegen rule for number of entries operator.
+class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
+public:
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value pos = adaptor.getIterator().back();
+ Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
+ return success();
+ }
+};
+
class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
public:
using OneToNOpConversionPattern::OneToNOpConversionPattern;
@@ -193,6 +209,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
TypeConverter &converter, RewritePatternSet &patterns) {
IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
- patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
- converter, patterns.getContext());
+ patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
+ SparseIterateOpConverter>(converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index c612a52aa8d50..08fc104fcbeea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -357,6 +357,9 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
const auto pos = env.emitter().getValPosits(tid);
assert(!pos.empty());
args.append(pos);
+ // Simply returns the tensor to extract value using iterators.
+ if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator)
+ return t->get();
} else {
// For dense tensors we push all level's coordinates onto `args`.
const Level lvlRank = stt.getLvlRank();
@@ -512,9 +515,16 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
return genInsertionLoadReduce(env, builder, t);
return genInsertionLoad(env, builder, t);
}
+
// Actual load.
SmallVector<Value> args;
Value ptr = genSubscript(env, builder, t, args);
+ if (llvm::isa<TensorType>(ptr.getType())) {
+ assert(env.options().sparseEmitStrategy ==
+ SparseEmitStrategy::kSparseIterator &&
+ args.size() == 1);
+ return builder.create<ExtractValOp>(loc, ptr, args.front());
+ }
return builder.create<memref::LoadOp>(loc, ptr, args);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 2a884b10e36b0..f3e73e4692c1f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -221,6 +221,11 @@ class LoopEmitter {
/// Getters.
///
SmallVector<Value> getValPosits(TensorId tid) const {
+ // Returns the iterator if we are generating sparse (co)iterate-based loops.
+ if (emitStrategy == SparseEmitStrategy::kSparseIterator)
+ return {spIterVals[tid].back()};
+
+ // Returns {[batch coords], last-level position}.
SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
batchCrds.push_back(lastLvlPos);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
index f5bbea0d340fb..268b3940418b7 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
#COO = #sparse_tensor.encoding<{
@@ -7,8 +7,7 @@
d1 : singleton(nonunique, soa),
d2 : singleton(nonunique, soa),
d3 : singleton(soa)
- ),
- explicitVal = 1 : i32
+ )
}>
// CHECK-LABEL: func.func @sqsum(
@@ -17,7 +16,10 @@
// CHECK-DAG: %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse> to memref<?xindex>
// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C0]]] : memref<?xindex>
// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK: %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse> to memref<?xi32>
// CHECK: %[[SQ_SUM:.*]] = scf.for %[[POS:.*]] = %[[POS_LO]] to %[[POS_HI]] step %[[C1]] {{.*}} {
+// CHECK: %[[VAL:.*]] = memref.load %[[VAL_BUF]]{{\[}}%[[POS]]] : memref<?xi32>
+// CHECK: %[[MUL:.*]] = arith.muli %[[VAL]], %[[VAL]] : i32
// CHECK: %[[SUM:.*]] = arith.addi
// CHECK: scf.yield %[[SUM]] : i32
// CHECK: }
More information about the Mlir-commits
mailing list