[Mlir-commits] [mlir] d7d7ffe - [mlir][sparse] Adding wrappers for constantOverheadTypeEncoding
wren romano
llvmlistbot at llvm.org
Tue Nov 23 18:30:13 PST 2021
Author: wren romano
Date: 2021-11-23T18:30:06-08:00
New Revision: d7d7ffe254d53cf0860126ab4c3f5db18c927892
URL: https://github.com/llvm/llvm-project/commit/d7d7ffe254d53cf0860126ab4c3f5db18c927892
DIFF: https://github.com/llvm/llvm-project/commit/d7d7ffe254d53cf0860126ab4c3f5db18c927892.diff
LOG: [mlir][sparse] Adding wrappers for constantOverheadTypeEncoding
Minor code cleanup
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D114392
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 3633ff02d83fa..88c0561521d7a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -85,6 +85,22 @@ static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
}
+/// Generates a constant of the internal type encoding for pointer
+/// overhead storage.
+static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter,
+ Location loc,
+ SparseTensorEncodingAttr &enc) {
+ return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth());
+}
+
+/// Generates a constant of the internal type encoding for index overhead
+/// storage.
+static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter,
+ Location loc,
+ SparseTensorEncodingAttr &enc) {
+ return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth());
+}
+
/// Generates a constant of the internal type encoding for primary storage.
static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc, Type tp) {
@@ -277,10 +293,8 @@ static void newParams(ConversionPatternRewriter &rewriter,
params.push_back(genBuffer(rewriter, loc, rev));
// Secondary and primary types encoding.
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
- params.push_back(
- constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()));
- params.push_back(
- constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()));
+ params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
+ params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
params.push_back(
constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
// User action and pointer.
@@ -598,10 +612,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, op, params);
- params[3] = constantOverheadTypeEncoding(rewriter, loc,
- encDst.getPointerBitWidth());
- params[4] = constantOverheadTypeEncoding(rewriter, loc,
- encDst.getIndexBitWidth());
+ params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
+ params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
More information about the Mlir-commits
mailing list