[Mlir-commits] [mlir] 938f419 - [mlir][sparse] Avoid generating DimOp in conversion passes.

Peiming Liu llvmlistbot at llvm.org
Fri Sep 9 11:08:13 PDT 2022


Author: Peiming Liu
Date: 2022-09-09T18:08:05Z
New Revision: 938f419cf1910b79388896c9694e58efc5325cba

URL: https://github.com/llvm/llvm-project/commit/938f419cf1910b79388896c9694e58efc5325cba
DIFF: https://github.com/llvm/llvm-project/commit/938f419cf1910b79388896c9694e58efc5325cba.diff

LOG: [mlir][sparse] Avoid generating DimOp in conversion passes.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133592

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 9ad37bf498f16..4ac2d1775a811 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -76,6 +76,31 @@ static void flattenOperands(ValueRange operands,
   }
 }
 
+/// Gets the dimension size for the given sparse tensor at the given dim.
+/// Returns None if no sparse encoding is attached to the tensor type.
+static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
+                                           ShapedType tensorTp,
+                                           Value adaptedValue, unsigned dim) {
+  auto enc = getSparseTensorEncoding(tensorTp);
+  if (!enc)
+    return llvm::None;
+
+  // Access into static dimension can query original type directly.
+  // Note that this is typically already done by DimOp's folding.
+  auto shape = tensorTp.getShape();
+  if (!ShapedType::isDynamic(shape[dim]))
+    return constantIndex(rewriter, loc, shape[dim]);
+
+  // Any other query can consult the dimSizes array at field 0 using,
+  // accounting for the reordering applied to the sparse storage.
+  auto tuple =
+      llvm::cast<UnrealizedConversionCastOp>(adaptedValue.getDefiningOp());
+  return rewriter
+      .create<memref::LoadOp>(loc, tuple.getInputs().front(),
+                              constantIndex(rewriter, loc, toStored(enc, dim)))
+      .getResult();
+}
+
 /// Maps a sparse tensor type to the appropriate compounded buffers.
 static Optional<LogicalResult>
 convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
@@ -344,28 +369,17 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
   LogicalResult
   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Only rewrite annotated DimOp with constant index.
-    auto enc = getSparseTensorEncoding(op.getSource().getType());
-    if (!enc)
-      return failure();
     Optional<int64_t> index = op.getConstantIndex();
     if (!index)
       return failure();
-    // Access into static dimension can query original type directly.
-    // Note that this is typically already done by DimOp's folding.
-    Location loc = op->getLoc();
-    auto shape = op.getSource().getType().cast<RankedTensorType>().getShape();
-    if (!ShapedType::isDynamic(shape[*index])) {
-      rewriter.replaceOp(op, constantIndex(rewriter, loc, shape[*index]));
-      return success();
-    }
-    // Any other query can consult the dimSizes array at field 0 using,
-    // accounting for the reordering applied to the sparse storage.
-    auto tuple = llvm::cast<UnrealizedConversionCastOp>(
-        adaptor.getSource().getDefiningOp());
-    rewriter.replaceOpWithNewOp<memref::LoadOp>(
-        op, tuple.getInputs().front(),
-        constantIndex(rewriter, loc, toStored(enc, *index)));
+    auto sz =
+        sizeFromTensorAtDim(rewriter, op.getLoc(),
+                            op.getSource().getType().cast<RankedTensorType>(),
+                            adaptor.getSource(), *index);
+    if (!sz)
+      return failure();
+
+    rewriter.replaceOp(op, *sz);
     return success();
   }
 };
@@ -496,11 +510,13 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
     unsigned innerDim = srcType.getRank() - 1;
     if (AffineMap p = enc.getDimOrdering())
       innerDim = p.getDimPosition(innerDim);
-    Value sz = rewriter.create<tensor::DimOp>(loc, op.getTensor(), innerDim);
+    auto sz = sizeFromTensorAtDim(rewriter, loc, srcType, adaptor.getTensor(),
+                                  innerDim);
+    assert(sz); // This for sure is a sparse tensor
     // Generate a memref for `sz` elements of type `t`.
     auto genAlloc = [&](Type t) {
       auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t);
-      return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
+      return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{*sz});
     };
     // Allocate temporary buffers for values, filled-switch, and indices.
     // We do not use stack buffers for this, since the expanded size may
@@ -590,5 +606,5 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                SparseTensorAllocConverter, SparseTensorDeallocConverter,
                SparseToPointersConverter, SparseToIndicesConverter,
                SparseToValuesConverter, SparseTensorLoadConverter>(
-               typeConverter, patterns.getContext());
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index f2967c33705bf..d6fe145e77610 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -564,7 +564,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
       encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
   SmallVector<Value, 4> sizes;
   SmallVector<Value, 8> params;
-  sizesFromSrc(rewriter, sizes, loc, op.getSrc());
+  sizesFromPtr(rewriter, sizes, loc, encSrc, srcTp, adaptor.getSrc());
   newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, sizes,
             adaptor.getSrc());
   Value iter = genNewCall(rewriter, loc, params);
@@ -1168,13 +1168,13 @@ class SparseTensorExpandConverter : public OpConversionPattern<ExpandOp> {
     // All initialization should be done on entry of the loop nest.
     rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
     // Determine the size for access expansion (always the innermost stored
-    // dimension size, translated back to original dimension). Note that we
-    // recursively rewrite the new DimOp on the **original** tensor.
+    // dimension size, translated back to original dimension).
     auto enc = getSparseTensorEncoding(srcType);
     unsigned innerDim = srcType.getRank() - 1;
     if (AffineMap p = enc.getDimOrdering())
       innerDim = p.getDimPosition(innerDim);
-    Value sz = rewriter.create<tensor::DimOp>(loc, op.getTensor(), innerDim);
+    auto sz = sizeFromPtrAtDim(rewriter, loc, enc, srcType, adaptor.getTensor(),
+                               innerDim);
     // Allocate temporary buffers for values, filled-switch, and indices.
     // We do not use stack buffers for this, since the expanded size may
     // be rather large (as it envelops a single expanded dense dimension).


        


More information about the Mlir-commits mailing list