[Mlir-commits] [mlir] [mlir][sparse] Implement rewriters to reinterpret maps on foreach (PR #70868)

Peiming Liu llvmlistbot at llvm.org
Wed Nov 1 11:53:50 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/70868

>From 236ce30cc54febeb14b98ba25f0085c6b18a755b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 31 Oct 2023 00:04:22 +0000
Subject: [PATCH 1/7] [mlir][sparse] add helper class to implement common
 rewriter to re/demap sparse tensors.

---
 .../Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp    | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a822effbb2ab78c..b301943c8732dab 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -142,7 +142,7 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
 };
 
 //===----------------------------------------------------------------------===//
-// Rewriting rules for operations other than linalg generic ops.
+// Reinterpret Map Rewriters for operations other than linalg.generics
 //===----------------------------------------------------------------------===//
 
 // CRTP to help implementing a rewriter that demaps all its inputs and remaps

>From f00f63901eaa508857e46ba1b4aa7dc44845927b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 31 Oct 2023 22:34:58 +0000
Subject: [PATCH 2/7] [mlir][sparse] Implement rewriters to reinterpret maps on
 foreach operands.

---
 .../Transforms/SparseReinterpretMap.cpp       | 190 ++++++++++++------
 .../Transforms/SparseTensorRewriting.cpp      |  27 ++-
 .../Dialect/Affine/decompose-affine-ops.mlir  |  10 +-
 .../SparseTensor/sparse_reinterpret_map.mlir  |  48 +++--
 4 files changed, 200 insertions(+), 75 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index b301943c8732dab..2d50205f1e96b8b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -25,8 +25,8 @@ namespace {
 //===----------------------------------------------------------------------===//
 
 // Translates a "simple" map according to an identity lvl-map.
-static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
-                              AffineMap map) {
+AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
+                       AffineMap map) {
   unsigned lvlRank = stt.getLvlRank();
   AffineMap lvl2dim = stt.getLvlToDim();
   assert(lvl2dim.getNumInputs() == lvlRank);
@@ -39,18 +39,37 @@ static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
 }
 
 // Generates a "de"mapping reinterpretation of the map.
-static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
-                      Value val) {
+Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) {
   return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
                                           val);
 }
 
 // Generates a "re"mapping reinterpretation of the map.
-static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
-                      Value val) {
+Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) {
   return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
 }
 
+SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
+                                   ValueRange outs) {
+  SmallVector<Value> ret(outs);
+  assert(outs.size() == types.size());
+  for (auto [r, t] : llvm::zip(ret, types))
+    if (r.getType() != t)
+      r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
+  return ret;
+}
+
+/// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
+bool hasNonIdentityOperandsOrResults(Operation *op) {
+  auto hasNonIdentityMap = [](Value v) {
+    auto stt = tryGetSparseTensorType(v);
+    return stt && !stt->isIdentity();
+  };
+
+  return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
+         llvm::any_of(op->getResults(), hasNonIdentityMap);
+}
+
 // Generates a clone of the given linalg generic operation, but with
 // remapped arguments, index maps, and iteration types.
 //
@@ -141,10 +160,6 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
   }
 };
 
-//===----------------------------------------------------------------------===//
-// Reinterpret Map Rewriters for operations other than linalg.generics
-//===----------------------------------------------------------------------===//
-
 // CRTP to help implementing a rewriter that demaps all its inputs and remaps
 // all its outputs.
 template <typename SubClass, typename SourceOp>
@@ -154,9 +169,6 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
 
   LogicalResult matchAndRewrite(SourceOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!static_cast<const SubClass *>(this)->matchOp(op))
-      return failure();
-
     Location loc = op.getLoc();
     // Demaps non-trivial inputs.
     SmallVector<Value> deMappedIns(op->getOperands());
@@ -166,61 +178,125 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
 
     // CRTP call.
     OpAdaptor adaptor(deMappedIns);
-    ValueRange outs =
-        static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
-    assert(outs.size() == op->getResults().size());
-
-    // Remap  outputs.
-    SmallVector<Value> reMappedOuts(outs);
-    for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
-      if (r.getType() != a.getType())
-        r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
-
-    rewriter.replaceOp(op, reMappedOuts);
-    return success();
+    return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+                                                          rewriter);
   }
 };
 
-struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(CrdTranslateOp op,
-                                PatternRewriter &rewriter) const override {
-    AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
-                        ? op.getEncoder().getDimToLvl()
-                        : op.getEncoder().getLvlToDim();
-
-    SmallVector<Value> outCrds;
-    for (AffineExpr result : map.getResults()) {
-      // TODO: we should probably expand the affine map to IR using our own
-      // rules, since affine.apply assume signed value, while the cooridinates
-      // we provided must always be signless.
-      Value trans = rewriter.create<affine::AffineApplyOp>(
-          op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
-          op.getInCrds());
-      outCrds.push_back(trans);
-    }
-    rewriter.replaceOp(op, outCrds);
-    return success();
-  }
-};
+//===----------------------------------------------------------------------===//
+// Reinterpret Map Rewriters for operations other than linalg.generics
+//===----------------------------------------------------------------------===//
 
-struct TensorInsertRewriter
-    : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
+struct TensorInsertDemapper
+    : public DemapInsRemapOutsRewriter<TensorInsertDemapper, tensor::InsertOp> {
   using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+  LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    if (!hasAnySparseResult(op))
+      return failure();
 
-  bool matchOp(tensor::InsertOp op) const {
-    return op.getResult().getType().getEncoding() != nullptr;
-  }
-
-  ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
-                       PatternRewriter &rewriter) const {
     Location loc = op.getLoc();
     auto stt = getSparseTensorType(op.getResult());
     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
                                           CrdTransDirectionKind::dim2lvl);
     Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
-    return insertOp->getResults();
+
+    SmallVector<Value> outs(insertOp->getResults());
+    remapValueRange(rewriter, op->getResultTypes(), outs);
+    rewriter.replaceOp(op, outs);
+    return success();
+  }
+};
+
+struct ForeachOpDemapper
+    : public DemapInsRemapOutsRewriter<ForeachOpDemapper, ForeachOp> {
+  using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+  LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
+                          PatternRewriter &rewriter) const {
+    // Only handles operations with sparse input/output.
+    if (!hasNonIdentityOperandsOrResults(op))
+      return failure();
+
+    // TODO: demap constant as well.
+    if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
+      if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
+        return failure();
+
+    Location loc = op.getLoc();
+    // Cache the type information since we update the foreach op in-place.
+    auto srcStt = getSparseTensorType(op.getTensor());
+    SmallVector<Type> prevRetTps(op.getResultTypes());
+
+    rewriter.startRootUpdate(op);
+    op.getTensorMutable().assign(adaptor.getTensor());
+    op.getInitArgsMutable().assign(adaptor.getInitArgs());
+    // Update results' types.
+    for (auto r : op.getResults())
+      if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
+        r.setType(stt->getDemappedType());
+
+    Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
+    // Update the foreach body.
+    SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
+    blockArgTps.push_back(srcStt.getElementType());
+    blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
+                       adaptor.getInitArgs().getTypes().end());
+    Block *body = op.getBody();
+    // Block Args: [dimCrd, val, initArgs]
+    unsigned preArgNum = body->getNumArguments();
+    for (Type t : blockArgTps)
+      body->addArgument(t, loc);
+
+    // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
+    rewriter.setInsertionPointToStart(body);
+    ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
+
+    ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
+                                              CrdTransDirectionKind::lvl2dim);
+    rewriter.replaceAllUsesWith(
+        body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
+    body->eraseArguments(0, srcStt.getDimRank());
+    // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
+    unsigned numInitArgs = op.getInitArgs().size();
+    rewriter.replaceAllUsesWith(body->getArgument(0),
+                                body->getArgument(lvlRank + numInitArgs + 1));
+    body->eraseArgument(0);
+    // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
+    ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
+    SmallVector<Value> dstArgs(body->getArguments().take_back(numInitArgs));
+    // Remap back before replacement;
+    for (auto [s, d] : llvm::zip(srcArgs, dstArgs))
+      if (s.getType() != d.getType())
+        d = rewriter.create<ReinterpretMapOp>(loc, s.getType(), d);
+    rewriter.replaceAllUsesWith(srcArgs, dstArgs);
+    body->eraseArguments(0, numInitArgs);
+    // Block Args: [lvlCrds, DemappedArgs] and we are done.
+
+    // Update yield operations.
+    if (numInitArgs != 0) {
+      rewriter.setInsertionPointToEnd(body);
+      auto yield = llvm::cast<YieldOp>(body->getTerminator());
+      if (auto stt = tryGetSparseTensorType(yield.getResult());
+          stt && !stt->isIdentity()) {
+        Value y = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(),
+                                                    yield.getResult());
+        rewriter.create<YieldOp>(loc, y);
+        rewriter.eraseOp(yield);
+      }
+    }
+    rewriter.finalizeRootUpdate(op);
+
+    rewriter.setInsertionPointAfter(op);
+    SmallVector<Value> outs(op.getResults());
+    remapValueRange(rewriter, prevRetTps, outs);
+
+    // Replace all the uses of the foreach results, expect the use in
+    // reinterpret_map used to remap the output.
+    for (auto [from, to] : llvm::zip(op.getResults(), outs))
+      rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
+
+    return success();
   }
 };
 
@@ -234,7 +310,7 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
   }
   if (scope == ReinterpretMapScope::kAll ||
       scope == ReinterpretMapScope::kExceptGeneric) {
-    patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(
+    patterns.add<TensorInsertDemapper, ForeachOpDemapper>(
         patterns.getContext());
   }
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 02796bc9a7e7df6..c00f19916e49fbe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1063,6 +1063,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
   }
 };
 
+struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(CrdTranslateOp op,
+                                PatternRewriter &rewriter) const override {
+    AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
+                        ? op.getEncoder().getDimToLvl()
+                        : op.getEncoder().getLvlToDim();
+
+    SmallVector<Value> outCrds;
+    for (AffineExpr result : map.getResults()) {
+      // TODO: we should probably expand the affine map to IR using our own
+      // rules, since affine.apply assume signed value, while the cooridinates
+      // we provided must always be signless.
+      Value trans = rewriter.create<affine::AffineApplyOp>(
+          op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
+          op.getInCrds());
+      outCrds.push_back(trans);
+    }
+    rewriter.replaceOp(op, outCrds);
+    return success();
+  }
+};
+
 /// Sparse rewriting rule for the foreach operator.
 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
 public:
@@ -1284,5 +1307,7 @@ void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
 }
 
 void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
-  patterns.add<ForeachRewriter>(patterns.getContext());
+  // Run CrdTranslateRewriter later in the pipeline so that operation can be
+  // folded before lowering to affine.apply
+  patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
index 6acdc436fe6774a..654b73a0bb6df09 100644
--- a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
+++ b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
@@ -108,12 +108,12 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
         // The hoisted part is %b.
         %a = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 16 + s2 * 32 + s3 * 32 + s0 floordiv 4)>()[%0, %1, %2, %i]
 
-        // Gets completely hoisted 
+        // Gets completely hoisted
         %b = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
 
-        // Gets completely hoisted 
+        // Gets completely hoisted
         %c = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
- 
+
         // 32 * %j + %c remains here, the rest is hoisted.
         // CHECK-DAG: %[[R10:.*]] = affine.apply #[[$times32]]()[%[[j]]]
         // CHECK-DAG: %[[d:.*]] = affine.apply #[[$add]]()[%[[c]], %[[R10]]]
@@ -134,7 +134,7 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
 
           // CHECK-NEXT: %[[g:.*]] = affine.apply #[[$add]]()[%[[b]], %[[idk]]]
           %g = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 16 + s3 * 32 + s1 floordiv 4)>()[%k, %0, %1, %2]
-          
+
           // CHECK-NEXT: "some_side_effecting_consumer"(%[[a]]) : (index) -> ()
           "some_side_effecting_consumer"(%a) : (index) -> ()
           // CHECK-NEXT: "some_side_effecting_consumer"(%[[b]]) : (index) -> ()
@@ -151,6 +151,6 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
           "some_side_effecting_consumer"(%g) : (index) -> ()
         }
     }
-  }   
+  }
   return
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 149c0bc46e25118..be3ab37e9cbd182 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -1,15 +1,4 @@
-// RUN: mlir-opt %s -split-input-file  --sparse-reinterpret-map | FileCheck %s
-
-#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
-
-// CHECK-LABEL: func @sparse_nop(
-//  CHECK-SAME: %[[A0:.*]]: tensor<?xf64, #sparse_tensor.encoding<{{{.*}}}>>)
-//       CHECK: return %[[A0]]
-func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
-  return %arg0 : tensor<?xf64, #SparseVector>
-}
-
-// -----
+// RUN: mlir-opt %s -split-input-file --sparse-reinterpret-map | FileCheck %s
 
 #trait_mul = {
   indexing_maps = [
@@ -55,3 +44,38 @@ func.func @mul(%arg0: tensor<32x32xf32>,
   return %0 : tensor<32x32xf32, #BSR>
 }
 
+// -----
+
+#BSR = #sparse_tensor.encoding<{
+   map = ( i, j ) ->
+      ( i floordiv 2 : dense,
+        j floordiv 2 : compressed,
+        i mod 2      : dense,
+        j mod 2      : dense
+      )
+}>
+
+// CHECK-LABEL:   func.func @sparse_foreach_reinterpret_map(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<2x4xf64
+// CHECK:           %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<2x4xf64
+// CHECK:           %[[VAL_2:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64
+// CHECK:           %[[VAL_3:.*]] = sparse_tensor.reinterpret_map %[[VAL_1]] : tensor<2x4xf64
+// CHECK:           %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_3]])
+// CHECK:           ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: f64, %[[VAL_10:.*]]: tensor<1x2x2x2xf64
+// CHECK:             %[[VAL_11:.*]] = sparse_tensor.insert %[[VAL_9]] into %[[VAL_10]]{{\[}}%[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]]] : tensor<1x2x2x2xf64
+// CHECK:             sparse_tensor.yield %[[VAL_11]] : tensor<1x2x2x2xf64
+// CHECK:           }
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.reinterpret_map %[[VAL_4]] : tensor<1x2x2x2xf64
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.load %[[VAL_12]] hasInserts : tensor<2x4xf64
+// CHECK:           return %[[VAL_13]] : tensor<2x4xf64
+// CHECK:         }
+func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<2x4xf64, #BSR> {
+  %7 = bufferization.alloc_tensor() : tensor<2x4xf64, #BSR>
+  %8 = sparse_tensor.foreach in %6 init(%7) : tensor<2x4xf64, #BSR>, tensor<2x4xf64, #BSR> -> tensor<2x4xf64, #BSR> do {
+    ^bb0(%arg0: index, %arg1: index, %arg2: f64, %arg3: tensor<2x4xf64, #BSR>):
+      %inserted = tensor.insert %arg2 into %arg3[%arg0, %arg1] : tensor<2x4xf64, #BSR>
+      sparse_tensor.yield %inserted : tensor<2x4xf64, #BSR>
+  }
+  %9 = sparse_tensor.load %8 hasInserts : tensor<2x4xf64, #BSR>
+  return %9 : tensor<2x4xf64, #BSR>
+}

>From 862684df2ee601785c3c0e3aa462832d4b1380f3 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 31 Oct 2023 22:45:09 +0000
Subject: [PATCH 3/7] revert unintended change

---
 mlir/test/Dialect/Affine/decompose-affine-ops.mlir | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
index 654b73a0bb6df09..6acdc436fe6774a 100644
--- a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
+++ b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir
@@ -108,12 +108,12 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
         // The hoisted part is %b.
         %a = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 16 + s2 * 32 + s3 * 32 + s0 floordiv 4)>()[%0, %1, %2, %i]
 
-        // Gets completely hoisted
+        // Gets completely hoisted 
         %b = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2]
 
-        // Gets completely hoisted
+        // Gets completely hoisted 
         %c = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0]
-
+ 
         // 32 * %j + %c remains here, the rest is hoisted.
         // CHECK-DAG: %[[R10:.*]] = affine.apply #[[$times32]]()[%[[j]]]
         // CHECK-DAG: %[[d:.*]] = affine.apply #[[$add]]()[%[[c]], %[[R10]]]
@@ -134,7 +134,7 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
 
           // CHECK-NEXT: %[[g:.*]] = affine.apply #[[$add]]()[%[[b]], %[[idk]]]
           %g = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 16 + s3 * 32 + s1 floordiv 4)>()[%k, %0, %1, %2]
-
+          
           // CHECK-NEXT: "some_side_effecting_consumer"(%[[a]]) : (index) -> ()
           "some_side_effecting_consumer"(%a) : (index) -> ()
           // CHECK-NEXT: "some_side_effecting_consumer"(%[[b]]) : (index) -> ()
@@ -151,6 +151,6 @@ func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index,
           "some_side_effecting_consumer"(%g) : (index) -> ()
         }
     }
-  }
+  }   
   return
 }

>From e7b3c2190e4da0a3498c0cb0287cff2bb6b7d516 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 1 Nov 2023 18:38:59 +0000
Subject: [PATCH 4/7] fix errors caused by merging

---
 .../Transforms/SparseReinterpretMap.cpp       | 43 ++++++++++---------
 1 file changed, 22 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 2d50205f1e96b8b..71d93fb3050aeea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -18,15 +18,13 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
-namespace {
-
 //===----------------------------------------------------------------------===//
-// Helper methods.
+// File Local Helper methods.
 //===----------------------------------------------------------------------===//
 
 // Translates a "simple" map according to an identity lvl-map.
-AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
-                       AffineMap map) {
+static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
+                              AffineMap map) {
   unsigned lvlRank = stt.getLvlRank();
   AffineMap lvl2dim = stt.getLvlToDim();
   assert(lvl2dim.getNumInputs() == lvlRank);
@@ -39,18 +37,20 @@ AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
 }
 
 // Generates a "de"mapping reinterpretation of the map.
-Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) {
+static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
+                      Value val) {
   return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
                                           val);
 }
 
 // Generates a "re"mapping reinterpretation of the map.
-Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val) {
+static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
+                      Value val) {
   return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
 }
 
-SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
-                                   ValueRange outs) {
+static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
+                                          ValueRange outs) {
   SmallVector<Value> ret(outs);
   assert(outs.size() == types.size());
   for (auto [r, t] : llvm::zip(ret, types))
@@ -60,7 +60,7 @@ SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
 }
 
 /// Whether the operation has any sparse tensor with non-identity dim2lvl maps.
-bool hasNonIdentityOperandsOrResults(Operation *op) {
+static bool hasNonIdentityOperandsOrResults(Operation *op) {
   auto hasNonIdentityMap = [](Value v) {
     auto stt = tryGetSparseTensorType(v);
     return stt && !stt->isIdentity();
@@ -105,6 +105,8 @@ static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
   return newOp;
 }
 
+namespace {
+
 //===----------------------------------------------------------------------===//
 // Rewriting rules for linalg generic ops.
 //===----------------------------------------------------------------------===//
@@ -163,7 +165,7 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
 // CRTP to help implementing a rewriter that demaps all its inputs and remaps
 // all its outputs.
 template <typename SubClass, typename SourceOp>
-struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
+struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
   using OpRewritePattern<SourceOp>::OpRewritePattern;
   using OpAdaptor = typename SourceOp::Adaptor;
 
@@ -188,8 +190,8 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
 //===----------------------------------------------------------------------===//
 
 struct TensorInsertDemapper
-    : public DemapInsRemapOutsRewriter<TensorInsertDemapper, tensor::InsertOp> {
-  using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+    : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
+  using DemapInsRewriter::DemapInsRewriter;
   LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
                           PatternRewriter &rewriter) const {
     if (!hasAnySparseResult(op))
@@ -199,19 +201,18 @@ struct TensorInsertDemapper
     auto stt = getSparseTensorType(op.getResult());
     ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
                                           CrdTransDirectionKind::dim2lvl);
-    Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
+    auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
         loc, op.getScalar(), adaptor.getDest(), lvlCrd);
 
-    SmallVector<Value> outs(insertOp->getResults());
-    remapValueRange(rewriter, op->getResultTypes(), outs);
-    rewriter.replaceOp(op, outs);
+    Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
+    rewriter.replaceOp(op, out);
     return success();
   }
 };
 
 struct ForeachOpDemapper
-    : public DemapInsRemapOutsRewriter<ForeachOpDemapper, ForeachOp> {
-  using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
+    : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
+  using DemapInsRewriter::DemapInsRewriter;
   LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
                           PatternRewriter &rewriter) const {
     // Only handles operations with sparse input/output.
@@ -288,8 +289,8 @@ struct ForeachOpDemapper
     rewriter.finalizeRootUpdate(op);
 
     rewriter.setInsertionPointAfter(op);
-    SmallVector<Value> outs(op.getResults());
-    remapValueRange(rewriter, prevRetTps, outs);
+    SmallVector<Value> outs =
+        remapValueRange(rewriter, prevRetTps, op.getResults());
 
     // Replace all the uses of the foreach results, expect the use in
     // reinterpret_map used to remap the output.

>From c8b915c64a0faf61ad368f7d959dbec0da291b0e Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 1 Nov 2023 18:40:55 +0000
Subject: [PATCH 5/7] address comments.

---
 .../Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp   | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 71d93fb3050aeea..fbbb73e434dd22c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -215,7 +215,8 @@ struct ForeachOpDemapper
   using DemapInsRewriter::DemapInsRewriter;
   LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
                           PatternRewriter &rewriter) const {
-    // Only handles operations with sparse input/output.
+    // Only handle operations with sparse input/output with non-identity dim2lvl
+    // maps.
     if (!hasNonIdentityOperandsOrResults(op))
       return failure();
 

>From 4e8e7a69359bdb8e4548bb1440b57c6b9d260809 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 1 Nov 2023 18:47:38 +0000
Subject: [PATCH 6/7] use utils available after merging

---
 .../Transforms/SparseReinterpretMap.cpp            | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index fbbb73e434dd22c..a2a3bfbf5720322 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -266,12 +266,11 @@ struct ForeachOpDemapper
     body->eraseArgument(0);
     // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
     ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
-    SmallVector<Value> dstArgs(body->getArguments().take_back(numInitArgs));
-    // Remap back before replacement;
-    for (auto [s, d] : llvm::zip(srcArgs, dstArgs))
-      if (s.getType() != d.getType())
-        d = rewriter.create<ReinterpretMapOp>(loc, s.getType(), d);
-    rewriter.replaceAllUsesWith(srcArgs, dstArgs);
+    ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
+    // Remap back before replacement.
+    SmallVector<Value> reMappedArgs =
+        remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
+    rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
     body->eraseArguments(0, numInitArgs);
     // Block Args: [lvlCrds, DemappedArgs] and we are done.
 
@@ -281,8 +280,7 @@ struct ForeachOpDemapper
       auto yield = llvm::cast<YieldOp>(body->getTerminator());
       if (auto stt = tryGetSparseTensorType(yield.getResult());
           stt && !stt->isIdentity()) {
-        Value y = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(),
-                                                    yield.getResult());
+        Value y = genDemap(rewriter, stt->getEncoding(), yield.getResult());
         rewriter.create<YieldOp>(loc, y);
         rewriter.eraseOp(yield);
       }

>From ce6d3f741c6102c4cfd3b5cef2f2c6c793827c71 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 1 Nov 2023 18:53:32 +0000
Subject: [PATCH 7/7] minor changes

---
 .../SparseTensor/Transforms/SparseReinterpretMap.cpp  | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index a2a3bfbf5720322..d14df6db8ee6b3f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -162,8 +162,11 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
   }
 };
 
-// CRTP to help implementing a rewriter that demaps all its inputs and remaps
-// all its outputs.
+//===----------------------------------------------------------------------===//
+// Reinterpret Map Rewriters for operations other than linalg.generics
+//===----------------------------------------------------------------------===//
+
+// CRTP to help implementing a rewriter that demaps all its inputs.
 template <typename SubClass, typename SourceOp>
 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
   using OpRewritePattern<SourceOp>::OpRewritePattern;
@@ -185,10 +188,6 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
   }
 };
 
-//===----------------------------------------------------------------------===//
-// Reinterpret Map Rewriters for operations other than linalg.generics
-//===----------------------------------------------------------------------===//
-
 struct TensorInsertDemapper
     : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
   using DemapInsRewriter::DemapInsRewriter;



More information about the Mlir-commits mailing list