[Mlir-commits] [mlir] d8731bf - [mlir][sparse] Requiring emitCInterface parameter to be explicit

wren romano llvmlistbot at llvm.org
Mon Dec 6 20:50:17 PST 2021


Author: wren romano
Date: 2021-12-06T20:50:08-08:00
New Revision: d8731bfc93c2946de6d1e2242cea25d0a76aeb85

URL: https://github.com/llvm/llvm-project/commit/d8731bfc93c2946de6d1e2242cea25d0a76aeb85
DIFF: https://github.com/llvm/llvm-project/commit/d8731bfc93c2946de6d1e2242cea25d0a76aeb85.diff

LOG: [mlir][sparse] Requiring emitCInterface parameter to be explicit

Depends On D115004

Cleans up code legibility by requiring the `emitCInterface` parameter to be explicit at all call-sites, and defining boolean aliases for that parameter.

Reviewed By: aartbik, rriddle

Differential Revision: https://reviews.llvm.org/D115005

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index d56fae07d47fc..20dda09ab4c51 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -31,6 +31,10 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
+/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`,
+/// `createFuncCall()`, and `replaceOpWithFuncCall()`.
+enum class EmitCInterface : bool { Off = false, On = true };
+
 //===----------------------------------------------------------------------===//
 // Helper methods.
 //===----------------------------------------------------------------------===//
@@ -154,7 +158,7 @@ static Type getOpaquePointerType(PatternRewriter &rewriter) {
 /// of ABI complications with passing in and returning MemRefs to C functions.
 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
                                  TypeRange resultType, ValueRange operands,
-                                 bool emitCInterface) {
+                                 EmitCInterface emitCInterface) {
   MLIRContext *context = op->getContext();
   auto module = op->getParentOfType<ModuleOp>();
   auto result = SymbolRefAttr::get(context, name);
@@ -165,7 +169,7 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
         op->getLoc(), name,
         FunctionType::get(context, operands.getTypes(), resultType));
     func.setPrivate();
-    if (emitCInterface)
+    if (static_cast<bool>(emitCInterface))
       func->setAttr("llvm.emit_c_interface", UnitAttr::get(context));
   }
   return result;
@@ -174,7 +178,7 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
 /// Creates a `CallOp` to the function reference returned by `getFunc()`.
 static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
                              TypeRange resultType, ValueRange operands,
-                             bool emitCInterface = false) {
+                             EmitCInterface emitCInterface) {
   auto fn = getFunc(op, name, resultType, operands, emitCInterface);
   return builder.create<CallOp>(op->getLoc(), resultType, fn, operands);
 }
@@ -184,7 +188,7 @@ static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
 static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op,
                                     StringRef name, TypeRange resultType,
                                     ValueRange operands,
-                                    bool emitCInterface = false) {
+                                    EmitCInterface emitCInterface) {
   auto fn = getFunc(op, name, resultType, operands, emitCInterface);
   return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands);
 }
@@ -200,7 +204,8 @@ static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
   StringRef name = "sparseDimSize";
   SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)};
   Type iTp = rewriter.getIndexType();
-  return createFuncCall(rewriter, op, name, iTp, params).getResult(0);
+  return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off)
+      .getResult(0);
 }
 
 /// Generates a call into the "swiss army knife" method of the sparse runtime
@@ -209,9 +214,8 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
                         ArrayRef<Value> params) {
   StringRef name = "newSparseTensor";
   Type pTp = getOpaquePointerType(rewriter);
-  auto call = createFuncCall(rewriter, op, name, pTp, params,
-                             /*emitCInterface=*/true);
-  return call.getResult(0);
+  return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On)
+      .getResult(0);
 }
 
 /// Populates given sizes array from type.
@@ -388,7 +392,7 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
     llvm_unreachable("Unknown element type");
   SmallVector<Value, 4> params{ptr, val, ind, perm};
   Type pTp = getOpaquePointerType(rewriter);
-  createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true);
+  createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On);
 }
 
 /// Generates a call to `iter->getNext()`.  If there is a next element,
@@ -415,9 +419,8 @@ static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
     llvm_unreachable("Unknown element type");
   SmallVector<Value, 3> params{iter, ind, elemPtr};
   Type i1 = rewriter.getI1Type();
-  auto call = createFuncCall(rewriter, op, name, i1, params,
-                             /*emitCInterface=*/true);
-  return call.getResult(0);
+  return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On)
+      .getResult(0);
 }
 
 /// If the tensor is a sparse constant, generates and returns the pair of
@@ -764,7 +767,8 @@ class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
                   ConversionPatternRewriter &rewriter) const override {
     StringRef name = "delSparseTensor";
     TypeRange noTp;
-    createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
+    createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
+                   EmitCInterface::Off);
     rewriter.eraseOp(op);
     return success();
   }
@@ -794,7 +798,7 @@ class SparseTensorToPointersConverter
     else
       return failure();
     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
-                          /*emitCInterface=*/true);
+                          EmitCInterface::On);
     return success();
   }
 };
@@ -822,7 +826,7 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
     else
       return failure();
     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
-                          /*emitCInterface=*/true);
+                          EmitCInterface::On);
     return success();
   }
 };
@@ -852,7 +856,7 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
     else
       return failure();
     replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
-                          /*emitCInterface=*/true);
+                          EmitCInterface::On);
     return success();
   }
 };
@@ -868,7 +872,8 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
       // Finalize any pending insertions.
       StringRef name = "endInsert";
       TypeRange noTp;
-      createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
+      createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
+                     EmitCInterface::Off);
     }
     rewriter.replaceOp(op, adaptor.getOperands());
     return success();
@@ -901,7 +906,7 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
       llvm_unreachable("Unknown element type");
     TypeRange noTp;
     replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
-                          /*emitCInterface=*/true);
+                          EmitCInterface::On);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list