[Mlir-commits] [mlir] f527fdf - [mlir][sparse] Code cleanup for SparseTensorConversion
wren romano
llvmlistbot at llvm.org
Mon Dec 6 14:13:43 PST 2021
Author: wren romano
Date: 2021-12-06T14:13:35-08:00
New Revision: f527fdf51e7763b516f5d4ebc8c4ffb8808cd281
URL: https://github.com/llvm/llvm-project/commit/f527fdf51e7763b516f5d4ebc8c4ffb8808cd281
DIFF: https://github.com/llvm/llvm-project/commit/f527fdf51e7763b516f5d4ebc8c4ffb8808cd281.diff
LOG: [mlir][sparse] Code cleanup for SparseTensorConversion
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D115004
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 0dd611df0a514..d56fae07d47fc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -142,13 +142,19 @@ constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
}
+/// Returns the equivalent of `void*` for opaque arguments to the
+/// execution engine.
+static Type getOpaquePointerType(PatternRewriter &rewriter) {
+ return LLVM::LLVMPointerType::get(rewriter.getI8Type());
+}
+
/// Returns a function reference (first hit also inserts into module). Sets
/// the "_emit_c_interface" on the function declaration when requested,
/// so that LLVM lowering generates a wrapper function that takes care
/// of ABI complications with passing in and returning MemRefs to C functions.
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
TypeRange resultType, ValueRange operands,
- bool emitCInterface = false) {
+ bool emitCInterface) {
MLIRContext *context = op->getContext();
auto module = op->getParentOfType<ModuleOp>();
auto result = SymbolRefAttr::get(context, name);
@@ -165,6 +171,24 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
return result;
}
+/// Creates a `CallOp` to the function reference returned by `getFunc()`.
+static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name,
+ TypeRange resultType, ValueRange operands,
+ bool emitCInterface = false) {
+ auto fn = getFunc(op, name, resultType, operands, emitCInterface);
+ return builder.create<CallOp>(op->getLoc(), resultType, fn, operands);
+}
+
+/// Replaces the `op` with a `CallOp` to the function reference returned
+/// by `getFunc()`.
+static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op,
+ StringRef name, TypeRange resultType,
+ ValueRange operands,
+ bool emitCInterface = false) {
+ auto fn = getFunc(op, name, resultType, operands, emitCInterface);
+ return rewriter.replaceOpWithNewOp<CallOp>(op, resultType, fn, operands);
+}
+
/// Generates dimension size call.
static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
SparseTensorEncodingAttr &enc, Value src,
@@ -173,25 +197,20 @@ static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
if (AffineMap p = enc.getDimOrdering())
idx = p.getPermutedPosition(idx);
// Generate the call.
- Location loc = op->getLoc();
StringRef name = "sparseDimSize";
- SmallVector<Value, 2> params;
- params.push_back(src);
- params.push_back(constantIndex(rewriter, loc, idx));
+ SmallVector<Value, 2> params{src, constantIndex(rewriter, op->getLoc(), idx)};
Type iTp = rewriter.getIndexType();
- auto fn = getFunc(op, name, iTp, params);
- return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
+ return createFuncCall(rewriter, op, name, iTp, params).getResult(0);
}
/// Generates a call into the "swiss army knife" method of the sparse runtime
/// support library for materializing sparse tensors into the computation.
static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
ArrayRef<Value> params) {
- Location loc = op->getLoc();
StringRef name = "newSparseTensor";
- Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
- auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
- auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
+ Type pTp = getOpaquePointerType(rewriter);
+ auto call = createFuncCall(rewriter, op, name, pTp, params,
+ /*emitCInterface=*/true);
return call.getResult(0);
}
@@ -210,8 +229,8 @@ static void sizesFromType(ConversionPatternRewriter &rewriter,
static void sizesFromSrc(ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &sizes, Location loc,
Value src) {
- ShapedType stp = src.getType().cast<ShapedType>();
- for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
+ unsigned rank = src.getType().cast<ShapedType>().getRank();
+ for (unsigned i = 0; i < rank; i++)
sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
}
@@ -221,12 +240,13 @@ static void sizesFromPtr(ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &sizes, Operation *op,
SparseTensorEncodingAttr &enc, ShapedType stp,
Value src) {
+ Location loc = op->getLoc();
auto shape = stp.getShape();
for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
if (shape[i] == ShapedType::kDynamicSize)
sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
else
- sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
+ sizes.push_back(constantIndex(rewriter, loc, shape[i]));
}
/// Generates an uninitialized temporary buffer of the given size and
@@ -293,16 +313,15 @@ static void newParams(ConversionPatternRewriter &rewriter,
}
params.push_back(genBuffer(rewriter, loc, rev));
// Secondary and primary types encoding.
- ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
+ Type elemTp = op->getResult(0).getType().cast<ShapedType>().getElementType();
params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
- params.push_back(
- constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
- // User action and pointer.
- Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
- if (!ptr)
- ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
+ params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp));
+ // User action.
params.push_back(constantAction(rewriter, loc, action));
+ // Payload pointer.
+ if (!ptr)
+ ptr = rewriter.create<LLVM::NullOp>(loc, getOpaquePointerType(rewriter));
params.push_back(ptr);
}
@@ -352,7 +371,6 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
Type eltType, Value ptr, Value val, Value ind,
Value perm) {
- Location loc = op->getLoc();
StringRef name;
if (eltType.isF64())
name = "addEltF64";
@@ -368,14 +386,9 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
name = "addEltI8";
else
llvm_unreachable("Unknown element type");
- SmallVector<Value, 8> params;
- params.push_back(ptr);
- params.push_back(val);
- params.push_back(ind);
- params.push_back(perm);
- Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
- auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
- rewriter.create<CallOp>(loc, pTp, fn, params);
+ SmallVector<Value, 4> params{ptr, val, ind, perm};
+ Type pTp = getOpaquePointerType(rewriter);
+ createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true);
}
/// Generates a call to `iter->getNext()`. If there is a next element,
@@ -384,7 +397,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
/// the memory for `iter` is freed and the return value is false.
static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
Value iter, Value ind, Value elemPtr) {
- Location loc = op->getLoc();
Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
StringRef name;
if (elemTp.isF64())
@@ -401,13 +413,10 @@ static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
name = "getNextI8";
else
llvm_unreachable("Unknown element type");
- SmallVector<Value, 3> params;
- params.push_back(iter);
- params.push_back(ind);
- params.push_back(elemPtr);
+ SmallVector<Value, 3> params{iter, ind, elemPtr};
Type i1 = rewriter.getI1Type();
- auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true);
- auto call = rewriter.create<CallOp>(loc, i1, fn, params);
+ auto call = createFuncCall(rewriter, op, name, i1, params,
+ /*emitCInterface=*/true);
return call.getResult(0);
}
@@ -461,7 +470,7 @@ static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc,
}
Value mem = rewriter.create<memref::AllocOp>(loc, memTp, dynamicSizes);
Value zero = constantZero(rewriter, loc, elemTp);
- rewriter.create<linalg::FillOp>(loc, zero, mem).result();
+ rewriter.create<linalg::FillOp>(loc, zero, mem);
return mem;
}
@@ -754,9 +763,8 @@ class SparseTensorReleaseConverter : public OpConversionPattern<ReleaseOp> {
matchAndRewrite(ReleaseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
StringRef name = "delSparseTensor";
- TypeRange none;
- auto fn = getFunc(op, name, none, adaptor.getOperands());
- rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
+ TypeRange noTp;
+ createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
rewriter.eraseOp(op);
return success();
}
@@ -785,9 +793,8 @@ class SparseTensorToPointersConverter
name = "sparsePointers8";
else
return failure();
- auto fn = getFunc(op, name, resType, adaptor.getOperands(),
- /*emitCInterface=*/true);
- rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
+ replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
+ /*emitCInterface=*/true);
return success();
}
};
@@ -814,9 +821,8 @@ class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
name = "sparseIndices8";
else
return failure();
- auto fn = getFunc(op, name, resType, adaptor.getOperands(),
- /*emitCInterface=*/true);
- rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
+ replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
+ /*emitCInterface=*/true);
return success();
}
};
@@ -845,9 +851,8 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
name = "sparseValuesI8";
else
return failure();
- auto fn = getFunc(op, name, resType, adaptor.getOperands(),
- /*emitCInterface=*/true);
- rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
+ replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
+ /*emitCInterface=*/true);
return success();
}
};
@@ -863,8 +868,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
// Finalize any pending insertions.
StringRef name = "endInsert";
TypeRange noTp;
- auto fn = getFunc(op, name, noTp, adaptor.getOperands());
- rewriter.create<CallOp>(op.getLoc(), noTp, fn, adaptor.getOperands());
+ createFuncCall(rewriter, op, name, noTp, adaptor.getOperands());
}
rewriter.replaceOp(op, adaptor.getOperands());
return success();
@@ -896,9 +900,8 @@ class SparseTensorLexInsertConverter : public OpConversionPattern<LexInsertOp> {
else
llvm_unreachable("Unknown element type");
TypeRange noTp;
- auto fn =
- getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true);
- rewriter.replaceOpWithNewOp<CallOp>(op, noTp, fn, adaptor.getOperands());
+ replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(),
+ /*emitCInterface=*/true);
return success();
}
};
More information about the Mlir-commits
mailing list