[Mlir-commits] [mlir] d7d7ffe - [mlir][sparse] Adding wrappers for constantOverheadTypeEncoding

wren romano llvmlistbot at llvm.org
Tue Nov 23 18:30:13 PST 2021


Author: wren romano
Date: 2021-11-23T18:30:06-08:00
New Revision: d7d7ffe254d53cf0860126ab4c3f5db18c927892

URL: https://github.com/llvm/llvm-project/commit/d7d7ffe254d53cf0860126ab4c3f5db18c927892
DIFF: https://github.com/llvm/llvm-project/commit/d7d7ffe254d53cf0860126ab4c3f5db18c927892.diff

LOG: [mlir][sparse] Adding wrappers for constantOverheadTypeEncoding

Minor code cleanup

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D114392

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 3633ff02d83fa..88c0561521d7a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -85,6 +85,22 @@ static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
   return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
 }
 
+/// Generates a constant of the internal type encoding for pointer
+/// overhead storage.
+static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter,
+                                         Location loc,
+                                         SparseTensorEncodingAttr &enc) {
+  return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth());
+}
+
+/// Generates a constant of the internal type encoding for index overhead
+/// storage.
+static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter,
+                                       Location loc,
+                                       SparseTensorEncodingAttr &enc) {
+  return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth());
+}
+
 /// Generates a constant of the internal type encoding for primary storage.
 static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
                                          Location loc, Type tp) {
@@ -277,10 +293,8 @@ static void newParams(ConversionPatternRewriter &rewriter,
   params.push_back(genBuffer(rewriter, loc, rev));
   // Secondary and primary types encoding.
   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
-  params.push_back(
-      constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()));
-  params.push_back(
-      constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()));
+  params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
+  params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
   params.push_back(
       constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
   // User action and pointer.
@@ -598,10 +612,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
           encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
       newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
       Value coo = genNewCall(rewriter, op, params);
-      params[3] = constantOverheadTypeEncoding(rewriter, loc,
-                                               encDst.getPointerBitWidth());
-      params[4] = constantOverheadTypeEncoding(rewriter, loc,
-                                               encDst.getIndexBitWidth());
+      params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
+      params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
       params[6] = constantAction(rewriter, loc, Action::kFromCOO);
       params[7] = coo;
       rewriter.replaceOp(op, genNewCall(rewriter, op, params));


        


More information about the Mlir-commits mailing list