[Mlir-commits] [mlir] 63d4fc9 - [mlir][sparse] Factoring out helper functions for generating constants
wren romano
llvmlistbot at llvm.org
Wed Oct 13 16:20:03 PDT 2021
Author: wren romano
Date: 2021-10-13T16:19:55-07:00
New Revision: 63d4fc9483774a3b10101bc6a6de9dce1d8bdca2
URL: https://github.com/llvm/llvm-project/commit/63d4fc9483774a3b10101bc6a6de9dce1d8bdca2
DIFF: https://github.com/llvm/llvm-project/commit/63d4fc9483774a3b10101bc6a6de9dce1d8bdca2.diff
LOG: [mlir][sparse] Factoring out helper functions for generating constants
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D111763
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index da3bd92f917eb..c4f5f4dccef08 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -81,6 +81,30 @@ getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
}
+/// Generates a constant zero of the given type.
+inline static Value constantZero(ConversionPatternRewriter &rewriter,
+ Location loc, Type t) {
+ return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
+}
+
+/// Generates a constant of `index` type.
+inline static Value constantIndex(ConversionPatternRewriter &rewriter,
+ Location loc, unsigned i) {
+ return rewriter.create<arith::ConstantIndexOp>(loc, i);
+}
+
+/// Generates a constant of `i64` type.
+inline static Value constantI64(ConversionPatternRewriter &rewriter,
+ Location loc, int64_t i) {
+ return rewriter.create<arith::ConstantIntOp>(loc, i, 64);
+}
+
+/// Generates a constant of `i32` type.
+inline static Value constantI32(ConversionPatternRewriter &rewriter,
+ Location loc, int32_t i) {
+ return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
+}
+
/// Returns integers of given width and values as a constant tensor.
/// We cast the static shape into a dynamic shape to ensure that the
/// method signature remains uniform across
diff erent tensor dimensions.
@@ -161,18 +185,14 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
assert(primary);
- params.push_back(rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI64IntegerAttr(secPtr)));
- params.push_back(rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI64IntegerAttr(secInd)));
- params.push_back(rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI64IntegerAttr(primary)));
+ params.push_back(constantI64(rewriter, loc, secPtr));
+ params.push_back(constantI64(rewriter, loc, secInd));
+ params.push_back(constantI64(rewriter, loc, primary));
// User action and pointer.
Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
if (!ptr)
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
- params.push_back(rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(action)));
+ params.push_back(constantI32(rewriter, loc, action));
params.push_back(ptr);
// Generate the call to create new tensor.
StringRef name = "newSparseTensor";
@@ -182,19 +202,13 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
return call.getResult(0);
}
-/// Generates a constant zero of the given type.
-static Value getZero(ConversionPatternRewriter &rewriter, Location loc,
- Type t) {
- return rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(t));
-}
-
/// Generates the comparison `v != 0` where `v` is of numeric type `t`.
/// For floating types, we use the "unordered" comparator (i.e., returns
/// true if `v` is NaN).
static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
Value v) {
Type t = v.getType();
- Value zero = getZero(rewriter, loc, t);
+ Value zero = constantZero(rewriter, loc, t);
if (t.isa<FloatType>())
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
zero);
@@ -221,8 +235,7 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
unsigned i = 0;
for (auto iv : ivs) {
- Value idx =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i++));
+ Value idx = constantIndex(rewriter, loc, i++);
rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
}
return val;
@@ -289,8 +302,7 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
unsigned rank) {
Location loc = op->getLoc();
for (unsigned i = 0; i < rank; i++) {
- Value idx =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i));
+ Value idx = constantIndex(rewriter, loc, i);
Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
ValueRange{ivs[0], idx});
val =
@@ -308,8 +320,7 @@ static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
int64_t rank) {
auto indexTp = rewriter.getIndexType();
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
- Value arg =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(rank));
+ Value arg = constantIndex(rewriter, loc, rank);
return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
}
@@ -352,8 +363,7 @@ class SparseTensorToDimSizeConverter
StringRef name = "sparseDimSize";
SmallVector<Value, 2> params;
params.push_back(adaptor.getOperands()[0]);
- params.push_back(rewriter.create<arith::ConstantOp>(
- op.getLoc(), rewriter.getIndexAttr(idx)));
+ params.push_back(constantIndex(rewriter, op.getLoc(), idx));
rewriter.replaceOpWithNewOp<CallOp>(
op, resType, getFunc(op, name, resType, params), params);
return success();
@@ -437,10 +447,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
SmallVector<Value> lo;
SmallVector<Value> hi;
SmallVector<Value> st;
- Value zero =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
- Value one =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
+ Value zero = constantIndex(rewriter, loc, 0);
+ Value one = constantIndex(rewriter, loc, 1);
auto indicesValues = genSplitSparseConstant(rewriter, op, src);
bool isCOOConstant = indicesValues.hasValue();
Value indices;
More information about the Mlir-commits
mailing list