[Mlir-commits] [mlir] [mlir][sparse] implement sparse_tensor.crd_translate operation (PR #69653)

Peiming Liu llvmlistbot at llvm.org
Thu Oct 19 16:09:10 PDT 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/69653

>From 1b505522e6103f6e993df89fc57749a4104d6268 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 22:42:55 +0000
Subject: [PATCH 1/3] [mlir][sparse] implement sparse_tensor.crd_translate
 operation

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  5 ++
 .../Dialect/SparseTensor/Transforms/Passes.td |  1 +
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 16 +++++++
 .../Transforms/SparseTensorRewriting.cpp      | 46 +++++++++++++------
 4 files changed, 54 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index b0fbbd747b76604..2dd7f8e961929cf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -429,6 +429,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     std::optional<uint64_t> getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const;
     std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
 
+    //
+    // Helper function to build IR related to the encoding.
+    //
+    ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
+
     //
     // Printing methods.
     //
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 3081f07b7bfe1c6..73ecf5061fa16ca 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -143,6 +143,7 @@ def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp">
   }];
   let constructor = "mlir::createPostSparsificationRewritePass()";
   let dependentDialects = [
+    "affine::AffineDialect",
     "arith::ArithDialect",
     "bufferization::BufferizationDialect",
     "linalg::LinalgDialect",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c6e7bfaf47d04d3..f601b8f6bbc283f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -415,6 +415,18 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
   return getStaticDimSliceStride(toOrigDim(*this, lvl));
 }
 
+ValueRange
+SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
+                                        ValueRange crds,
+                                        CrdTransDirectionKind dir) const {
+  if (!getImpl())
+    return crds;
+
+  SmallVector<Type> retType(getDimRank(), builder.getIndexType());
+  auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
+  return transOp.getOutCrds();
+}
+
 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 #define RETURN_ON_FAIL(stmt)                                                   \
   if (failed(stmt)) {                                                          \
@@ -1155,6 +1167,10 @@ LogicalResult CrdTranslateOp::verify() {
 
 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
                                    SmallVectorImpl<OpFoldResult> &results) {
+  if (getEncoder().isIdentity()) {
+    results.assign(getInCrds().begin(), getInCrds().end());
+    return success();
+  }
   if (getEncoder().isPermutation()) {
     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
                          ? getEncoder().getDimToLvl()
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index e50b14975e83d63..c69c3036e621f64 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -13,6 +13,7 @@
 #include "CodegenUtils.h"
 #include "LoopEmitter.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -865,6 +866,28 @@ struct TensorLike {
   Value val;
 };
 
+struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(CrdTranslateOp op,
+                                PatternRewriter &rewriter) const override {
+    AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
+                        ? op.getEncoder().getDimToLvl()
+                        : op.getEncoder().getLvlToDim();
+    SmallVector<Value> outCrds;
+    for (AffineExpr result : map.getResults()) {
+      // TODO: we should probably expand the affine map to IR using our own
+      // rules, since affine.apply assume signed value, while the cooridinates
+      // we provided must always be signless.
+      Value trans = rewriter.create<affine::AffineApplyOp>(
+          op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
+          op.getInCrds());
+      outCrds.push_back(trans);
+    }
+    rewriter.replaceOp(op, outCrds);
+    return success();
+  }
+};
+
 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ConcatenateOp op,
@@ -999,13 +1022,9 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
             ValueRange reduc) {
           // Enters the loop, update the SSA value for insertion chain.
           dstBuf.val = reduc.front();
-          const Dimension dimRank = dstStt.getDimRank();
-          const Level lvlRank = dstStt.getLvlRank();
-          SmallVector<Value> lcvs(lvlRank);
-          for (Dimension d = 0; d < dimRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            lcvs[toStoredDim(dstStt.getEncoding(), d)] = dcvs[d];
-          }
+
+          ValueRange lcvs = dstStt.getEncoding().translateCrds(
+              builder, loc, dcvs, CrdTransDirectionKind::dim2lvl);
 
           if (!skipZeroCheck) {
             Value cond = genIsNonzero(builder, loc, v);
@@ -1101,12 +1120,10 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
     Block *srcBlock = op.getBody();
 
     // Remap coordinates.
-    SmallVector<Value> args;
-    for (Dimension d = 0; d < dimRank; d++) {
-      // FIXME: `toStoredDim` is deprecated
-      Value dimCrd = lcvs[toStoredDim(enc, d)];
-      args.push_back(dimCrd);
-    }
+    ValueRange dimCrds =
+        enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
+
+    SmallVector<Value> args(dimCrds.begin(), dimCrds.end());
     // Remap value.
     args.push_back(val);
     // Remap reduction variables.
@@ -1249,7 +1266,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
                                                bool enableRT,
                                                bool enableForeach,
                                                bool enableConvert) {
-  patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
+  patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
+               ReshapeRewriter<tensor::ExpandShapeOp>,
                ReshapeRewriter<tensor::CollapseShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,

>From aff94cdb4d23b86f5bf1388d917064e661a7fb3b Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 23:05:58 +0000
Subject: [PATCH 2/3] small fixes

---
 .../Dialect/SparseTensor/IR/SparseTensorType.h   | 16 +++++++++++++---
 .../Transforms/SparseTensorRewriting.cpp         |  5 ++---
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index b5dbc67781ce004..c3e967fdcd90fc0 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -176,6 +176,10 @@ class SparseTensorType {
   /// Returns the encoding (or the null-attribute for dense-tensors).
   SparseTensorEncodingAttr getEncoding() const { return enc; }
 
+  //
+  // SparseTensorEncodingAttr delegators
+  //
+
   /// Returns true for tensors which have an encoding, and false for
   /// those which do not.  Therefore tensors with an all-dense encoding
   /// return true.
@@ -189,14 +193,20 @@ class SparseTensorType {
   /// (This is always true for dense-tensors.)
   bool isAllOrdered() const { return enc.isAllOrdered(); }
 
-  /// Returns true if the dimToLvl mapping is the identity.
-  /// (This is always true for dense-tensors.)
-  bool isIdentity() const { return !dimToLvl; }
+  /// Translates between level / dimension coordinate space.
+  ValueRange translateCrds(OpBuilder &builder, Location loc, ValueRange crds,
+                           CrdTransDirectionKind dir) const {
+    return enc.translateCrds(builder, loc, crds, dir);
+  }
 
   /// Returns true if the dimToLvl mapping is a permutation.
   /// (This is always true for dense-tensors.)
   bool isPermutation() const { return enc.isPermutation(); }
 
+  /// Returns true if the dimToLvl mapping is the identity.
+  /// (This is always true for dense-tensors.)
+  bool isIdentity() const { return enc.isIdentity(); }
+
   /// Returns the dimToLvl mapping (or the null-map for the identity).
   /// If you intend to compare the results of this method for equality,
   /// see `hasSameDimToLvl` instead.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index c69c3036e621f64..c168f10dadbbfb5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1023,7 +1023,7 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
           // Enters the loop, update the SSA value for insertion chain.
           dstBuf.val = reduc.front();
 
-          ValueRange lcvs = dstStt.getEncoding().translateCrds(
+          ValueRange lcvs = dstStt.translateCrds(
               builder, loc, dcvs, CrdTransDirectionKind::dim2lvl);
 
           if (!skipZeroCheck) {
@@ -1120,10 +1120,9 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
     Block *srcBlock = op.getBody();
 
     // Remap coordinates.
-    ValueRange dimCrds =
+    SmallVector<Value> args =
         enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
 
-    SmallVector<Value> args(dimCrds.begin(), dimCrds.end());
     // Remap value.
     args.push_back(val);
     // Remap reduction variables.

>From 0d680840db6c26df84317a204ddca453ff782173 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 23:08:57 +0000
Subject: [PATCH 3/3] fix bugs

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index f601b8f6bbc283f..fe2eaebbfa45d5a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -422,7 +422,9 @@ SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
   if (!getImpl())
     return crds;
 
-  SmallVector<Type> retType(getDimRank(), builder.getIndexType());
+  SmallVector<Type> retType(
+      dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
+      builder.getIndexType());
   auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
   return transOp.getOutCrds();
 }



More information about the Mlir-commits mailing list