[Mlir-commits] [mlir] [mlir][sparse] implement lowering rules for ExtractIterSpaceOp. (PR #89143)

Peiming Liu llvmlistbot at llvm.org
Tue Jun 11 13:59:22 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/89143

>From 66e26d222a35caa8afe54a91b9920412f7ccd8df Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 17 Apr 2024 18:51:24 +0000
Subject: [PATCH 1/4] [mlir][sparse] implement lowering rules for
 ExtractIterSpaceOp.

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |   4 +
 .../Dialect/SparseTensor/Transforms/Passes.h  |  15 +++
 .../Dialect/SparseTensor/Transforms/Passes.td |  15 +++
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseIterationToScf.cpp       |  78 ++++++++++++
 .../Transforms/SparseTensorPasses.cpp         |  29 +++++
 .../Transforms/Utils/SparseTensorIterator.cpp | 117 ++++++++++++++----
 .../Transforms/Utils/SparseTensorIterator.h   |  67 +++++++++-
 .../SparseTensor/sparse_iteration_to_scf.mlir |  23 ++++
 .../SparseTensor/sparse_space_collapse.mlir   |  11 +-
 10 files changed, 325 insertions(+), 35 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
 create mode 100644 mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 33f613a46bad8..96ee7111fea2c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -357,6 +357,10 @@ struct LevelType {
     return hasSparseSemantic();
   }
 
+  constexpr unsigned getNumBuffer() const {
+    return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
+  }
+
   std::string toMLIRString() const {
     std::string lvlStr = toFormatString(getLvlFmt());
     std::string propStr = "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 3043a0c4dc410..c9164e39a3a75 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
 
 //===----------------------------------------------------------------------===//
 // Include the generated pass header (which needs some early definitions).
@@ -143,6 +144,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
 
 std::unique_ptr<Pass> createLowerForeachToSCFPass();
 
+//===----------------------------------------------------------------------===//
+// The LowerSparseIterationToSCF pass.
+//===----------------------------------------------------------------------===//
+
+/// Type converter for iter_space and iterator.
+struct SparseIterationTypeConverter : public OneToNTypeConverter {
+  SparseIterationTypeConverter();
+};
+
+void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
+                                               RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
+
 //===----------------------------------------------------------------------===//
 // The SparseTensorConversion pass.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index c6554e1c94a4a..ec25bb42aa174 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -480,4 +480,19 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
   ];
 }
 
+def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
+  let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
+  let description = [{
+     This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
+     The pass is not yet stablized.
+  }];
+  let constructor = "mlir::createLowerSparseIterationToSCFPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
+    "sparse_tensor::SparseTensorDialect",
+  ];
+}
+
+
 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index 2a29ee8a7a87c..e4acfa8889e5f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   SparseAssembler.cpp
   SparseBufferRewriting.cpp
   SparseGPUCodegen.cpp
+  SparseIterationToScf.cpp
   SparseReinterpretMap.cpp
   SparseStorageSpecifierToLLVM.cpp
   SparseSpaceCollapse.cpp
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
new file mode 100644
index 0000000000000..d89b0b192ffcd
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -0,0 +1,78 @@
+
+#include "Utils/CodegenUtils.h"
+#include "Utils/SparseTensorIterator.h"
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
+                      SmallVectorImpl<Type> &fields) {
+  // Position and coordinate buffer in the sparse structure.
+  if (enc.getLvlType(lvl).isWithPosLT())
+    fields.push_back(enc.getPosMemRefType());
+  if (enc.getLvlType(lvl).isWithCrdLT())
+    fields.push_back(enc.getCrdMemRefType());
+  // One index for shape bound (result from lvlOp)
+  fields.push_back(IndexType::get(enc.getContext()));
+}
+
+static std::optional<LogicalResult>
+convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
+
+  auto idxTp = IndexType::get(itSp.getContext());
+  for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
+    convertLevelType(itSp.getEncoding(), l, fields);
+
+  // Two indices for lower and upper bound (we only need one pair for the last
+  // iteration space).
+  fields.append({idxTp, idxTp});
+  return success();
+}
+
+namespace {
+
+/// Sparse codegen rule for number of entries operator.
+class ExtractIterSpaceConverter
+    : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+public:
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+
+    // Construct the iteration space.
+    SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
+                               op.getLvlRange(), adaptor.getParentIter());
+
+    SmallVector<Value> result = space.toValues();
+    rewriter.replaceOp(op, result, resultMapping);
+    return success();
+  }
+};
+
+} // namespace
+
+mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
+  addConversion([](Type type) { return type; });
+  addConversion(convertIterSpaceType);
+
+  addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
+                              ValueRange inputs,
+                              Location loc) -> std::optional<Value> {
+    return builder
+        .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
+        .getResult(0);
+  });
+}
+
+void mlir::populateLowerSparseIterationToSCFPatterns(
+    TypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index b42d58634a36c..ffbc85e9a17f5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -26,6 +26,7 @@ namespace mlir {
 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
 #define GEN_PASS_DEF_SPARSIFICATIONPASS
+#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -153,10 +154,34 @@ struct LowerForeachToSCFPass
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     populateLowerForeachToSCFPatterns(patterns);
+
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
 
+struct LowerSparseIterationToSCFPass
+    : public impl::LowerSparseIterationToSCFBase<
+          LowerSparseIterationToSCFPass> {
+  LowerSparseIterationToSCFPass() = default;
+  LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
+      default;
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+    SparseIterationTypeConverter converter;
+    ConversionTarget target(*ctx);
+
+    // The actual conversion.
+    target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+    populateLowerSparseIterationToSCFPatterns(converter, patterns);
+
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 struct SparseTensorConversionPass
     : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
   SparseTensorConversionPass() = default;
@@ -439,6 +464,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
   return std::make_unique<LowerForeachToSCFPass>();
 }
 
+std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
+  return std::make_unique<LowerSparseIterationToSCFPass>();
+}
+
 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
   return std::make_unique<SparseTensorConversionPass>();
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index dbec46d2616d9..be8e15d6ae6f4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -168,7 +168,7 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
     ValueRange posRange = posRangeIf.getResults();
     return {posRange.front(), posRange.back()};
   }
-};
+}; // namespace
 
 class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
 public:
@@ -190,7 +190,7 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
     Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
     return {pLo, pHi};
   }
-};
+}; // namespace
 
 class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
 public:
@@ -210,6 +210,13 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
     // Use the segHi as the loop upper bound.
     return {p, segHi};
   }
+
+  ValuePair
+  collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
+                       std::pair<Value, Value> parentRange) const override {
+    // Singleton level keeps the same range after collapsing.
+    return parentRange;
+  };
 };
 
 class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
@@ -1474,10 +1481,85 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
   return getCursor();
 }
 
+//===----------------------------------------------------------------------===//
+// SparseIterationSpace Implementation
+//===----------------------------------------------------------------------===//
+
+mlir::sparse_tensor::SparseIterationSpace::SparseIterationSpace(
+    Location l, OpBuilder &b, Value t, unsigned tid,
+    std::pair<Level, Level> lvlRange, ValueRange parentPos)
+    : lvls() {
+  auto [lvlLo, lvlHi] = lvlRange;
+
+  Value c0 = C_IDX(0);
+  if (parentPos.empty())
+    parentPos = c0;
+
+  for (Level lvl = lvlLo; lvl < lvlHi; lvl++)
+    lvls.emplace_back(makeSparseTensorLevel(b, l, t, tid, lvl));
+
+  bound = lvls.front()->peekRangeAt(b, l, /*batchPrefix=*/{}, parentPos);
+  for (auto &lvl : getLvlRef().drop_front())
+    bound = lvl->collapseRangeBetween(b, l, /*batchPrefix=*/{}, bound);
+}
+
+SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
+    IterSpaceType dstTp, ValueRange values, unsigned int tid) {
+  // Reconstruct every sparse tensor level.
+  SparseIterationSpace space;
+  for (auto [i, lt] : llvm::enumerate(dstTp.getLvlTypes())) {
+    unsigned bufferCnt = 0;
+    if (lt.isWithPosLT())
+      bufferCnt++;
+    if (lt.isWithCrdLT())
+      bufferCnt++;
+    // Sparse tensor buffers.
+    ValueRange buffers = values.take_front(bufferCnt);
+    values = values.drop_front(bufferCnt);
+
+    // Level size.
+    Value sz = values.front();
+    values = values.drop_front();
+    space.lvls.push_back(
+        makeSparseTensorLevel(lt, sz, buffers, tid, i + dstTp.getLoLvl()));
+  }
+  // Two bounds.
+  space.bound = std::make_pair(values[0], values[1]);
+  values = values.drop_front(2);
+
+  // Must have consumed all values.
+  assert(values.empty());
+  return space;
+}
+
 //===----------------------------------------------------------------------===//
 // SparseIterator factory functions.
 //===----------------------------------------------------------------------===//
 
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(LevelType lt, Value sz, ValueRange b,
+                                     unsigned t, Level l) {
+  assert(lt.getNumBuffer() == b.size());
+  switch (lt.getLvlFmt()) {
+  case LevelFormat::Dense:
+    return std::make_unique<DenseLevel>(t, l, sz);
+  case LevelFormat::Batch:
+    return std::make_unique<BatchLevel>(t, l, sz);
+  case LevelFormat::Compressed:
+    return std::make_unique<CompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::LooseCompressed:
+    return std::make_unique<LooseCompressedLevel>(t, l, lt, sz, b[0], b[1]);
+  case LevelFormat::Singleton:
+    return std::make_unique<SingletonLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::NOutOfM:
+    return std::make_unique<NOutOfMLevel>(t, l, lt, sz, b[0]);
+  case LevelFormat::Undef:
+    llvm_unreachable("undefined level format");
+  }
+  llvm_unreachable("unrecognizable level format");
+}
+
 std::unique_ptr<SparseTensorLevel>
 sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
                                      unsigned tid, Level lvl) {
@@ -1487,33 +1569,16 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
   Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
                                : b.create<tensor::DimOp>(l, t, lvl).getResult();
 
-  switch (lt.getLvlFmt()) {
-  case LevelFormat::Dense:
-    return std::make_unique<DenseLevel>(tid, lvl, sz);
-  case LevelFormat::Batch:
-    return std::make_unique<BatchLevel>(tid, lvl, sz);
-  case LevelFormat::Compressed: {
-    Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
-  }
-  case LevelFormat::LooseCompressed: {
+  SmallVector<Value, 2> buffers;
+  if (lt.isWithPosLT()) {
     Value pos = b.create<ToPositionsOp>(l, t, lvl);
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
-  }
-  case LevelFormat::Singleton: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
+    buffers.push_back(pos);
   }
-  case LevelFormat::NOutOfM: {
-    Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
-    return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
+  if (lt.isWithCrdLT()) {
+    Value pos = b.create<ToCoordinatesOp>(l, t, lvl);
+    buffers.push_back(pos);
   }
-  case LevelFormat::Undef:
-    llvm_unreachable("undefined level format");
-  }
-  llvm_unreachable("unrecognizable level format");
+  return makeSparseTensorLevel(lt, sz, buffers, tid, lvl);
 }
 
 std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 120a806536f19..09503d4b6a099 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -15,6 +15,9 @@
 namespace mlir {
 namespace sparse_tensor {
 
+// Forward declaration.
+class SparseIterator;
+
 /// The base class for all types of sparse tensor levels. It provides interfaces
 /// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
 /// `peekCrdAt`).
@@ -50,6 +53,12 @@ class SparseTensorLevel {
   peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
               ValueRange parentPos, Value inPadZone = nullptr) const = 0;
 
+  virtual std::pair<Value, Value>
+  collapseRangeBetween(OpBuilder &b, Location l, ValueRange batchPrefix,
+                       std::pair<Value, Value> parentRange) const {
+    llvm_unreachable("Not Implemented");
+  };
+
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
   Value getSize() const { return lvlSize; }
@@ -79,6 +88,51 @@ enum class IterKind : uint8_t {
   kPad,
 };
 
+class SparseIterationSpace {
+public:
+  SparseIterationSpace() = default;
+
+  // Constructs a N-D iteration space.
+  SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
+                       std::pair<Level, Level> lvlRange, ValueRange parentPos);
+
+  // Constructs a 1-D iteration space.
+  SparseIterationSpace(Location loc, OpBuilder &b, Value t, unsigned tid,
+                       Level lvl, ValueRange parentPos)
+      : SparseIterationSpace(loc, b, t, tid, {lvl, lvl + 1}, parentPos){};
+
+  bool isUnique() const { return lvls.back()->isUnique(); }
+
+  unsigned getSpaceDim() const { return lvls.size(); }
+
+  // Reconstructs a iteration space directly from the provided ValueRange.
+  static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values,
+                                         unsigned tid);
+
+  // The inverse operation of `fromValues`.
+  SmallVector<Value> toValues() const {
+    SmallVector<Value> vals;
+    for (auto &stl : lvls) {
+      llvm::append_range(vals, stl->getLvlBuffers());
+      vals.push_back(stl->getSize());
+    }
+    vals.append({bound.first, bound.second});
+    return vals;
+  }
+
+  const SparseTensorLevel &getLastLvl() const { return *lvls.back(); }
+  ArrayRef<std::unique_ptr<SparseTensorLevel>> getLvlRef() const {
+    return lvls;
+  }
+
+  Value getBoundLo() const { return bound.first; }
+  Value getBoundHi() const { return bound.second; }
+
+private:
+  SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
+  std::pair<Value, Value> bound;
+};
+
 /// Helper class that generates loop conditions, etc, to traverse a
 /// sparse tensor level.
 class SparseIterator {
@@ -287,10 +341,15 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
                                                          unsigned tid,
                                                          Level lvl);
 
-/// Helper function to create a simple SparseIterator object that iterate over
-/// the SparseTensorLevel.
-std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
-                                                   SparseEmitStrategy strategy);
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
+                                                         ValueRange buffers,
+                                                         unsigned tid, Level l);
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the SparseTensorLevel.
+std::unique_ptr<SparseIterator> makeSimpleIterator(
+    const SparseTensorLevel &stl,
+    SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
 
 /// Helper function to create a synthetic SparseIterator object that iterates
 /// over a dense space specified by [0,`sz`).
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
new file mode 100644
index 0000000000000..b51c11ebf8a8c
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i : compressed(nonunique),
+    j : singleton(soa)
+  )
+}>
+
+// CHECK-LABEL:   func.func @sparse_1D_space(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK:           %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
+// CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
+// CHECK:           %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
+func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO>
+  return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
index baa6199f12bc3..b99bf915c71f8 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -18,20 +18,21 @@
 // CHECK:           "test.sink"(%[[VAL_4]]) : (index) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
+func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>) -> index {
+  %i = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
   %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
       : tensor<4x8xf32, #COO>
      -> !sparse_tensor.iter_space<#COO, lvls = 0>
-  %r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
     %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
         : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
        -> !sparse_tensor.iter_space<#COO, lvls = 1>
     %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
-      %k ="test.op"(%inner) : (index) -> index
+      %k = arith.addi %inner, %c1 : index
       sparse_tensor.yield %k : index
     }
     sparse_tensor.yield %r2 : index
   }
-  "test.sink"(%r1) : (index) -> ()
-  return
+  return %r1 : index
 }

>From 9e21f22cfb1069061e9f56e26bbf9ed9ade90382 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 11 Jun 2024 20:14:43 +0000
Subject: [PATCH 2/4] rebase

---
 .../SparseTensor/sparse_iteration_to_scf.mlir    |  2 +-
 .../SparseTensor/sparse_space_collapse.mlir      | 16 ++++++++--------
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
index b51c11ebf8a8c..5fcd661bb69b2 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -18,6 +18,6 @@
 // CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
 // CHECK:           %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
 func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO>
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
   return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
index b99bf915c71f8..03a21e5d2c789 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -8,15 +8,15 @@
 }>
 
 // CHECK-LABEL:   func.func @sparse_sparse_collapse(
-// CHECK-SAME:         %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
-// CHECK-SAME:         %[[VAL_1:.*]]: index) {
-// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
-// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
-// CHECK:             %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
-// CHECK:             sparse_tensor.yield %[[VAL_8]] : index
+// CHECK-SAME:        %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>) -> index {
+// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] iter_args(%[[VAL_6:.*]] = %[[VAL_1]])
+// CHECK:             %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_2]] : index
+// CHECK:             sparse_tensor.yield %[[VAL_7]] : index
 // CHECK:           }
-// CHECK:           "test.sink"(%[[VAL_4]]) : (index) -> ()
-// CHECK:           return
+// CHECK:           return %[[VAL_4]] : index
 // CHECK:         }
 func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>) -> index {
   %i = arith.constant 0 : index

>From 1d6af0ab307cf86cffaf01b6f6b9348bb670db45 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 11 Jun 2024 20:27:50 +0000
Subject: [PATCH 3/4] address comments

---
 .../Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp  | 1 -
 .../SparseTensor/Transforms/Utils/SparseTensorIterator.h    | 6 +++++-
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index ffbc85e9a17f5..8004bdb904b8a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -154,7 +154,6 @@ struct LowerForeachToSCFPass
     auto *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     populateLowerForeachToSCFPatterns(patterns);
-
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 09503d4b6a099..5f452a9354006 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -88,6 +88,10 @@ enum class IterKind : uint8_t {
   kPad,
 };
 
+/// A `SparseIterationSpace` represents a sparse set of coordinates defined by
+/// (possibly multiple) levels of a specific sparse tensor.
+/// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when
+/// feature complete.
 class SparseIterationSpace {
 public:
   SparseIterationSpace() = default;
@@ -345,7 +349,7 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
 std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
                                                          ValueRange buffers,
                                                          unsigned tid, Level l);
-/// Helper function to create a simple SparseIterator object that iterate
+/// Helper function to create a simple SparseIterator object that iterates
 /// over the SparseTensorLevel.
 std::unique_ptr<SparseIterator> makeSimpleIterator(
     const SparseTensorLevel &stl,

>From 90e5f1f6c9b191876644fb7b3298dc98daf59ff3 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 11 Jun 2024 20:55:51 +0000
Subject: [PATCH 4/4] address comments

---
 .../mlir/Dialect/SparseTensor/Transforms/Passes.td        | 4 ++--
 .../SparseTensor/Transforms/SparseIterationToScf.cpp      | 2 +-
 mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir | 8 +++-----
 3 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index ec25bb42aa174..de48a8eb184be 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -472,7 +472,7 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
   let summary = "sparse space collapsing pass";
   let description = [{
      This pass collapses consecutive sparse spaces (extracted from the same tensor)
-     into one multi-dimensional space. The pass is not yet stablized.
+     into one multi-dimensional space. The pass is not yet stabilized.
   }];
   let constructor = "mlir::createSparseSpaceCollapsePass()";
   let dependentDialects = [
@@ -484,7 +484,7 @@ def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::Fun
   let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
   let description = [{
      This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
-     The pass is not yet stablized.
+     The pass is not yet stabilized.
   }];
   let constructor = "mlir::createLowerSparseIterationToSCFPass()";
   let dependentDialects = [
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d89b0b192ffcd..62887c75c872b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -17,7 +17,7 @@ void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
     fields.push_back(enc.getPosMemRefType());
   if (enc.getLvlType(lvl).isWithCrdLT())
     fields.push_back(enc.getCrdMemRefType());
-  // One index for shape bound (result from lvlOp)
+  // One index for shape bound (result from lvlOp).
   fields.push_back(IndexType::get(enc.getContext()));
 }
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
index 03a21e5d2c789..b5d041273f440 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_space_collapse.mlir
@@ -22,12 +22,10 @@ func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>) -> index {
   %i = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
-      : tensor<4x8xf32, #COO>
-     -> !sparse_tensor.iter_space<#COO, lvls = 0>
-    %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+      : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+  %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
     %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
-        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
-       -> !sparse_tensor.iter_space<#COO, lvls = 1>
+        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1>
     %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
       %k = arith.addi %inner, %c1 : index
       sparse_tensor.yield %k : index



More information about the Mlir-commits mailing list