[Mlir-commits] [mlir] [mlir][sparse] implement sparse_tensor.crd_translate operation (PR #69653)

Peiming Liu llvmlistbot at llvm.org
Thu Oct 19 15:47:24 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/69653

None

>From 1b505522e6103f6e993df89fc57749a4104d6268 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 22:42:55 +0000
Subject: [PATCH] [mlir][sparse] implement sparse_tensor.crd_translate
 operation

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  5 ++
 .../Dialect/SparseTensor/Transforms/Passes.td |  1 +
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 16 +++++++
 .../Transforms/SparseTensorRewriting.cpp      | 46 +++++++++++++------
 4 files changed, 54 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index b0fbbd747b76604..2dd7f8e961929cf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -429,6 +429,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     std::optional<uint64_t> getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const;
     std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
 
+    //
+    // Helper function to build IR related to the encoding.
+    //
+    ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
+
     //
     // Printing methods.
     //
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 3081f07b7bfe1c6..73ecf5061fa16ca 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -143,6 +143,7 @@ def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp">
   }];
   let constructor = "mlir::createPostSparsificationRewritePass()";
   let dependentDialects = [
+    "affine::AffineDialect",
     "arith::ArithDialect",
     "bufferization::BufferizationDialect",
     "linalg::LinalgDialect",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c6e7bfaf47d04d3..f601b8f6bbc283f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -415,6 +415,18 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
   return getStaticDimSliceStride(toOrigDim(*this, lvl));
 }
 
+ValueRange
+SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
+                                        ValueRange crds,
+                                        CrdTransDirectionKind dir) const {
+  if (!getImpl())
+    return crds;
+
+  SmallVector<Type> retType(getDimRank(), builder.getIndexType());
+  auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
+  return transOp.getOutCrds();
+}
+
 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
 #define RETURN_ON_FAIL(stmt)                                                   \
   if (failed(stmt)) {                                                          \
@@ -1155,6 +1167,10 @@ LogicalResult CrdTranslateOp::verify() {
 
 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
                                    SmallVectorImpl<OpFoldResult> &results) {
+  if (getEncoder().isIdentity()) {
+    results.assign(getInCrds().begin(), getInCrds().end());
+    return success();
+  }
   if (getEncoder().isPermutation()) {
     AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
                          ? getEncoder().getDimToLvl()
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index e50b14975e83d63..c69c3036e621f64 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -13,6 +13,7 @@
 #include "CodegenUtils.h"
 #include "LoopEmitter.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -865,6 +866,28 @@ struct TensorLike {
   Value val;
 };
 
+struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(CrdTranslateOp op,
+                                PatternRewriter &rewriter) const override {
+    AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
+                        ? op.getEncoder().getDimToLvl()
+                        : op.getEncoder().getLvlToDim();
+    SmallVector<Value> outCrds;
+    for (AffineExpr result : map.getResults()) {
+      // TODO: we should probably expand the affine map to IR using our own
+      // rules, since affine.apply assume signed value, while the cooridinates
+      // we provided must always be signless.
+      Value trans = rewriter.create<affine::AffineApplyOp>(
+          op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
+          op.getInCrds());
+      outCrds.push_back(trans);
+    }
+    rewriter.replaceOp(op, outCrds);
+    return success();
+  }
+};
+
 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ConcatenateOp op,
@@ -999,13 +1022,9 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
             ValueRange reduc) {
           // Enters the loop, update the SSA value for insertion chain.
           dstBuf.val = reduc.front();
-          const Dimension dimRank = dstStt.getDimRank();
-          const Level lvlRank = dstStt.getLvlRank();
-          SmallVector<Value> lcvs(lvlRank);
-          for (Dimension d = 0; d < dimRank; d++) {
-            // FIXME: `toStoredDim` is deprecated
-            lcvs[toStoredDim(dstStt.getEncoding(), d)] = dcvs[d];
-          }
+
+          ValueRange lcvs = dstStt.getEncoding().translateCrds(
+              builder, loc, dcvs, CrdTransDirectionKind::dim2lvl);
 
           if (!skipZeroCheck) {
             Value cond = genIsNonzero(builder, loc, v);
@@ -1101,12 +1120,10 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
     Block *srcBlock = op.getBody();
 
     // Remap coordinates.
-    SmallVector<Value> args;
-    for (Dimension d = 0; d < dimRank; d++) {
-      // FIXME: `toStoredDim` is deprecated
-      Value dimCrd = lcvs[toStoredDim(enc, d)];
-      args.push_back(dimCrd);
-    }
+    ValueRange dimCrds =
+        enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
+
+    SmallVector<Value> args(dimCrds.begin(), dimCrds.end());
     // Remap value.
     args.push_back(val);
     // Remap reduction variables.
@@ -1249,7 +1266,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
                                                bool enableRT,
                                                bool enableForeach,
                                                bool enableConvert) {
-  patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
+  patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
+               ReshapeRewriter<tensor::ExpandShapeOp>,
                ReshapeRewriter<tensor::CollapseShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
                Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,



More information about the Mlir-commits mailing list