[Mlir-commits] [mlir] 14fffda - [mlir][sparse] Factoring out allocaIndices()
wren romano
llvmlistbot at llvm.org
Fri Oct 1 14:19:02 PDT 2021
Author: wren romano
Date: 2021-10-01T14:18:56-07:00
New Revision: 14fffda979ae7f8c7f6425568d3e9615d3d7732f
URL: https://github.com/llvm/llvm-project/commit/14fffda979ae7f8c7f6425568d3e9615d3d7732f
DIFF: https://github.com/llvm/llvm-project/commit/14fffda979ae7f8c7f6425568d3e9615d3d7732f.diff
LOG: [mlir][sparse] Factoring out allocaIndices()
This is preliminary work towards D110790. Depends On D110882.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D110883
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b5015931c44b..154fe3fb9be9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -295,6 +295,18 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
}
+/// Generates code to stack-allocate a `memref<?xindex>` where the `?`
+/// is the given `rank`. This array is intended to serve as a reusable
+/// buffer for storing the indices of a single tensor element, to avoid
+/// allocation in the body of loops.
+static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
+ int64_t rank) {
+ auto indexTp = rewriter.getIndexType();
+ auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
+ Value arg = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(rank));
+ return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
+}
+
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
@@ -413,13 +425,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
// loop is generated by genAddElt().
Location loc = op->getLoc();
ShapedType shape = resType.cast<ShapedType>();
- auto memTp =
- MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
Value perm;
Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
- Value arg = rewriter.create<ConstantOp>(
- loc, rewriter.getIndexAttr(shape.getRank()));
- Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
+ Value ind = allocaIndices(rewriter, loc, shape.getRank());
SmallVector<Value> lo;
SmallVector<Value> hi;
SmallVector<Value> st;
More information about the Mlir-commits
mailing list