[Mlir-commits] [mlir] c948922 - [mlir][sparse] Factoring out type-based function-name suffixes
wren romano
llvmlistbot at llvm.org
Tue Jan 4 16:18:01 PST 2022
Author: wren romano
Date: 2022-01-04T16:17:55-08:00
New Revision: c9489225678106c21cb8584c08f6003ba3987a5d
URL: https://github.com/llvm/llvm-project/commit/c9489225678106c21cb8584c08f6003ba3987a5d
DIFF: https://github.com/llvm/llvm-project/commit/c9489225678106c21cb8584c08f6003ba3987a5d.diff
LOG: [mlir][sparse] Factoring out type-based function-name suffixes
Depends On D115010
This changes a couple of places that used to `return failure();` to now use `llvm_unreachable()` instead. However, `Transforms/Sparsification.cpp` should be doing the necessary type checks to ensure that those cases are in fact unreachable.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D115012
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 0d45ff15e899..ea9be3bddb54 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -34,6 +34,14 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
llvm_unreachable("Unsupported overhead bitwidth");
}
+OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
+ if (tp.isIndex())
+ return OverheadType::kIndex;
+ if (auto intTp = tp.dyn_cast<IntegerType>())
+ return overheadTypeEncoding(intTp.getWidth());
+ llvm_unreachable("Unknown overhead type");
+}
+
Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
switch (ot) {
case OverheadType::kIndex:
@@ -61,6 +69,26 @@ Type mlir::sparse_tensor::getIndexOverheadType(
return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth()));
}
+StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
+ switch (ot) {
+ case OverheadType::kIndex:
+ return "";
+ case OverheadType::kU64:
+ return "64";
+ case OverheadType::kU32:
+ return "32";
+ case OverheadType::kU16:
+ return "16";
+ case OverheadType::kU8:
+ return "8";
+ }
+ llvm_unreachable("Unknown OverheadType");
+}
+
+StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
+ return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
+}
+
PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
if (elemTp.isF64())
return PrimaryType::kF64;
@@ -77,6 +105,28 @@ PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
llvm_unreachable("Unknown primary type");
}
+StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
+ switch (pt) {
+ case PrimaryType::kF64:
+ return "F64";
+ case PrimaryType::kF32:
+ return "F32";
+ case PrimaryType::kI64:
+ return "I64";
+ case PrimaryType::kI32:
+ return "I32";
+ case PrimaryType::kI16:
+ return "I16";
+ case PrimaryType::kI8:
+ return "I8";
+ }
+ llvm_unreachable("Unknown PrimaryType");
+}
+
+StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
+ return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
+}
+
DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding(
SparseTensorEncodingAttr::DimLevelType dlt) {
switch (dlt) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index fd539fe997cf..9286cca808aa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -32,6 +32,9 @@ namespace sparse_tensor {
/// Converts an overhead storage bitwidth to its internal type-encoding.
OverheadType overheadTypeEncoding(unsigned width);
+/// Converts an overhead storage type to its internal type-encoding.
+OverheadType overheadTypeEncoding(Type tp);
+
/// Converts the internal type-encoding for overhead storage to an mlir::Type.
Type getOverheadType(Builder &builder, OverheadType ot);
@@ -43,9 +46,21 @@ Type getPointerOverheadType(Builder &builder,
Type getIndexOverheadType(Builder &builder,
const SparseTensorEncodingAttr &enc);
+/// Convert OverheadType to its function-name suffix.
+StringRef overheadTypeFunctionSuffix(OverheadType ot);
+
+/// Converts an overhead storage type to its function-name suffix.
+StringRef overheadTypeFunctionSuffix(Type overheadTp);
+
/// Converts a primary storage type to its internal type-encoding.
PrimaryType primaryTypeEncoding(Type elemTp);
+/// Convert PrimaryType to its function-name suffix.
+StringRef primaryTypeFunctionSuffix(PrimaryType pt);
+
+/// Converts a primary storage type to its function-name suffix.
+StringRef primaryTypeFunctionSuffix(Type elemTp);
+
/// Converts the IR's dimension level type to its internal type-encoding.
DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 3c6817274b83..a28f9ac70b31 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -260,21 +260,7 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
Type eltType, Value ptr, Value val, Value ind,
Value perm) {
- StringRef name;
- if (eltType.isF64())
- name = "addEltF64";
- else if (eltType.isF32())
- name = "addEltF32";
- else if (eltType.isInteger(64))
- name = "addEltI64";
- else if (eltType.isInteger(32))
- name = "addEltI32";
- else if (eltType.isInteger(16))
- name = "addEltI16";
- else if (eltType.isInteger(8))
- name = "addEltI8";
- else
- llvm_unreachable("Unknown element type");
+ SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
SmallVector<Value, 4> params{ptr, val, ind, perm};
Type pTp = getOpaquePointerType(rewriter);
createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On);
@@ -287,21 +273,7 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
Value iter, Value ind, Value elemPtr) {
Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
- StringRef name;
- if (elemTp.isF64())
- name = "getNextF64";
- else if (elemTp.isF32())
- name = "getNextF32";
- else if (elemTp.isInteger(64))
- name = "getNextI64";
- else if (elemTp.isInteger(32))
- name = "getNextI32";
- else if (elemTp.isInteger(16))
- name = "getNextI16";
- else if (elemTp.isInteger(8))
- name = "getNextI8";
- else
- llvm_unreachable("Unknown element type");
+ SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
SmallVector<Value, 3> params{iter, ind, elemPtr};
Type i1 = rewriter.getI1Type();
return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On)
@@ -668,20 +640,8 @@ class SparseTensorToPointersConverter
matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
- Type eltType = resType.cast<ShapedType>().getElementType();
- StringRef name;
- if (eltType.isIndex())
- name = "sparsePointers";
- else if (eltType.isInteger(64))
- name = "sparsePointers64";
- else if (eltType.isInteger(32))
- name = "sparsePointers32";
- else if (eltType.isInteger(16))
- name = "sparsePointers16";
- else if (eltType.isInteger(8))
- name = "sparsePointers8";
- else
- return failure();
+ Type ptrType = resType.cast<ShapedType>().getElementType();
+ SmallString<16> name{"sparsePointers", overheadTypeFunctionSuffix(ptrType)};
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
EmitCInterface::On);
return success();
@@ -696,20 +656,8 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
- Type eltType = resType.cast<ShapedType>().getElementType();
- StringRef name;
- if (eltType.isIndex())
- name = "sparseIndices";
- else if (eltType.isInteger(64))
- name = "sparseIndices64";
- else if (eltType.isInteger(32))
- name = "sparseIndices32";
- else if (eltType.isInteger(16))
- name = "sparseIndices16";
- else if (eltType.isInteger(8))
- name = "sparseIndices8";
- else
- return failure();
+ Type indType = resType.cast<ShapedType>().getElementType();
+ SmallString<15> name{"sparseIndices", overheadTypeFunctionSuffix(indType)};
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
EmitCInterface::On);
return success();
@@ -725,21 +673,7 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
Type eltType = resType.cast<ShapedType>().getElementType();
- StringRef name;
- if (eltType.isF64())
- name = "sparseValuesF64";
- else if (eltType.isF32())
- name = "sparseValuesF32";
- else if (eltType.isInteger(64))
- name = "sparseValuesI64";
- else if (eltType.isInteger(32))
- name = "sparseValuesI32";
- else if (eltType.isInteger(16))
- name = "sparseValuesI16";
- else if (eltType.isInteger(8))
- name = "sparseValuesI8";
- else
- return failure();
+ SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
EmitCInterface::On);
return success();
@@ -772,23 +706,8 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
LogicalResult
matchAndRewrite(LexInsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type srcType = op.tensor().getType();
- Type eltType = srcType.cast<ShapedType>().getElementType();
- StringRef name;
- if (eltType.isF64())
- name = "lexInsertF64";
- else if (eltType.isF32())
- name = "lexInsertF32";
- else if (eltType.isInteger(64))
- name = "lexInsertI64";
- else if (eltType.isInteger(32))
- name = "lexInsertI32";
- else if (eltType.isInteger(16))
- name = "lexInsertI16";
- else if (eltType.isInteger(8))
- name = "lexInsertI8";
- else
- llvm_unreachable("Unknown element type");
+ Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType();
+ SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
TypeRange noTp;
replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
EmitCInterface::On);
@@ -843,23 +762,8 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
// all-zero/false by only iterating over the set elements, so the
// complexity remains proportional to the sparsity of the expanded
// access pattern.
- Type srcType = op.tensor().getType();
- Type eltType = srcType.cast<ShapedType>().getElementType();
- StringRef name;
- if (eltType.isF64())
- name = "expInsertF64";
- else if (eltType.isF32())
- name = "expInsertF32";
- else if (eltType.isInteger(64))
- name = "expInsertI64";
- else if (eltType.isInteger(32))
- name = "expInsertI32";
- else if (eltType.isInteger(16))
- name = "expInsertI16";
- else if (eltType.isInteger(8))
- name = "expInsertI8";
- else
- return failure();
+ Type elemTp = op.tensor().getType().cast<ShapedType>().getElementType();
+ SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
TypeRange noTp;
replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
EmitCInterface::On);
More information about the Mlir-commits
mailing list