[Mlir-commits] [mlir] d8731bf - [mlir][sparse] Requiring emitCInterface parameter to be explicit
wren romano
llvmlistbot at llvm.org
Mon Dec 6 20:50:17 PST 2021
Author: wren romano
Date: 2021-12-06T20:50:08-08:00
New Revision: d8731bfc93c2946de6d1e2242cea25d0a76aeb85
URL: https://github.com/llvm/llvm-project/commit/d8731bfc93c2946de6d1e2242cea25d0a76aeb85
DIFF: https://github.com/llvm/llvm-project/commit/d8731bfc93c2946de6d1e2242cea25d0a76aeb85.diff
LOG: [mlir][sparse] Requiring emitCInterface parameter to be explicit
Depends On D115004
Cleans up code legibility by requiring the `emitCInterface` parameter to be explicit at all call-sites, and defining boolean aliases for that parameter.
Reviewed By: aartbik, rriddle
Differential Revision: https://reviews.llvm.org/D115005
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index d56fae07d47fc..20dda09ab4c51 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -31,6 +31,10 @@ using namespace mlir::sparse_tensor;
namespace {
+/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
+/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
+enum class EmitCInterface : bool { Off = false, On = true };
+
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
@@ -154,7 +158,7 @@ static Type getOpaquePointerType(PatternRewriter &rewriter) {
/// of ABI complications with passing in and returning MemRefs to C functions.
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
TypeRange resultType, ValueRange operands,
- bool emitCInterface) {
+ EmitCInterface emitCInterface) {
MLIRContext *context = op->getContext();
auto module = op->getParentOfType<ModuleOp>();
auto result = SymbolRefAttr::get(context, name);
@@ -165,7 +169,7 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
op->getLoc(), name,
FunctionType::get(context, operands.getTypes(), resultType));
func.setPrivate();
- if (emitCInterface)
+ if (static_cast<bool>(emitCInterface))
func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
}
return result;
@@ -174,7 +178,7 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
/// Creates a `CallOp` to the function reference returned by `getFunc()`.
static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
TypeRange resultType, ValueRange operands,
- bool emitCInterface = false) {
+ EmitCInterface emitCInterface) {
auto fn = getFunc(op, name, resultType, operands, emitCInterface);
return builder.create<CallOp>(op->getLoc(), resultType, fn, operands);
}
@@ -184,7 +188,7 @@ static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op,
StringRef name, TypeRange resultType,
ValueRange operands,
- bool emitCInterface = false) {
+ EmitCInterface emitCInterface) {
auto fn = getFunc(op, name, resultType, operands, emitCInterface);
return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands);
}
@@ -200,7 +204,8 @@ static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
StringRef name = "sparseDimSize";
SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)};
Type iTp = rewriter.getIndexType();
- return createFuncCall(rewriter, op, name, iTp, params).getResult(0);
+ return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off)
+ .getResult(0);
}
/// Generates a call into the "swiss army knife" method of the sparse runtime
@@ -209,9 +214,8 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
ArrayRef<Value> params) {
StringRef name = "newSparseTensor";
Type pTp = getOpaquePointerType(rewriter);
- auto call = createFuncCall(rewriter, op, name, pTp, params,
- /*emitCInterface=*/true);
- return call.getResult(0);
+ return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On)
+ .getResult(0);
}
/// Populates given sizes array from type.
@@ -388,7 +392,7 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
llvm_unreachable("Unknown element type");
SmallVector<Value, 4> params{ptr, val, ind, perm};
Type pTp = getOpaquePointerType(rewriter);
- createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true);
+ createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On);
}
/// Generates a call to `iter->getNext()`. If there is a next element,
@@ -415,9 +419,8 @@ static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
llvm_unreachable("Unknown element type");
SmallVector<Value, 3> params{iter, ind, elemPtr};
Type i1 = rewriter.getI1Type();
- auto call = createFuncCall(rewriter, op, name, i1, params,
- /*emitCInterface=*/true);
- return call.getResult(0);
+ return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On)
+ .getResult(0);
}
/// If the tensor is a sparse constant, generates and returns the pair of
@@ -764,7 +767,8 @@ class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
ConversionPatternRewriter &rewriter) const override {
StringRef name = "delSparseTensor";
TypeRange noTp;
- createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
+ createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
+ EmitCInterface::Off);
rewriter.eraseOp(op);
return success();
}
@@ -794,7 +798,7 @@ class SparseTensorToPointersConverter
else
return failure();
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
- /*emitCInterface=*/true);
+ EmitCInterface::On);
return success();
}
};
@@ -822,7 +826,7 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
else
return failure();
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
- /*emitCInterface=*/true);
+ EmitCInterface::On);
return success();
}
};
@@ -852,7 +856,7 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
else
return failure();
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
- /*emitCInterface=*/true);
+ EmitCInterface::On);
return success();
}
};
@@ -868,7 +872,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
// Finalize any pending insertions.
StringRef name = "endInsert";
TypeRange noTp;
- createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
+ createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
+ EmitCInterface::Off);
}
rewriter.replaceOp(op, adaptor.getOperands());
return success();
@@ -901,7 +906,7 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
llvm_unreachable("Unknown element type");
TypeRange noTp;
replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
- /*emitCInterface=*/true);
+ EmitCInterface::On);
return success();
}
};
More information about the Mlir-commits
mailing list