[Mlir-commits] [mlir] c948922 - [mlir][sparse] Factoring out type-based function-name suffixes

wren romano llvmlistbot at llvm.org
Tue Jan 4 16:18:01 PST 2022


Author: wren romano
Date: 2022-01-04T16:17:55-08:00
New Revision: c9489225678106c21cb8584c08f6003ba3987a5d

URL: https://github.com/llvm/llvm-project/commit/c9489225678106c21cb8584c08f6003ba3987a5d
DIFF: https://github.com/llvm/llvm-project/commit/c9489225678106c21cb8584c08f6003ba3987a5d.diff

LOG: [mlir][sparse] Factoring out type-based function-name suffixes

Depends On D115010

This changes a couple of places that used to `return failure();` to now use `llvm_unreachable()` instead. However, `Transforms/Sparsification.cpp` should be doing the necessary type checks to ensure that those cases are in fact unreachable.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D115012

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 0d45ff15e899..ea9be3bddb54 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -34,6 +34,14 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
   llvm_unreachable("Unsupported overhead bitwidth");
 }
 
+OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
+  if (tp.isIndex())
+    return OverheadType::kIndex;
+  if (auto intTp = tp.dyn_cast<IntegerType>())
+    return overheadTypeEncoding(intTp.getWidth());
+  llvm_unreachable("Unknown overhead type");
+}
+
 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
   switch (ot) {
   case OverheadType::kIndex:
@@ -61,6 +69,26 @@ Type mlir::sparse_tensor::getIndexOverheadType(
   return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth()));
 }
 
+StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
+  switch (ot) {
+  case OverheadType::kIndex:
+    return "";
+  case OverheadType::kU64:
+    return "64";
+  case OverheadType::kU32:
+    return "32";
+  case OverheadType::kU16:
+    return "16";
+  case OverheadType::kU8:
+    return "8";
+  }
+  llvm_unreachable("Unknown OverheadType");
+}
+
+StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
+  return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
+}
+
 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
   if (elemTp.isF64())
     return PrimaryType::kF64;
@@ -77,6 +105,28 @@ PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
   llvm_unreachable("Unknown primary type");
 }
 
+StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
+  switch (pt) {
+  case PrimaryType::kF64:
+    return "F64";
+  case PrimaryType::kF32:
+    return "F32";
+  case PrimaryType::kI64:
+    return "I64";
+  case PrimaryType::kI32:
+    return "I32";
+  case PrimaryType::kI16:
+    return "I16";
+  case PrimaryType::kI8:
+    return "I8";
+  }
+  llvm_unreachable("Unknown PrimaryType");
+}
+
+StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
+  return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
+}
+
 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding(
     SparseTensorEncodingAttr::DimLevelType dlt) {
   switch (dlt) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index fd539fe997cf..9286cca808aa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -32,6 +32,9 @@ namespace sparse_tensor {
 /// Converts an overhead storage bitwidth to its internal type-encoding.
 OverheadType overheadTypeEncoding(unsigned width);
 
+/// Converts an overhead storage type to its internal type-encoding.
+OverheadType overheadTypeEncoding(Type tp);
+
 /// Converts the internal type-encoding for overhead storage to an mlir::Type.
 Type getOverheadType(Builder &builder, OverheadType ot);
 
@@ -43,9 +46,21 @@ Type getPointerOverheadType(Builder &builder,
 Type getIndexOverheadType(Builder &builder,
                           const SparseTensorEncodingAttr &enc);
 
+/// Convert OverheadType to its function-name suffix.
+StringRef overheadTypeFunctionSuffix(OverheadType ot);
+
+/// Converts an overhead storage type to its function-name suffix.
+StringRef overheadTypeFunctionSuffix(Type overheadTp);
+
 /// Converts a primary storage type to its internal type-encoding.
 PrimaryType primaryTypeEncoding(Type elemTp);
 
+/// Convert PrimaryType to its function-name suffix.
+StringRef primaryTypeFunctionSuffix(PrimaryType pt);
+
+/// Converts a primary storage type to its function-name suffix.
+StringRef primaryTypeFunctionSuffix(Type elemTp);
+
 /// Converts the IR's dimension level type to its internal type-encoding.
 DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt);
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 3c6817274b83..a28f9ac70b31 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -260,21 +260,7 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
                           Type eltType, Value ptr, Value val, Value ind,
                           Value perm) {
-  StringRef name;
-  if (eltType.isF64())
-    name = "addEltF64";
-  else if (eltType.isF32())
-    name = "addEltF32";
-  else if (eltType.isInteger(64))
-    name = "addEltI64";
-  else if (eltType.isInteger(32))
-    name = "addEltI32";
-  else if (eltType.isInteger(16))
-    name = "addEltI16";
-  else if (eltType.isInteger(8))
-    name = "addEltI8";
-  else
-    llvm_unreachable("Unknown element type");
+  SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
   SmallVector<Value, 4> params{ptr, val, ind, perm};
   Type pTp = getOpaquePointerType(rewriter);
   createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On);
@@ -287,21 +273,7 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
                             Value iter, Value ind, Value elemPtr) {
   Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
-  StringRef name;
-  if (elemTp.isF64())
-    name = "getNextF64";
-  else if (elemTp.isF32())
-    name = "getNextF32";
-  else if (elemTp.isInteger(64))
-    name = "getNextI64";
-  else if (elemTp.isInteger(32))
-    name = "getNextI32";
-  else if (elemTp.isInteger(16))
-    name = "getNextI16";
-  else if (elemTp.isInteger(8))
-    name = "getNextI8";
-  else
-    llvm_unreachable("Unknown element type");
+  SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
   SmallVector<Value, 3> params{iter, ind, elemPtr};
   Type i1 = rewriter.getI1Type();
   return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On)
@@ -668,20 +640,8 @@ class SparseTensorToPointersConverter
   matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
-    Type eltType = resType.cast<ShapedType>().getElementType();
-    StringRef name;
-    if (eltType.isIndex())
-      name = "sparsePointers";
-    else if (eltType.isInteger(64))
-      name = "sparsePointers64";
-    else if (eltType.isInteger(32))
-      name = "sparsePointers32";
-    else if (eltType.isInteger(16))
-      name = "sparsePointers16";
-    else if (eltType.isInteger(8))
-      name = "sparsePointers8";
-    else
-      return failure();
+    Type ptrType = resType.cast<ShapedType>().getElementType();
+    SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)};
     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
                           EmitCInterface::On);
     return success();
@@ -696,20 +656,8 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
   matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
-    Type eltType = resType.cast<ShapedType>().getElementType();
-    StringRef name;
-    if (eltType.isIndex())
-      name = "sparseIndices";
-    else if (eltType.isInteger(64))
-      name = "sparseIndices64";
-    else if (eltType.isInteger(32))
-      name = "sparseIndices32";
-    else if (eltType.isInteger(16))
-      name = "sparseIndices16";
-    else if (eltType.isInteger(8))
-      name = "sparseIndices8";
-    else
-      return failure();
+    Type indType = resType.cast<ShapedType>().getElementType();
+    SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)};
     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
                           EmitCInterface::On);
     return success();
@@ -725,21 +673,7 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
                   ConversionPatternRewriter &rewriter) const override {
     Type resType = op.getType();
     Type eltType = resType.cast<ShapedType>().getElementType();
-    StringRef name;
-    if (eltType.isF64())
-      name = "sparseValuesF64";
-    else if (eltType.isF32())
-      name = "sparseValuesF32";
-    else if (eltType.isInteger(64))
-      name = "sparseValuesI64";
-    else if (eltType.isInteger(32))
-      name = "sparseValuesI32";
-    else if (eltType.isInteger(16))
-      name = "sparseValuesI16";
-    else if (eltType.isInteger(8))
-      name = "sparseValuesI8";
-    else
-      return failure();
+    SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
                           EmitCInterface::On);
     return success();
@@ -772,23 +706,8 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
   LogicalResult
   matchAndRewrite(LexInsertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type srcType = op.tensor().getType();
-    Type eltType = srcType.cast<ShapedType>().getElementType();
-    StringRef name;
-    if (eltType.isF64())
-      name = "lexInsertF64";
-    else if (eltType.isF32())
-      name = "lexInsertF32";
-    else if (eltType.isInteger(64))
-      name = "lexInsertI64";
-    else if (eltType.isInteger(32))
-      name = "lexInsertI32";
-    else if (eltType.isInteger(16))
-      name = "lexInsertI16";
-    else if (eltType.isInteger(8))
-      name = "lexInsertI8";
-    else
-      llvm_unreachable("Unknown element type");
+    Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType();
+    SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
     TypeRange noTp;
     replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
                           EmitCInterface::On);
@@ -843,23 +762,8 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
     // all-zero/false by only iterating over the set elements, so the
     // complexity remains proportional to the sparsity of the expanded
     // access pattern.
-    Type srcType = op.tensor().getType();
-    Type eltType = srcType.cast<ShapedType>().getElementType();
-    StringRef name;
-    if (eltType.isF64())
-      name = "expInsertF64";
-    else if (eltType.isF32())
-      name = "expInsertF32";
-    else if (eltType.isInteger(64))
-      name = "expInsertI64";
-    else if (eltType.isInteger(32))
-      name = "expInsertI32";
-    else if (eltType.isInteger(16))
-      name = "expInsertI16";
-    else if (eltType.isInteger(8))
-      name = "expInsertI8";
-    else
-      return failure();
+    Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType();
+    SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
     TypeRange noTp;
     replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
                           EmitCInterface::On);


        


More information about the Mlir-commits mailing list