[Mlir-commits] [mlir] [mlir][sparse_tensor] Migrate `SparseIterationToScf.cpp` to dialect conversion (PR #121054)

Matthias Springer llvmlistbot at llvm.org
Tue Dec 24 06:07:59 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/121054

>From 8547cba1d26e2699874f4688cbf5ef340a09d71b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 20 Dec 2024 18:13:00 +0100
Subject: [PATCH] sparse iteration

---
 .../Transforms/SparseIterationToScf.cpp       | 123 +++++++++++-------
 .../Transforms/SparseTensorPasses.cpp         |  11 +-
 2 files changed, 81 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index e8a40b1e033dd5b..9e9fea76416b9ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -7,11 +7,17 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+  assert(values.size() == 1 && "expected single value");
+  return values.front();
+}
+
 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
                              SmallVectorImpl<Type> &fields) {
   // Position and coordinate buffer in the sparse structure.
@@ -54,14 +60,17 @@ static ValueRange
 genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
                        Value loopCrd,
                        ArrayRef<std::unique_ptr<SparseIterator>> iters,
-                       ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
-  if (subCases.empty())
+                       ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
+                       ArrayRef<Value> userReduc) {
+  if (newBlocks.empty())
     return userReduc;
 
   // The current branch that we are handling.
-  Region *b = subCases.front();
+  Block *newBlock = newBlocks.front();
+  Block *oldBlock = oldBlocks.front();
   Value casePred = constantI1(rewriter, loc, true);
-  I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
+  I64BitSet caseBits =
+      op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
   for (unsigned i : caseBits.bits()) {
     SparseIterator *it = iters[i].get();
     Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
@@ -80,16 +89,20 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
   for (unsigned idx : caseBits.bits())
     llvm::append_range(blockArgs, iters[idx]->getCursor());
 
+  // Map the old block arguments, because the dialect conversion driver does
+  // not immediately perform SSA value replacements. This function is still
+  // seeing the old uses.
   IRMapping mapping;
-  for (auto [from, to] :
-       llvm::zip_equal(b->front().getArguments(), blockArgs)) {
+  for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
     mapping.map(from, to);
   }
 
   // Clone the region, we can not erase the region now because the same region
   // might be a subcase for multiple lattice point.
-  rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
+  rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
                              ifOp.getThenRegion().begin(), mapping);
+  // Remove the block arguments, they were already replaced via `mapping`.
+  ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
 
   // replace sparse_tensor::YieldOp -> scf::YieldOp
   auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
@@ -101,7 +114,8 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
   // Generates remaining case recursively.
   rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
   ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
-                                          subCases.drop_front(), userReduc);
+                                          newBlocks.drop_front(),
+                                          oldBlocks.drop_front(), userReduc);
   if (!res.empty())
     rewriter.create<scf::YieldOp>(loc, res);
 
@@ -119,15 +133,13 @@ static ValueRange genLoopWithIterator(
   if (it->iteratableByFor()) {
     auto [lo, hi] = it->genForCond(rewriter, loc);
     Value step = constantIndex(rewriter, loc, 1);
-    scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
+    scf::ForOp forOp = rewriter.create<scf::ForOp>(
+        loc, lo, hi, step, reduc,
+        [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
+          // Empty builder function to ensure that no terminator is created.
+        });
     {
       OpBuilder::InsertionGuard guard(rewriter);
-      // Erase the implicit yield operation created by ForOp when there is no
-      // yielding values.
-      if (!forOp.getBody()->empty())
-        rewriter.eraseOp(&forOp.getBody()->front());
-      assert(forOp.getBody()->empty());
-
       it->linkNewScope(forOp.getInductionVar());
       rewriter.setInsertionPointToStart(forOp.getBody());
       SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
@@ -178,46 +190,47 @@ namespace {
 
 /// Sparse codegen rule for number of entries operator.
 class ExtractIterSpaceConverter
-    : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
+    : public OpConversionPattern<ExtractIterSpaceOp> {
 public:
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
 
     // Construct the iteration space.
-    SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
+    SparseIterationSpace space(loc, rewriter,
+                               getSingleValue(adaptor.getTensor()), 0,
                                op.getLvlRange(), adaptor.getParentIter());
 
     SmallVector<Value> result = space.toValues();
-    rewriter.replaceOp(op, result, resultMapping);
+    rewriter.replaceOpWithMultiple(op, {result});
     return success();
   }
 };
 
 /// Sparse codegen rule for number of entries operator.
-class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
+class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
 public:
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Value pos = adaptor.getIterator().back();
-    Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
+    Value valBuf =
+        rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
     return success();
   }
 };
 
-class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
+class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
 public:
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(IterateOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     if (!op.getCrdUsedLvls().empty())
       return rewriter.notifyMatchFailure(
           op, "non-empty coordinates list not implemented.");
@@ -235,14 +248,15 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
       llvm::append_range(ivs, inits);
 
     // Type conversion on iterate op block.
-    OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+    unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
+    TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
     if (failed(typeConverter->convertSignatureArgs(
-            op.getBody()->getArgumentTypes(), blockTypeMapping)))
+            op.getBody()->getArgumentTypes(), signatureConversion)))
       return rewriter.notifyMatchFailure(
           op, "failed to convert iterate region argurment types");
-    rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
 
-    Block *block = op.getBody();
+    Block *block = rewriter.applySignatureConversion(
+        op.getBody(), signatureConversion, getTypeConverter());
     ValueRange ret = genLoopWithIterator(
         rewriter, loc, it.get(), ivs,
         [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
@@ -263,19 +277,17 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
           return result;
         });
 
-    const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-    rewriter.replaceOp(op, ret, resultMapping);
+    rewriter.replaceOp(op, ret);
     return success();
   }
 };
 
-class SparseCoIterateOpConverter
-    : public OneToNOpConversionPattern<CoIterateOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     assert(op.getSpaceDim() == 1 && "Not implemented");
     Location loc = op.getLoc();
 
@@ -299,18 +311,23 @@ class SparseCoIterateOpConverter
     assert(!needUniv && "Not implemented");
     (void)needUniv;
 
+    SmallVector<Block *> newBlocks;
+    DenseMap<Block *, Block *> newToOldBlockMap;
     for (Region &region : op.getCaseRegions()) {
       // Do a one-shot type conversion on all region blocks, since the same
       // region might be used multiple time.
       Block *block = &region.getBlocks().front();
-      OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
+      TypeConverter::SignatureConversion blockTypeMapping(
+          block->getArgumentTypes().size());
       if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
                                                      blockTypeMapping))) {
         return rewriter.notifyMatchFailure(
             op, "failed to convert coiterate region argurment types");
       }
 
-      rewriter.applySignatureConversion(block, blockTypeMapping);
+      newBlocks.push_back(rewriter.applySignatureConversion(
+          block, blockTypeMapping, getTypeConverter()));
+      newToOldBlockMap[newBlocks.back()] = block;
     }
 
     SmallVector<SparseIterationSpace> spaces;
@@ -343,7 +360,7 @@ class SparseCoIterateOpConverter
 
     // Generates a loop sequence, one loop per case.
     for (auto [r, caseBits] :
-         llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
+         llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
       assert(caseBits.count() > 0 && "Complement space not implemented");
 
       // Retrives a vector of pointers to the iterators used in the case.
@@ -359,11 +376,17 @@ class SparseCoIterateOpConverter
         // The subcases are never empty, it must contains at least the current
         // region itself.
         // TODO: these cases should be sorted.
-        SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
+        SmallVector<Region *> subCases =
+            op.getSubCasesOf(r->getParent()->getRegionNumber());
+        SmallVector<Block *> newBlocks, oldBlocks;
+        for (Region *r : subCases) {
+          newBlocks.push_back(&r->front());
+          oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
+        }
         assert(!subCases.empty());
 
-        ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
-                                                iters, subCases, userReduc);
+        ValueRange res = genCoIterateBranchNest(
+            rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
 
         SmallVector<Value> nextIterYields(res);
         // 2nd. foward the loop.
@@ -388,7 +411,7 @@ class SparseCoIterateOpConverter
         // This is a simple iteration loop.
         assert(caseBits.count() == 1);
 
-        Block *block = &r.getBlocks().front();
+        Block *block = r;
         ValueRange curResult = genLoopWithIterator(
             rewriter, loc, validIters.front(), userReduc,
             /*bodyBuilder=*/
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 1cac949b68c79dc..153b9b170e5d340 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -172,11 +172,16 @@ struct LowerSparseIterationToSCFPass
     ConversionTarget target(*ctx);
 
     // The actual conversion.
-    target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
+    target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
+                           memref::MemRefDialect, scf::SCFDialect,
+                           sparse_tensor::SparseTensorDialect>();
+    target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
+                        IterateOp>();
+    target.addLegalOp<UnrealizedConversionCastOp>();
     populateLowerSparseIterationToSCFPatterns(converter, patterns);
 
-    if (failed(applyPartialOneToNConversion(getOperation(), converter,
-                                            std::move(patterns))))
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
       signalPassFailure();
   }
 };



More information about the Mlir-commits mailing list