[Mlir-commits] [mlir] 255c3f1 - [mlir][sparse] factoring out getRankedTensorType helper function
wren romano
llvmlistbot at llvm.org
Fri Jan 20 19:36:08 PST 2023
Author: wren romano
Date: 2023-01-20T19:36:01-08:00
New Revision: 255c3f11592bf320770108dc700aed47e57419f7
URL: https://github.com/llvm/llvm-project/commit/255c3f11592bf320770108dc700aed47e57419f7
DIFF: https://github.com/llvm/llvm-project/commit/255c3f11592bf320770108dc700aed47e57419f7.diff
LOG: [mlir][sparse] factoring out getRankedTensorType helper function
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D142074
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index f47d3046f6bae..a73d6275e09c4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -565,7 +565,7 @@ Value sparse_tensor::reshapeValuesToLevels(
Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc,
Value tensor, uint64_t d) {
- RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(tensor);
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
Type ptrTp = get1DMemRefType(getPointerOverheadType(builder, encSrc),
/*withLayout=*/false);
@@ -575,7 +575,7 @@ Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc,
Value sparse_tensor::genToIndices(OpBuilder &builder, Location loc,
Value tensor, uint64_t d, uint64_t cooStart) {
- RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(tensor);
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
Type indTp = get1DMemRefType(getIndexOverheadType(builder, encSrc),
/*withLayout=*/d >= cooStart);
@@ -585,7 +585,7 @@ Value sparse_tensor::genToIndices(OpBuilder &builder, Location loc,
Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
Value tensor) {
- RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(tensor);
Type valTp = get1DMemRefType(srcTp.getElementType(),
/*withLayout=*/false);
return builder.create<ToValuesOp>(loc, valTp, tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index b07991ef5f64e..8d8b0f8c0ab26 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -78,6 +78,11 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
// Misc code generators and utilities.
//===----------------------------------------------------------------------===//
+template <typename T>
+inline RankedTensorType getRankedTensorType(T t) {
+ return t.getType().template cast<RankedTensorType>();
+}
+
/// Generates a 1-valued attribute of the given type. This supports
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
/// for unsupported types we raise `llvm_unreachable` rather than
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 41a4c0599c62c..88981fccaf403 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -82,7 +82,7 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
// a scalar or 0-dimension tensors
if (isZeroRankedTensorOrScalar(t.getType()))
continue;
- auto rtp = t.getType().cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(t);
auto rank = static_cast<size_t>(rtp.getRank());
auto enc = getSparseTensorEncoding(rtp);
// We always treat sparse output tensor as dense so that we always iterate
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 4a1a0c9258610..2ce29e59029d9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -756,8 +756,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
return failure();
Location loc = op->getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- RankedTensorType srcType =
- op.getTensor().getType().cast<RankedTensorType>();
+ auto srcType = getRankedTensorType(op.getTensor());
Type eltType = srcType.getElementType();
Type boolType = rewriter.getIntegerType(1);
Type idxType = rewriter.getIndexType();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index c0b2caa3690e6..22ec4791066f3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -268,7 +268,7 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
!isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
return failure();
- auto outputType = op.getResult(0).getType().cast<RankedTensorType>();
+ auto outputType = getRankedTensorType(op.getResult(0));
// Yielding zero on newly allocated (all-zero) sparse tensors can be
// optimized out directly (regardless of dynamic or static size).
if (getSparseTensorEncoding(outputType)) {
@@ -405,8 +405,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value srcTensor = op.getSrc();
- auto srcTp = srcTensor.getType().template cast<RankedTensorType>();
- auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+ auto srcTp = getRankedTensorType(srcTensor);
+ auto dstTp = getRankedTensorType(op.getResult());
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
if (!encDst || !encSrc) {
@@ -483,8 +483,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
return failure();
}
if (encSrc) {
- RankedTensorType rtp =
- op.getSrc().getType().template cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(op.getSrc());
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
@@ -492,8 +491,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
return success();
}
if (encDst) {
- RankedTensorType rtp =
- op.getResult().getType().template cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(op.getResult());
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
@@ -511,7 +509,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto dstTp = op.getType().cast<RankedTensorType>();
+ auto dstTp = getRankedTensorType(op);
uint64_t conDim = op.getDimension().getZExtValue();
SmallVector<Value> sizes;
concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
@@ -547,7 +545,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// CSC matrices along column).
if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) {
for (auto i : op.getInputs()) {
- auto rtp = i.getType().cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(i);
auto srcEnc = getSparseTensorEncoding(rtp);
if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) {
allOrdered = true;
@@ -623,7 +621,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
// dynamically.
- int64_t d = input.getType().cast<RankedTensorType>().getShape()[conDim];
+ int64_t d = getRankedTensorType(input).getShape()[conDim];
assert(!ShapedType::isDynamic(d));
offset = rewriter.create<arith::AddIOp>(loc, offset,
constantIndex(rewriter, loc, d));
@@ -699,7 +697,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value src = op.getSource();
- RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+ auto dstTp = getRankedTensorType(op);
SmallVector<Value> sizes;
sizesFromSrc(rewriter, sizes, loc, src);
SmallVector<Value> dynSizes;
@@ -769,9 +767,9 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
LogicalResult sparse2DenseRewrite(ConvertOp op,
PatternRewriter &rewriter) const {
Location loc = op->getLoc();
- RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+ RankedTensorType dstTp = getRankedTensorType(op);
Value src = op.getSource();
- RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(src);
SmallVector<Value> sizes;
sizesForTensor(rewriter, sizes, loc, srcTp, src);
@@ -808,8 +806,8 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
PatternRewriter &rewriter) const {
Location loc = op->getLoc();
Value src = op.getSource();
- RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
- RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(src);
+ RankedTensorType dstTp = getRankedTensorType(op);
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
int64_t rank = dstTp.getRank();
@@ -928,7 +926,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
auto loc = op.getLoc();
Value input = op.getTensor();
SmallVector<Value> reduc = op.getInitArgs();
- auto rtp = input.getType().cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(input);
int64_t rank = rtp.getRank();
// Special-case: for each over a sparse constant uses its own rewriting
@@ -1015,7 +1013,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+ auto dstTp = getRankedTensorType(op.getResult());
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
if (!encDst)
return failure();
@@ -1138,7 +1136,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
// Allocate a temporary buffer for storing dimension sizes and indices.
- auto srcTp = src.getType().template cast<RankedTensorType>();
+ auto srcTp = getRankedTensorType(src);
uint64_t rank = srcTp.getRank();
Type indexTp = rewriter.getIndexType();
Value dimSizes = genAlloca(rewriter, loc, rank, indexTp);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 719b1c68ed40d..c31f20e4582fa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1589,7 +1589,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// TODO: investigate fusing the conversion with computation,
// especially if it is a direct yield!
//
- auto srcTp = tval.getType().cast<RankedTensorType>();
+ auto srcTp = getRankedTensorType(tval);
auto dstEnc = SparseTensorEncodingAttr::get(
getContext(), srcEnc.getDimLevelType(),
permute(env, env.op().getMatchingIndexingMap(t)), // new order
More information about the Mlir-commits
mailing list