[Mlir-commits] [mlir] [mlir][sparse] implement sparse_tensor.crd_translate operation (PR #69653)
Peiming Liu
llvmlistbot at llvm.org
Thu Oct 19 15:47:24 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/69653
None
>From 1b505522e6103f6e993df89fc57749a4104d6268 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 19 Oct 2023 22:42:55 +0000
Subject: [PATCH] [mlir][sparse] implement sparse_tensor.crd_translate
operation
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 5 ++
.../Dialect/SparseTensor/Transforms/Passes.td | 1 +
.../SparseTensor/IR/SparseTensorDialect.cpp | 16 +++++++
.../Transforms/SparseTensorRewriting.cpp | 46 +++++++++++++------
4 files changed, 54 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index b0fbbd747b76604..2dd7f8e961929cf 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -429,6 +429,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
std::optional<uint64_t> getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const;
std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
+ //
+ // Helper function to build IR related to the encoding.
+ //
+ ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
+
//
// Printing methods.
//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index 3081f07b7bfe1c6..73ecf5061fa16ca 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -143,6 +143,7 @@ def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp">
}];
let constructor = "mlir::createPostSparsificationRewritePass()";
let dependentDialects = [
+ "affine::AffineDialect",
"arith::ArithDialect",
"bufferization::BufferizationDialect",
"linalg::LinalgDialect",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c6e7bfaf47d04d3..f601b8f6bbc283f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -415,6 +415,18 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
return getStaticDimSliceStride(toOrigDim(*this, lvl));
}
+ValueRange
+SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
+ ValueRange crds,
+ CrdTransDirectionKind dir) const {
+ if (!getImpl())
+ return crds;
+
+ SmallVector<Type> retType(getDimRank(), builder.getIndexType());
+ auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
+ return transOp.getOutCrds();
+}
+
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
#define RETURN_ON_FAIL(stmt) \
if (failed(stmt)) { \
@@ -1155,6 +1167,10 @@ LogicalResult CrdTranslateOp::verify() {
LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
+ if (getEncoder().isIdentity()) {
+ results.assign(getInCrds().begin(), getInCrds().end());
+ return success();
+ }
if (getEncoder().isPermutation()) {
AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
? getEncoder().getDimToLvl()
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index e50b14975e83d63..c69c3036e621f64 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -13,6 +13,7 @@
#include "CodegenUtils.h"
#include "LoopEmitter.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -865,6 +866,28 @@ struct TensorLike {
Value val;
};
+struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(CrdTranslateOp op,
+ PatternRewriter &rewriter) const override {
+ AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
+ ? op.getEncoder().getDimToLvl()
+ : op.getEncoder().getLvlToDim();
+ SmallVector<Value> outCrds;
+ for (AffineExpr result : map.getResults()) {
+ // TODO: we should probably expand the affine map to IR using our own
+ // rules, since affine.apply assume signed value, while the cooridinates
+ // we provided must always be signless.
+ Value trans = rewriter.create<affine::AffineApplyOp>(
+ op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
+ op.getInCrds());
+ outCrds.push_back(trans);
+ }
+ rewriter.replaceOp(op, outCrds);
+ return success();
+ }
+};
+
struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
@@ -999,13 +1022,9 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
ValueRange reduc) {
// Enters the loop, update the SSA value for insertion chain.
dstBuf.val = reduc.front();
- const Dimension dimRank = dstStt.getDimRank();
- const Level lvlRank = dstStt.getLvlRank();
- SmallVector<Value> lcvs(lvlRank);
- for (Dimension d = 0; d < dimRank; d++) {
- // FIXME: `toStoredDim` is deprecated
- lcvs[toStoredDim(dstStt.getEncoding(), d)] = dcvs[d];
- }
+
+ ValueRange lcvs = dstStt.getEncoding().translateCrds(
+ builder, loc, dcvs, CrdTransDirectionKind::dim2lvl);
if (!skipZeroCheck) {
Value cond = genIsNonzero(builder, loc, v);
@@ -1101,12 +1120,10 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
Block *srcBlock = op.getBody();
// Remap coordinates.
- SmallVector<Value> args;
- for (Dimension d = 0; d < dimRank; d++) {
- // FIXME: `toStoredDim` is deprecated
- Value dimCrd = lcvs[toStoredDim(enc, d)];
- args.push_back(dimCrd);
- }
+ ValueRange dimCrds =
+ enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
+
+ SmallVector<Value> args(dimCrds.begin(), dimCrds.end());
// Remap value.
args.push_back(val);
// Remap reduction variables.
@@ -1249,7 +1266,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableRT,
bool enableForeach,
bool enableConvert) {
- patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
+ patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
+ ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
More information about the Mlir-commits
mailing list