[Mlir-commits] [mlir] f527fdf - [mlir][sparse] Code cleanup for SparseTensorConversion

wren romano llvmlistbot at llvm.org
Mon Dec 6 14:13:43 PST 2021


Author: wren romano
Date: 2021-12-06T14:13:35-08:00
New Revision: f527fdf51e7763b516f5d4ebc8c4ffb8808cd281

URL: https://github.com/llvm/llvm-project/commit/f527fdf51e7763b516f5d4ebc8c4ffb8808cd281
DIFF: https://github.com/llvm/llvm-project/commit/f527fdf51e7763b516f5d4ebc8c4ffb8808cd281.diff

LOG: [mlir][sparse] Code cleanup for SparseTensorConversion

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D115004

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 0dd611df0a514..d56fae07d47fc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -142,13 +142,19 @@ constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
   return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
 }
 
+/// Returns the equivalent of `void*` for opaque arguments to the
+/// execution engine.
+static Type getOpaquePointerType(PatternRewriter &rewriter) {
+  return LLVM::LLVMPointerType::get(rewriter.getI8Type());
+}
+
 /// Returns a function reference (first hit also inserts into module). Sets
 /// the "_emit_c_interface" on the function declaration when requested,
 /// so that LLVM lowering generates a wrapper function that takes care
 /// of ABI complications with passing in and returning MemRefs to C functions.
 static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
                                  TypeRange resultType, ValueRange operands,
-                                 bool emitCInterface = false) {
+                                 bool emitCInterface) {
   MLIRContext *context = op->getContext();
   auto module = op->getParentOfType<ModuleOp>();
   auto result = SymbolRefAttr::get(context, name);
@@ -165,6 +171,24 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
   return result;
 }
 
+/// Creates a `CallOp` to the function reference returned by `getFunc()`.
+static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
+                             TypeRange resultType, ValueRange operands,
+                             bool emitCInterface = false) {
+  auto fn = getFunc(op, name, resultType, operands, emitCInterface);
+  return builder.create<CallOp>(op->getLoc(), resultType, fn, operands);
+}
+
+/// Replaces the `op` with  a `CallOp` to the function reference returned
+/// by `getFunc()`.
+static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op,
+                                    StringRef name, TypeRange resultType,
+                                    ValueRange operands,
+                                    bool emitCInterface = false) {
+  auto fn = getFunc(op, name, resultType, operands, emitCInterface);
+  return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands);
+}
+
 /// Generates dimension size call.
 static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
                             SparseTensorEncodingAttr &enc, Value src,
@@ -173,25 +197,20 @@ static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
   if (AffineMap p = enc.getDimOrdering())
     idx = p.getPermutedPosition(idx);
   // Generate the call.
-  Location loc = op->getLoc();
   StringRef name = "sparseDimSize";
-  SmallVector<Value, 2> params;
-  params.push_back(src);
-  params.push_back(constantIndex(rewriter, loc, idx));
+  SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)};
   Type iTp = rewriter.getIndexType();
-  auto fn = getFunc(op, name, iTp, params);
-  return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
+  return createFuncCall(rewriter, op, name, iTp, params).getResult(0);
 }
 
 /// Generates a call into the "swiss army knife" method of the sparse runtime
 /// support library for materializing sparse tensors into the computation.
 static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
                         ArrayRef<Value> params) {
-  Location loc = op->getLoc();
   StringRef name = "newSparseTensor";
-  Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
-  auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
-  auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
+  Type pTp = getOpaquePointerType(rewriter);
+  auto call = createFuncCall(rewriter, op, name, pTp, params,
+                             /*emitCInterface=*/true);
   return call.getResult(0);
 }
 
@@ -210,8 +229,8 @@ static void sizesFromType(ConversionPatternRewriter &rewriter,
 static void sizesFromSrc(ConversionPatternRewriter &rewriter,
                          SmallVector<Value, 4> &sizes, Location loc,
                          Value src) {
-  ShapedType stp = src.getType().cast<ShapedType>();
-  for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
+  unsigned rank = src.getType().cast<ShapedType>().getRank();
+  for (unsigned i = 0; i < rank; i++)
     sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
 }
 
@@ -221,12 +240,13 @@ static void sizesFromPtr(ConversionPatternRewriter &rewriter,
                          SmallVector<Value, 4> &sizes, Operation *op,
                          SparseTensorEncodingAttr &enc, ShapedType stp,
                          Value src) {
+  Location loc = op->getLoc();
   auto shape = stp.getShape();
   for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
     if (shape[i] == ShapedType::kDynamicSize)
       sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
     else
-      sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
+      sizes.push_back(constantIndex(rewriter, loc, shape[i]));
 }
 
 /// Generates an uninitialized temporary buffer of the given size and
@@ -293,16 +313,15 @@ static void newParams(ConversionPatternRewriter &rewriter,
   }
   params.push_back(genBuffer(rewriter, loc, rev));
   // Secondary and primary types encoding.
-  ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
+  Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType();
   params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
   params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
-  params.push_back(
-      constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
-  // User action and pointer.
-  Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
-  if (!ptr)
-    ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
+  params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp));
+  // User action.
   params.push_back(constantAction(rewriter, loc, action));
+  // Payload pointer.
+  if (!ptr)
+    ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter));
   params.push_back(ptr);
 }
 
@@ -352,7 +371,6 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
 static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
                           Type eltType, Value ptr, Value val, Value ind,
                           Value perm) {
-  Location loc = op->getLoc();
   StringRef name;
   if (eltType.isF64())
     name = "addEltF64";
@@ -368,14 +386,9 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
     name = "addEltI8";
   else
     llvm_unreachable("Unknown element type");
-  SmallVector<Value, 8> params;
-  params.push_back(ptr);
-  params.push_back(val);
-  params.push_back(ind);
-  params.push_back(perm);
-  Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
-  auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
-  rewriter.create<CallOp>(loc, pTp, fn, params);
+  SmallVector<Value, 4> params{ptr, val, ind, perm};
+  Type pTp = getOpaquePointerType(rewriter);
+  createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true);
 }
 
 /// Generates a call to `iter->getNext()`.  If there is a next element,
@@ -384,7 +397,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
 /// the memory for `iter` is freed and the return value is false.
 static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
                             Value iter, Value ind, Value elemPtr) {
-  Location loc = op->getLoc();
   Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
   StringRef name;
   if (elemTp.isF64())
@@ -401,13 +413,10 @@ static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
     name = "getNextI8";
   else
     llvm_unreachable("Unknown element type");
-  SmallVector<Value, 3> params;
-  params.push_back(iter);
-  params.push_back(ind);
-  params.push_back(elemPtr);
+  SmallVector<Value, 3> params{iter, ind, elemPtr};
   Type i1 = rewriter.getI1Type();
-  auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true);
-  auto call = rewriter.create<CallOp>(loc, i1, fn, params);
+  auto call = createFuncCall(rewriter, op, name, i1, params,
+                             /*emitCInterface=*/true);
   return call.getResult(0);
 }
 
@@ -461,7 +470,7 @@ static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc,
   }
   Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes);
   Value zero = constantZero(rewriter, loc, elemTp);
-  rewriter.create<linalg::FillOp>(loc, zero, mem).result();
+  rewriter.create<linalg::FillOp>(loc, zero, mem);
   return mem;
 }
 
@@ -754,9 +763,8 @@ class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
   matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     StringRef name = "delSparseTensor";
-    TypeRange none;
-    auto fn = getFunc(op, name, none, adaptor.getOperands());
-    rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
+    TypeRange noTp;
+    createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
     rewriter.eraseOp(op);
     return success();
   }
@@ -785,9 +793,8 @@ class SparseTensorToPointersConverter
       name = "sparsePointers8";
     else
       return failure();
-    auto fn = getFunc(op, name, resType, adaptor.getOperands(),
-                      /*emitCInterface=*/true);
-    rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
+    replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
+                          /*emitCInterface=*/true);
     return success();
   }
 };
@@ -814,9 +821,8 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
       name = "sparseIndices8";
     else
       return failure();
-    auto fn = getFunc(op, name, resType, adaptor.getOperands(),
-                      /*emitCInterface=*/true);
-    rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
+    replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
+                          /*emitCInterface=*/true);
     return success();
   }
 };
@@ -845,9 +851,8 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
       name = "sparseValuesI8";
     else
       return failure();
-    auto fn = getFunc(op, name, resType, adaptor.getOperands(),
-                      /*emitCInterface=*/true);
-    rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
+    replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
+                          /*emitCInterface=*/true);
     return success();
   }
 };
@@ -863,8 +868,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
       // Finalize any pending insertions.
       StringRef name = "endInsert";
       TypeRange noTp;
-      auto fn = getFunc(op, name, noTp, adaptor.getOperands());
-      rewriter.create<CallOp>(op.getLoc(), noTp, fn, adaptor.getOperands());
+      createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
     }
     rewriter.replaceOp(op, adaptor.getOperands());
     return success();
@@ -896,9 +900,8 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
     else
       llvm_unreachable("Unknown element type");
     TypeRange noTp;
-    auto fn =
-        getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true);
-    rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands());
+    replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
+                          /*emitCInterface=*/true);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list