[Mlir-commits] [mlir] 63d4fc9 - [mlir][sparse] Factoring out helper functions for generating constants

wren romano llvmlistbot at llvm.org
Wed Oct 13 16:20:03 PDT 2021


Author: wren romano
Date: 2021-10-13T16:19:55-07:00
New Revision: 63d4fc9483774a3b10101bc6a6de9dce1d8bdca2

URL: https://github.com/llvm/llvm-project/commit/63d4fc9483774a3b10101bc6a6de9dce1d8bdca2
DIFF: https://github.com/llvm/llvm-project/commit/63d4fc9483774a3b10101bc6a6de9dce1d8bdca2.diff

LOG: [mlir][sparse] Factoring out helper functions for generating constants

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D111763

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index da3bd92f917eb..c4f5f4dccef08 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -81,6 +81,30 @@ getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
 }
 
+/// Generates a constant zero of the given type.
+inline static Value constantZero(ConversionPatternRewriter &rewriter,
+                                 Location loc, Type t) {
+  return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
+}
+
+/// Generates a constant of `index` type.
+inline static Value constantIndex(ConversionPatternRewriter &rewriter,
+                                  Location loc, unsigned i) {
+  return rewriter.create<arith::ConstantIndexOp>(loc, i);
+}
+
+/// Generates a constant of `i64` type.
+inline static Value constantI64(ConversionPatternRewriter &rewriter,
+                                Location loc, int64_t i) {
+  return rewriter.create<arith::ConstantIntOp>(loc, i, 64);
+}
+
+/// Generates a constant of `i32` type.
+inline static Value constantI32(ConversionPatternRewriter &rewriter,
+                                Location loc, int32_t i) {
+  return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
+}
+
 /// Returns integers of given width and values as a constant tensor.
 /// We cast the static shape into a dynamic shape to ensure that the
 /// method signature remains uniform across 
diff erent tensor dimensions.
@@ -161,18 +185,14 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
   assert(primary);
-  params.push_back(rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI64IntegerAttr(secPtr)));
-  params.push_back(rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI64IntegerAttr(secInd)));
-  params.push_back(rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI64IntegerAttr(primary)));
+  params.push_back(constantI64(rewriter, loc, secPtr));
+  params.push_back(constantI64(rewriter, loc, secInd));
+  params.push_back(constantI64(rewriter, loc, primary));
   // User action and pointer.
   Type pTp = LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8));
   if (!ptr)
     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
-  params.push_back(rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32IntegerAttr(action)));
+  params.push_back(constantI32(rewriter, loc, action));
   params.push_back(ptr);
   // Generate the call to create new tensor.
   StringRef name = "newSparseTensor";
@@ -182,19 +202,13 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
   return call.getResult(0);
 }
 
-/// Generates a constant zero of the given type.
-static Value getZero(ConversionPatternRewriter &rewriter, Location loc,
-                     Type t) {
-  return rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(t));
-}
-
 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
 /// For floating types, we use the "unordered" comparator (i.e., returns
 /// true if `v` is NaN).
 static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
                           Value v) {
   Type t = v.getType();
-  Value zero = getZero(rewriter, loc, t);
+  Value zero = constantZero(rewriter, loc, t);
   if (t.isa<FloatType>())
     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
                                           zero);
@@ -221,8 +235,7 @@ static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter,
   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
   unsigned i = 0;
   for (auto iv : ivs) {
-    Value idx =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i++));
+    Value idx = constantIndex(rewriter, loc, i++);
     rewriter.create<memref::StoreOp>(loc, iv, ind, idx);
   }
   return val;
@@ -289,8 +302,7 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
                                        unsigned rank) {
   Location loc = op->getLoc();
   for (unsigned i = 0; i < rank; i++) {
-    Value idx =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i));
+    Value idx = constantIndex(rewriter, loc, i);
     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
                                                    ValueRange{ivs[0], idx});
     val =
@@ -308,8 +320,7 @@ static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
                            int64_t rank) {
   auto indexTp = rewriter.getIndexType();
   auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
-  Value arg =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(rank));
+  Value arg = constantIndex(rewriter, loc, rank);
   return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
 }
 
@@ -352,8 +363,7 @@ class SparseTensorToDimSizeConverter
     StringRef name = "sparseDimSize";
     SmallVector<Value, 2> params;
     params.push_back(adaptor.getOperands()[0]);
-    params.push_back(rewriter.create<arith::ConstantOp>(
-        op.getLoc(), rewriter.getIndexAttr(idx)));
+    params.push_back(constantIndex(rewriter, op.getLoc(), idx));
     rewriter.replaceOpWithNewOp<CallOp>(
         op, resType, getFunc(op, name, resType, params), params);
     return success();
@@ -437,10 +447,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     SmallVector<Value> lo;
     SmallVector<Value> hi;
     SmallVector<Value> st;
-    Value zero =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
-    Value one =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
+    Value zero = constantIndex(rewriter, loc, 0);
+    Value one = constantIndex(rewriter, loc, 1);
     auto indicesValues = genSplitSparseConstant(rewriter, op, src);
     bool isCOOConstant = indicesValues.hasValue();
     Value indices;


        


More information about the Mlir-commits mailing list