[Mlir-commits] [mlir] dd33481 - [mlir][sparse] add getPointerType/getIndexType to SparseTensorEncodingAttr.
Peiming Liu
llvmlistbot at llvm.org
Thu Dec 1 14:01:57 PST 2022
Author: Peiming Liu
Date: 2022-12-01T22:01:52Z
New Revision: dd33481f48f264420862d1ee9eae83f2deab7078
URL: https://github.com/llvm/llvm-project/commit/dd33481f48f264420862d1ee9eae83f2deab7078
DIFF: https://github.com/llvm/llvm-project/commit/dd33481f48f264420862d1ee9eae83f2deab7078.diff
LOG: [mlir][sparse] add getPointerType/getIndexType to SparseTensorEncodingAttr.
add new interfaces to SparseTensorEncodingAttr to construct the pointer/index types based on pointer/index bitwidth.
Reviewed By: aartbik, wrengr
Differential Revision: https://reviews.llvm.org/D139141
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index e5272a907fc92..5e472d5998d42 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -158,6 +158,14 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
"unsigned":$indexBitWidth
);
+ let extraClassDeclaration = [{
+ /// Returns the type for pointer storage based on pointerBitWidth
+ Type getPointerType() const;
+
+ /// Returns the type for index storage based on indexBitWidth
+ Type getIndexType() const;
+ }];
+
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 652e0504c5ee1..599de1e5fee3d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -41,6 +41,18 @@ static bool acceptBitWidth(unsigned bitWidth) {
}
}
+Type SparseTensorEncodingAttr::getPointerType() const {
+ unsigned ptrWidth = getPointerBitWidth();
+ Type indexType = IndexType::get(getContext());
+ return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType;
+}
+
+Type SparseTensorEncodingAttr::getIndexType() const {
+ unsigned idxWidth = getIndexBitWidth();
+ Type indexType = IndexType::get(getContext());
+ return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
+}
+
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 97f1f952e5bd5..cb4dafbc7b625 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -203,12 +203,10 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
return llvm::None;
// Construct the basic types.
auto *context = type.getContext();
- unsigned idxWidth = enc.getIndexBitWidth();
- unsigned ptrWidth = enc.getPointerBitWidth();
RankedTensorType rType = type.cast<RankedTensorType>();
Type indexType = IndexType::get(context);
- Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
- Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
+ Type idxType = enc.getIndexType();
+ Type ptrType = enc.getPointerType();
Type eltType = rType.getElementType();
//
// Sparse tensor storage scheme for rank-dimensional tensor is organized
@@ -268,21 +266,20 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// Append linear x pointers, initialized to zero. Since each compressed
// dimension initially already has a single zero entry, this maintains
// the desired "linear + 1" length property at all times.
- unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth();
- Type indexType = builder.getIndexType();
- Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
+ Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
Value ptrZero = constantZero(builder, loc, ptrType);
createPushback(builder, loc, fields, field, ptrZero, linear);
return;
}
if (isSingletonDim(rtp, r)) {
return; // nothing to do
- } // Keep compounding the size, but nothing needs to be initialized
- // at this level. We will eventually reach a compressed level or
- // otherwise the values array for the from-here "all-dense" case.
- assert(isDenseDim(rtp, r));
- Value size = sizeAtStoredDim(builder, loc, rtp, fields, r);
- linear = builder.create<arith::MulIOp>(loc, linear, size);
+ }
+ // Keep compounding the size, but nothing needs to be initialized
+ // at this level. We will eventually reach a compressed level or
+ // otherwise the values array for the from-here "all-dense" case.
+ assert(isDenseDim(rtp, r));
+ Value size = sizeAtStoredDim(builder, loc, rtp, fields, r);
+ linear = builder.create<arith::MulIOp>(loc, linear, size);
}
// Reached values array so prepare for an insertion.
Value valZero = constantZero(builder, loc, rtp.getElementType());
@@ -315,13 +312,10 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
SmallVectorImpl<Value> &fields) {
auto enc = getSparseTensorEncoding(type);
assert(enc);
- // Construct the basic types.
- unsigned idxWidth = enc.getIndexBitWidth();
- unsigned ptrWidth = enc.getPointerBitWidth();
RankedTensorType rtp = type.cast<RankedTensorType>();
Type indexType = builder.getIndexType();
- Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType;
- Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
+ Type idxType = enc.getIndexType();
+ Type ptrType = enc.getPointerType();
Type eltType = rtp.getElementType();
auto shape = rtp.getShape();
unsigned rank = shape.size();
@@ -622,9 +616,7 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
// TODO: avoid cleanup and keep compressed scheme consistent at all times?
//
if (d > 0) {
- unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth();
- Type indexType = builder.getIndexType();
- Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType;
+ Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
Value mz = constantIndex(builder, loc, getMemSizesIndex(field));
Value hi = genLoad(builder, loc, fields[memSizesIdx], mz);
Value zero = constantIndex(builder, loc, 0);
More information about the Mlir-commits
mailing list