[Mlir-commits] [mlir] 14fffda - [mlir][sparse] Factoring out allocaIndices()

wren romano llvmlistbot at llvm.org
Fri Oct 1 14:19:02 PDT 2021


Author: wren romano
Date: 2021-10-01T14:18:56-07:00
New Revision: 14fffda979ae7f8c7f6425568d3e9615d3d7732f

URL: https://github.com/llvm/llvm-project/commit/14fffda979ae7f8c7f6425568d3e9615d3d7732f
DIFF: https://github.com/llvm/llvm-project/commit/14fffda979ae7f8c7f6425568d3e9615d3d7732f.diff

LOG: [mlir][sparse] Factoring out allocaIndices()

This is preliminary work towards D110790. Depends On D110882.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D110883

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b5015931c44b..154fe3fb9be9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -295,6 +295,18 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
 }
 
+/// Generates code to stack-allocate a `memref<?xindex>` where the `?`
+/// is the given `rank`.  This array is intended to serve as a reusable
+/// buffer for storing the indices of a single tensor element, to avoid
+/// allocation in the body of loops.
+static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
+                           int64_t rank) {
+  auto indexTp = rewriter.getIndexType();
+  auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
+  Value arg = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(rank));
+  return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
+}
+
 //===----------------------------------------------------------------------===//
 // Conversion rules.
 //===----------------------------------------------------------------------===//
@@ -413,13 +425,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     // loop is generated by genAddElt().
     Location loc = op->getLoc();
     ShapedType shape = resType.cast<ShapedType>();
-    auto memTp =
-        MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
     Value perm;
     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
-    Value arg = rewriter.create<ConstantOp>(
-        loc, rewriter.getIndexAttr(shape.getRank()));
-    Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
+    Value ind = allocaIndices(rewriter, loc, shape.getRank());
     SmallVector<Value> lo;
     SmallVector<Value> hi;
     SmallVector<Value> st;


        


More information about the Mlir-commits mailing list