[Mlir-commits] [mlir] [mlir][sparse] implement lowering rules for IterateOp. (PR #95286)

Peiming Liu llvmlistbot at llvm.org
Mon Jun 17 10:20:07 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/95286

>From d91da20b0da1b3306bda7496b1b97b99b42cb06d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 28 May 2024 17:50:37 +0000
Subject: [PATCH 1/2] [mlir][sparse] implement lowering rules for IterateOp.

---
 .../Transforms/SparseIterationToScf.cpp       | 119 +++++++++++++++++-
 .../Transforms/Utils/SparseTensorIterator.cpp |  40 ++++++
 .../Transforms/Utils/SparseTensorIterator.h   |  26 +++-
 .../SparseTensor/sparse_iteration_to_scf.mlir |  54 ++++++--
 4 files changed, 222 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 62887c75c872b..440fc3efb4edf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -34,6 +34,19 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
   return success();
 }
 
+static std::optional<LogicalResult>
+convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
+  // The actually Iterator Values (that are updated every iteration).
+  auto idxTp = IndexType::get(itTp.getContext());
+  // TODO: This assumes there is no batch dimenstion in the sparse tensor.
+  if (!itTp.isUnique()) {
+    // Segment high for non-unqiue iterator.
+    fields.push_back(idxTp);
+  }
+  fields.push_back(idxTp);
+  return success();
+}
+
 namespace {
 
 /// Sparse codegen rule for number of entries operator.
@@ -57,10 +70,113 @@ class ExtractIterSpaceConverter
   }
 };
 
+class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
+public:
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  LogicalResult
+  matchAndRewrite(IterateOp op, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (!op.getCrdUsedLvls().empty())
+      llvm_unreachable("Not implemented.");
+
+    Location loc = op.getLoc();
+
+    auto iterSpace = SparseIterationSpace::fromValues(
+        op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
+
+    std::unique_ptr<SparseIterator> it =
+        iterSpace.extractIterator(rewriter, loc);
+
+    if (it->iteratableByFor()) {
+      auto [lo, hi] = it->genForCond(rewriter, loc);
+      Value step = constantIndex(rewriter, loc, 1);
+      SmallVector<Value> ivs;
+      for (ValueRange inits : adaptor.getInitArgs())
+        llvm::append_range(ivs, inits);
+      scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
+
+      Block *loopBody = op.getBody();
+      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
+      if (failed(typeConverter->convertSignatureArgs(
+              loopBody->getArgumentTypes(), bodyTypeMapping)))
+        return failure();
+      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
+
+      forOp.getBody()->erase();
+      Region &dstRegion = forOp.getRegion();
+      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
+
+      auto yieldOp =
+          llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
+
+      rewriter.setInsertionPointToEnd(forOp.getBody());
+      // replace sparse_tensor.yield with scf.yield.
+      rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
+      yieldOp.erase();
+
+      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+      rewriter.replaceOp(op, forOp.getResults(), resultMapping);
+    } else {
+      SmallVector<Value> ivs;
+      llvm::append_range(ivs, it->getCursor());
+      for (ValueRange inits : adaptor.getInitArgs())
+        llvm::append_range(ivs, inits);
+
+      assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
+
+      TypeRange types = ValueRange(ivs).getTypes();
+      auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+      SmallVector<Location> l(types.size(), op.getIterator().getLoc());
+
+      // Generates loop conditions.
+      Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+      rewriter.setInsertionPointToStart(before);
+      ValueRange bArgs = before->getArguments();
+      auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+      assert(remArgs.size() == adaptor.getInitArgs().size());
+      rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+      // Generates loop body.
+      Block *loopBody = op.getBody();
+      OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
+      if (failed(typeConverter->convertSignatureArgs(
+              loopBody->getArgumentTypes(), bodyTypeMapping)))
+        return failure();
+      rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
+
+      Region &dstRegion = whileOp.getAfter();
+      // TODO: handle uses of coordinate!
+      rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
+      ValueRange aArgs = whileOp.getAfterArguments();
+      auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
+          whileOp.getAfterBody()->getTerminator());
+
+      rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+
+      aArgs = it->linkNewScope(aArgs);
+      ValueRange nx = it->forward(rewriter, loc);
+      SmallVector<Value> yields;
+      llvm::append_range(yields, nx);
+      llvm::append_range(yields, yieldOp.getResults());
+
+      // replace sparse_tensor.yield with scf.yield.
+      yieldOp->erase();
+      rewriter.create<scf::YieldOp>(loc, yields);
+
+      const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
+      rewriter.replaceOp(
+          op, whileOp.getResults().drop_front(it->getCursor().size()),
+          resultMapping);
+    }
+    return success();
+  }
+};
+
 } // namespace
 
 mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
   addConversion([](Type type) { return type; });
+  addConversion(convertIteratorType);
   addConversion(convertIterSpaceType);
 
   addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
@@ -74,5 +190,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
 
 void mlir::populateLowerSparseIterationToSCFPatterns(
     TypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
+  patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
+      converter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index be8e15d6ae6f4..ef95fcc84bd90 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -331,6 +331,13 @@ class TrivialIterator : public ConcreteIterator {
   TrivialIterator(const SparseTensorLevel &stl)
       : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
 
+  TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
+                  Value posLo, Value posHi)
+      : ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
+        posHi(posHi) {
+    seek(posLo);
+  }
+
   std::string getDebugInterfacePrefix() const override {
     return std::string("trivial<") + stl.toString() + ">";
   }
@@ -420,6 +427,14 @@ class DedupIterator : public ConcreteIterator {
       : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
     assert(!stl.isUnique());
   }
+
+  DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
+                Value posLo, Value posHi)
+      : ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
+    assert(!stl.isUnique());
+    seek({posLo, genSegmentHigh(b, l, posLo)});
+  }
+
   // For LLVM-style RTTI.
   static bool classof(const SparseIterator *from) {
     return from->kind == IterKind::kDedup;
@@ -1532,6 +1547,11 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
   return space;
 }
 
+std::unique_ptr<SparseIterator>
+SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
+  return makeSimpleIterator(b, l, *this);
+}
+
 //===----------------------------------------------------------------------===//
 // SparseIterator factory functions.
 //===----------------------------------------------------------------------===//
@@ -1590,6 +1610,26 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
   return std::make_pair(std::move(stl), std::move(it));
 }
 
+std::unique_ptr<SparseIterator>
+sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
+                                  const SparseIterationSpace &iterSpace) {
+  // assert(iterSpace.getSpaceDim() == 1);
+  std::unique_ptr<SparseIterator> ret;
+  if (!iterSpace.isUnique()) {
+    // We always dedupliate the non-unique level, but we should optimize it away
+    // if possible.
+    ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
+                                          iterSpace.getBoundLo(),
+                                          iterSpace.getBoundHi());
+  } else {
+    ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
+                                            iterSpace.getBoundLo(),
+                                            iterSpace.getBoundHi());
+  }
+  ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
+  return ret;
+}
+
 std::unique_ptr<SparseIterator>
 sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
                                   SparseEmitStrategy strategy) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 17636af2b2f9d..91f363db93f1d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -132,6 +132,10 @@ class SparseIterationSpace {
   Value getBoundLo() const { return bound.first; }
   Value getBoundHi() const { return bound.second; }
 
+  // Extract an iterator to iterate over the sparse iteration space.
+  std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
+                                                  Location l) const;
+
 private:
   SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
   std::pair<Value, Value> bound;
@@ -192,6 +196,13 @@ class SparseIterator {
     crd = nullptr;
   }
 
+  // Reconstructs a iteration space directly from the provided ValueRange.
+  static std::unique_ptr<SparseIterator>
+  fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
+
+  // The inverse operation of `fromValues`.
+  SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
+
   //
   // Iterator properties.
   //
@@ -345,12 +356,21 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
                                                          unsigned tid,
                                                          Level lvl);
 
-/// Helper function to create a TensorLevel object from given `tensor`.
+/// Helper function to create a TensorLevel object from given ValueRange.
 std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
                                                          ValueRange buffers,
                                                          unsigned tid, Level l);
-/// Helper function to create a simple SparseIterator object that iterates
-/// over the SparseTensorLevel.
+
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the entire iteration space.
+std::unique_ptr<SparseIterator>
+makeSimpleIterator(OpBuilder &b, Location l,
+                   const SparseIterationSpace &iterSpace);
+
+/// Helper function to create a simple SparseIterator object that iterate
+/// over the sparse tensor level.
+/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
+/// feature complete.
 std::unique_ptr<SparseIterator> makeSimpleIterator(
     const SparseTensorLevel &stl,
     SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
diff --git a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
index 5fcd661bb69b2..77a0e89dc7c81 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_iteration_to_scf.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
+// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED
 
 #COO = #sparse_tensor.encoding<{
   map = (i, j) -> (
@@ -7,17 +8,44 @@
   )
 }>
 
-// CHECK-LABEL:   func.func @sparse_1D_space(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
-// CHECK:           %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
-// CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
-// CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
-// CHECK:           %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
-func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
-  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
-  return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
+// CHECK-LABEL:   @sparse_iteration_to_scf
+//                  // deduplication
+// CHECK:           scf.while {{.*}} {
+// CHECK:           } do {
+// CHECK:           }
+// CHECK:           scf.while {{.*}} {
+// CHECK:           } do {
+//                    // actual computation
+// CHECK:             scf.for {{.*}} {
+// CHECK:               arith.addi
+// CHECK:             }
+//                    // deduplication
+// CHECK:             scf.while {{.*}} {
+// CHECK:             } do {
+// CHECK:             }
+// CHECK:             scf.yield
+// CHECK:           }
+// CHECK:           return
+
+// COLLAPSED-LABEL:   @sparse_iteration_to_scf
+// COLLAPSED:           %[[RET:.*]] = scf.for {{.*}} {
+// COLLAPSED:             %[[VAL:.*]] = arith.addi
+// COLLAPSED:             scf.yield %[[VAL]] : index
+// COLLAPSED:           }
+// COLLAPSED:           return %[[RET]] : index
+func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index {
+  %i = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
+      : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
+  %r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
+    %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+        : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1>
+    %r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
+      %k = arith.addi %inner, %c1 : index
+      sparse_tensor.yield %k : index
+    }
+    sparse_tensor.yield %r2 : index
+  }
+  return %r1 : index
 }

>From a5d0ef40d8986fde785b6a42d36df322fb533b6e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Mon, 17 Jun 2024 17:19:50 +0000
Subject: [PATCH 2/2] address comments

---
 .../SparseTensor/Transforms/SparseIterationToScf.cpp        | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index 440fc3efb4edf..58332034721e4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -38,7 +38,8 @@ static std::optional<LogicalResult>
 convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
   // The actually Iterator Values (that are updated every iteration).
   auto idxTp = IndexType::get(itTp.getContext());
-  // TODO: This assumes there is no batch dimenstion in the sparse tensor.
+  // TODO: handle batch dimension.
+  assert(itTp.getEncoding().getBatchLvlRank() == 0);
   if (!itTp.isUnique()) {
     // Segment high for non-unqiue iterator.
     fields.push_back(idxTp);
@@ -77,7 +78,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
   matchAndRewrite(IterateOp op, OpAdaptor adaptor,
                   OneToNPatternRewriter &rewriter) const override {
     if (!op.getCrdUsedLvls().empty())
-      llvm_unreachable("Not implemented.");
+      return rewriter.notifyMatchFailure(
+          op, "non-empty coordinates list not implemented.");
 
     Location loc = op.getLoc();
 



More information about the Mlir-commits mailing list