[Mlir-commits] [mlir] [mlir][sparse] implement lowering rules for ExtractIterSpaceOp. (PR #89143)
Peiming Liu
llvmlistbot at llvm.org
Tue Jun 11 13:40:53 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/89143
>From 66e26d222a35caa8afe54a91b9920412f7ccd8df Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 17 Apr 2024 18:51:24 +0000
Subject: [PATCH 1/3] [mlir][sparse] implement lowering rules for
ExtractIterSpaceOp.
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 4 +
.../Dialect/SparseTensor/Transforms/Passes.h | 15 +++
.../Dialect/SparseTensor/Transforms/Passes.td | 15 +++
.../SparseTensor/Transforms/CMakeLists.txt | 1 +
.../Transforms/SparseIterationToScf.cpp | 78 ++++++++++++
.../Transforms/SparseTensorPasses.cpp | 29 +++++
.../Transforms/Utils/SparseTensorIterator.cpp | 117 ++++++++++++++----
.../Transforms/Utils/SparseTensorIterator.h | 67 +++++++++-
.../SparseTensor/sparse_iteration_to_scf.mlir | 23 ++++
.../SparseTensor/sparse_space_collapse.mlir | 11 +-
10 files changed, 325 insertions(+), 35 deletions(-)
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
create mode 100644 mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 33f613a46bad8..96ee7111fea2c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -357,6 +357,10 @@ struct LevelType {
return hasSparseSemantic();
}
+ constexpr unsigned getNumBuffer() const {
+ return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
+ }
+
std::string toMLIRString() const {
std::string lvlStr = toFormatString(getLvlFmt());
std::string propStr = "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 3043a0c4dc410..c9164e39a3a75 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -16,6 +16,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
//===----------------------------------------------------------------------===//
// Include the generated pass header (which needs some early definitions).
@@ -143,6 +144,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
std::unique_ptr<Pass> createLowerForeachToSCFPass();
+//===----------------------------------------------------------------------===//
+// The LowerSparseIterationToSCF pass.
+//===----------------------------------------------------------------------===//
+
+/// Type converter for iter_space and iterator.
+struct SparseIterationTypeConverter : public OneToNTypeConverter {
+ SparseIterationTypeConverter();
+};
+
+void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
+
//===----------------------------------------------------------------------===//
// The SparseTensorConversion pass.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index c6554e1c94a4a..ec25bb42aa174 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -480,4 +480,19 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
];
}
+def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
+ let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
+ let description = [{
+ This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
+ The pass is not yet stablized.
+ }];
+ let constructor = "mlir::createLowerSparseIterationToSCFPass()";
+ let dependentDialects = [
+ "memref::MemRefDialect",
+ "scf::SCFDialect",
+ "sparse_tensor::SparseTensorDialect",
+ ];
+}
+
+
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 2a29ee8a7a87c..e4acfa8889e5f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
SparseAssembler.cpp
SparseBufferRewriting.cpp
SparseGPUCodegen.cpp
+ SparseIterationToScf.cpp
SparseReinterpretMap.cpp
SparseStorageSpecifierToLLVM.cpp
SparseSpaceCollapse.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
new file mode 100644
index 0000000000000..d89b0b192ffcd
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -0,0 +1,78 @@
+
+#include "Utils/CodegenUtils.h"
+#include "Utils/SparseTensorIterator.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
+ SmallVectorImpl<Type> &fields) {
+ // Position and coordinate buffer in the sparse structure.
+ if (enc.getLvlType(lvl).isWithPosLT())
+ fields.push_back(enc.getPosMemRefType());
+ if (enc.getLvlType(lvl).isWithCrdLT())
+ fields.push_back(enc.getCrdMemRefType());
+ // One index for shape bound (result from lvlOp)
+ fields.push_back(IndexType::get(enc.getContext()));
+}
+
+static std::optional<LogicalResult>
+convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
+
+ auto idxTp = IndexType::get(itSp.getContext());
+ for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
+ convertLevelType(itSp.getEncoding(), l, fields);
+
+ // Two indices for lower and upper bound (we only need one pair for the last
+ // iteration space).
+ fields.append({idxTp, idxTp});
+ return success();
+}
+
+namespace {
+
+/// Sparse codegen rule for number of entries operator.
+class ExtractIterSpaceConverter
+ : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+public:
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+
+ // Construct the iteration space.
+ SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
+ op.getLvlRange(), adaptor.getParentIter());
+
+ SmallVector<Value> result = space.toValues();
+ rewriter.replaceOp(op, result, resultMapping);
+ return success();
+ }
+};
+
+} // namespace
+
+mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
+ addConversion([](Type type) { return type; });
+ addConversion(convertIterSpaceType);
+
+ addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
+ ValueRange inputs,
+ Location loc) -> std::optional<Value> {
+ return builder
+ .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
+ .getResult(0);
+ });
+}
+
+void mlir::populateLowerSparseIterationToSCFPatterns(
+ TypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b42d58634a36c..ffbc85e9a17f5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -26,6 +26,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
+#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -153,10 +154,34 @@ struct LowerForeachToSCFPass
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateLowerForeachToSCFPatterns(patterns);
+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
+struct LowerSparseIterationToSCFPass
+ : public impl::LowerSparseIterationToSCFBase<
+ LowerSparseIterationToSCFPass> {
+ LowerSparseIterationToSCFPass() = default;
+ LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
+ default;
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ SparseIterationTypeConverter converter;
+ ConversionTarget target(*ctx);
+
+ // The actual conversion.
+ target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+ populateLowerSparseIterationToSCFPatterns(converter, patterns);
+
+ if (failed(applyPartialOneToNConversion(getOperation(), converter,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
struct SparseTensorConversionPass
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
SparseTensorConversionPass() = default;
@@ -439,6 +464,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
return std::make_unique<LowerForeachToSCFPass>();
}
+std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
+ return std::make_unique<LowerSparseIterationToSCFPass>();
+}
+
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
return std::make_unique<SparseTensorConversionPass>();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index dbec46d2616d9..be8e15d6ae6f4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -168,7 +168,7 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
ValueRange posRange = posRangeIf.getResults();
return {posRange.front(), posRange.back()};
}
-};
+}; // namespace
class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
@@ -190,7 +190,7 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}
-};
+}; // namespace
class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
@@ -210,6 +210,13 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
// Use the segHi as the loop upper bound.
return {p, segHi};
}
+
+ ValuePair
+ collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
+ std::pair<Value, Value> parentRange) const override {
+ // Singleton level keeps the same range after collapsing.
+ return parentRange;
+ };
};
class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
@@ -1474,10 +1481,85 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
return getCursor();
}
+//===----------------------------------------------------------------------===//
+// SparseIterationSpace Implementation
+//===----------------------------------------------------------------------===//
+
+mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
+ Location l, OpBuilder &b, Value t, unsigned tid,
+ std::pair<Level, Level> lvlRange, ValueRange parentPos)
+ : lvls() {
+ auto [lvlLo, lvlHi] = lvlRange;
+
+ Value c0 = C_IDX(0);
+ if (parentPos.empty())
+ parentPos = c0;
+
+ for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
+ lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
+
+ bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
+ for (auto &lvl : getLvlRef().drop_front())
+ bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
+}
+
+SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
+ IterSpaceType dstTp, ValueRange values, unsigned int tid) {
+ // Reconstruct every sparse tensor level.
+ SparseIterationSpace space;
+ for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
+ unsigned bufferCnt = 0;
+ if (lt.isWithPosLT())
+ bufferCnt++;
+ if (lt.isWithCrdLT())
+ bufferCnt++;
+ // Sparse tensor buffers.
+ ValueRange buffers = values.take_front(bufferCnt);
+ values = values.drop_front(bufferCnt);
+
+ // Level size.
+ Value sz = values.front();
+ values = values.drop_front();
+ space.lvls.push_back(
+ makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
+ }
+ // Two bounds.
+ space.bound = std::make_pair(values[0], values[1]);
+ values = values.drop_front(2);
+
+ // Must have consumed all values.
+ assert(values.empty());
+ return space;
+}
+
//===----------------------------------------------------------------------===//
// SparseIterator factory functions.
//===----------------------------------------------------------------------===//
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
+ unsigned t, Level l) {
+ assert(lt.getNumBuffer() == b.size());
+ switch (lt.getLvlFmt()) {
+ case LevelFormat::Dense:
+ return std::make_unique<DenseLevel>(t, l, sz);
+ case LevelFormat::Batch:
+ return std::make_unique<BatchLevel>(t, l, sz);
+ case LevelFormat::Compressed:
+ return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
+ case LevelFormat::LooseCompressed:
+ return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
+ case LevelFormat::Singleton:
+ return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
+ case LevelFormat::NOutOfM:
+ return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
+ case LevelFormat::Undef:
+ llvm_unreachable("undefined level format");
+ }
+ llvm_unreachable("unrecognizable level format");
+}
+
std::unique_ptr<SparseTensorLevel>
sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
unsigned tid, Level lvl) {
@@ -1487,33 +1569,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
: b.create<tensor::DimOp>(l, t, lvl).getResult();
- switch (lt.getLvlFmt()) {
- case LevelFormat::Dense:
- return std::make_unique<DenseLevel>(tid, lvl, sz);
- case LevelFormat::Batch:
- return std::make_unique<BatchLevel>(tid, lvl, sz);
- case LevelFormat::Compressed: {
- Value pos = b.create<ToPositionsOp>(l, t, lvl);
- Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
- return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
- }
- case LevelFormat::LooseCompressed: {
+ SmallVector<Value, 2> buffers;
+ if (lt.isWithPosLT()) {
Value pos = b.create<ToPositionsOp>(l, t, lvl);
- Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
- return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
- }
- case LevelFormat::Singleton: {
- Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
- return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
+ buffers.push_back(pos);
}
- case LevelFormat::NOutOfM: {
- Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
- return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
+ if (lt.isWithCrdLT()) {
+ Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
+ buffers.push_back(pos);
}
- case LevelFormat::Undef:
- llvm_unreachable("undefined level format");
- }
- llvm_unreachable("unrecognizable level format");
+ return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
}
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 120a806536f19..09503d4b6a099 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -15,6 +15,9 @@
namespace mlir {
namespace sparse_tensor {
+// Forward declaration.
+class SparseIterator;
+
/// The base class for all types of sparse tensor levels. It provides interfaces
/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
/// `peekCrdAt`).
@@ -50,6 +53,12 @@ class SparseTensorLevel {
peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
ValueRange parentPos, Value inPadZone = nullptr) const = 0;
+ virtual std::pair<Value, Value>
+ collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
+ std::pair<Value, Value> parentRange) const {
+ llvm_unreachable("Not Implemented");
+ };
+
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
Value getSize() const { return lvlSize; }
@@ -79,6 +88,51 @@ enum class IterKind : uint8_t {
kPad,
};
+class SparseIterationSpace {
+public:
+ SparseIterationSpace() = default;
+
+ // Constructs a N-D iteration space.
+ SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
+ std::pair<Level, Level> lvlRange, ValueRange parentPos);
+
+ // Constructs a 1-D iteration space.
+ SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
+ Level lvl, ValueRange parentPos)
+ : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos){};
+
+ bool isUnique() const { return lvls.back()->isUnique(); }
+
+ unsigned getSpaceDim() const { return lvls.size(); }
+
+ // Reconstructs a iteration space directly from the provided ValueRange.
+ static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values,
+ unsigned tid);
+
+ // The inverse operation of `fromValues`.
+ SmallVector<Value> toValues() const {
+ SmallVector<Value> vals;
+ for (auto &stl : lvls) {
+ llvm::append_range(vals, stl->getLvlBuffers());
+ vals.push_back(stl->getSize());
+ }
+ vals.append({bound.first, bound.second});
+ return vals;
+ }
+
+ const SparseTensorLevel &getLastLvl() const { return *lvls.back(); }
+ ArrayRef<std::unique_ptr<SparseTensorLevel>> getLvlRef() const {
+ return lvls;
+ }
+
+ Value getBoundLo() const { return bound.first; }
+ Value getBoundHi() const { return bound.second; }
+
+private:
+ SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
+ std::pair<Value, Value> bound;
+};
+
/// Helper class that generates loop conditions, etc, to traverse a
/// sparse tensor level.
class SparseIterator {
@@ -287,10 +341,15 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
unsigned tid,
Level lvl);
-/// Helper function to create a simple SparseIterator object that iterate over
-/// the SparseTensorLevel.
-std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
- SparseEmitStrategy strategy);
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
+ ValueRange buffers,
+ unsigned tid, Level l);
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the SparseTensorLevel.
+std::unique_ptr<SparseIterator> makeSimpleIterator(
+ const SparseTensorLevel &stl,
+ SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
/// Helper function to create a synthetic SparseIterator object that iterates
/// over a dense space specified by [0,`sz`).
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
new file mode 100644
index 0000000000000..b51c11ebf8a8c
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+ map = (i, j) -> (
+ i : compressed(nonunique),
+ j : singleton(soa)
+ )
+}>
+
+// CHECK-LABEL: func.func @sparse_1D_space(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
+func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO>
+ return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
index baa6199f12bc3..b99bf915c71f8 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -18,20 +18,21 @@
// CHECK: "test.sink"(%[[VAL_4]]) : (index) -> ()
// CHECK: return
// CHECK: }
-func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
+func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>) -> index {
+ %i = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
: tensor<4x8xf32, #COO>
-> !sparse_tensor.iter_space<#COO, lvls = 0>
- %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+ %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
-> !sparse_tensor.iter_space<#COO, lvls = 1>
%r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
- %k ="test.op"(%inner) : (index) -> index
+ %k = arith.addi %inner, %c1 : index
sparse_tensor.yield %k : index
}
sparse_tensor.yield %r2 : index
}
- "test.sink"(%r1) : (index) -> ()
- return
+ return %r1 : index
}
>From 9e21f22cfb1069061e9f56e26bbf9ed9ade90382 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 11 Jun 2024 20:14:43 +0000
Subject: [PATCH 2/3] rebase
---
.../SparseTensor/sparse_iteration_to_scf.mlir | 2 +-
.../SparseTensor/sparse_space_collapse.mlir | 16 ++++++++--------
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
index b51c11ebf8a8c..5fcd661bb69b2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -18,6 +18,6 @@
// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO>
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
index b99bf915c71f8..03a21e5d2c789 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -8,15 +8,15 @@
}>
// CHECK-LABEL: func.func @sparse_sparse_collapse(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
-// CHECK-SAME: %[[VAL_1:.*]]: index) {
-// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
-// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
-// CHECK: %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
-// CHECK: sparse_tensor.yield %[[VAL_8]] : index
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>) -> index {
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2
+// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]])
+// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_2]] : index
+// CHECK: sparse_tensor.yield %[[VAL_7]] : index
// CHECK: }
-// CHECK: "test.sink"(%[[VAL_4]]) : (index) -> ()
-// CHECK: return
+// CHECK: return %[[VAL_4]] : index
// CHECK: }
func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>) -> index {
%i = arith.constant 0 : index
>From 1d6af0ab307cf86cffaf01b6f6b9348bb670db45 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 11 Jun 2024 20:27:50 +0000
Subject: [PATCH 3/3] address comments
---
.../Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp | 1 -
.../SparseTensor/Transforms/Utils/SparseTensorIterator.h | 6 +++++-
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index ffbc85e9a17f5..8004bdb904b8a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -154,7 +154,6 @@ struct LowerForeachToSCFPass
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateLowerForeachToSCFPatterns(patterns);
-
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 09503d4b6a099..5f452a9354006 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -88,6 +88,10 @@ enum class IterKind : uint8_t {
kPad,
};
+/// A `SparseIterationSpace` represents a sparse set of coordinates defined by
+/// (possibly multiple) levels of a specific sparse tensor.
+/// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when
+/// feature complete.
class SparseIterationSpace {
public:
SparseIterationSpace() = default;
@@ -345,7 +349,7 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
ValueRange buffers,
unsigned tid, Level l);
-/// Helper function to create a simple SparseIterator object that iterate
+/// Helper function to create a simple SparseIterator object that iterates
/// over the SparseTensorLevel.
std::unique_ptr<SparseIterator> makeSimpleIterator(
const SparseTensorLevel &stl,
More information about the Mlir-commits
mailing list