[Mlir-commits] [mlir] 9d4df97 - [mlir][sparse] Canonicalizing arguments to genReshapeDstShape and foreachInSparseConstant

wren romano llvmlistbot at llvm.org
Tue Apr 11 19:12:09 PDT 2023


Author: wren romano
Date: 2023-04-11T19:11:59-07:00
New Revision: 9d4df97ff0ef459316cdb55f17eaea9fe9b1c1d5

URL: https://github.com/llvm/llvm-project/commit/9d4df97ff0ef459316cdb55f17eaea9fe9b1c1d5
DIFF: https://github.com/llvm/llvm-project/commit/9d4df97ff0ef459316cdb55f17eaea9fe9b1c1d5.diff

LOG: [mlir][sparse] Canonicalizing arguments to genReshapeDstShape and foreachInSparseConstant

These functions don't need a`PatternRewriter`, they only need an `OpBuilder`.  And, the builder should be the first argument, before the `Location`, to match the style used everywhere else in MLIR.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D148059

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 957d41b82d23b..cbf591372f9af 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -244,16 +244,16 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
 }
 
 void mlir::sparse_tensor::genReshapeDstShape(
-    Location loc, PatternRewriter &rewriter, SmallVectorImpl<Value> &dstShape,
+    OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
     ArrayRef<Value> srcShape, ArrayRef<StaticSize> staticDstShape,
     ArrayRef<ReassociationIndices> reassociation) {
   // Collapse shape.
   if (reassociation.size() < srcShape.size()) {
     unsigned start = 0;
     for (const auto &map : llvm::enumerate(reassociation)) {
-      auto dstDim = constantIndex(rewriter, loc, 1);
+      auto dstDim = constantIndex(builder, loc, 1);
       for (unsigned i = start; i < start + map.value().size(); i++) {
-        dstDim = rewriter.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
+        dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
       }
       dstShape.push_back(dstDim);
       start = start + map.value().size();
@@ -285,13 +285,13 @@ void mlir::sparse_tensor::genReshapeDstShape(
           }
         }
         // Compute the dynamic dimension size.
-        Value productVal = constantIndex(rewriter, loc, product);
+        Value productVal = constantIndex(builder, loc, product);
         Value dynamicSize =
-            rewriter.create<arith::DivUIOp>(loc, srcDim, productVal);
+            builder.create<arith::DivUIOp>(loc, srcDim, productVal);
         dstShape.push_back(dynamicSize);
       } else {
         // The expanded dimension is statically known.
-        dstShape.push_back(constantIndex(rewriter, loc, staticDstShape[j]));
+        dstShape.push_back(constantIndex(builder, loc, staticDstShape[j]));
       }
     }
     start = start + map.size();
@@ -512,8 +512,8 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
 }
 
 void sparse_tensor::foreachInSparseConstant(
-    Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
-    AffineMap order, function_ref<void(ArrayRef<Value>, Value)> callback) {
+    OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
+    function_ref<void(ArrayRef<Value>, Value)> callback) {
   const Dimension dimRank = getSparseTensorType(attr).getDimRank();
   const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
   const auto values = attr.getValues().getValues<Attribute>();
@@ -560,17 +560,17 @@ void sparse_tensor::foreachInSparseConstant(
     cvs.clear();
     for (Dimension d = 0; d < dimRank; d++) {
       auto crd = elems[i].first[d].getInt();
-      cvs.push_back(rewriter.create<arith::ConstantIndexOp>(loc, crd));
+      cvs.push_back(builder.create<arith::ConstantIndexOp>(loc, crd));
     }
     // Remap value.
     Value val;
     if (attr.getElementType().isa<ComplexType>()) {
       auto valAttr = elems[i].second.cast<ArrayAttr>();
-      val = rewriter.create<complex::ConstantOp>(loc, attr.getElementType(),
-                                                 valAttr);
+      val = builder.create<complex::ConstantOp>(loc, attr.getElementType(),
+                                                valAttr);
     } else {
       auto valAttr = elems[i].second.cast<TypedAttr>();
-      val = rewriter.create<arith::ConstantOp>(loc, valAttr);
+      val = builder.create<arith::ConstantOp>(loc, valAttr);
     }
     assert(val);
     callback(cvs, val);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 6d6351cce6dba..47c581c8d88ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -89,7 +89,7 @@ Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
 /// Computes the shape of destination tensor of a reshape operator. This is only
 /// used when operands have dynamic shape. The shape of the destination is
 /// stored into dstShape.
-void genReshapeDstShape(Location loc, PatternRewriter &rewriter,
+void genReshapeDstShape(OpBuilder &builder, Location loc,
                         SmallVectorImpl<Value> &dstShape,
                         ArrayRef<Value> srcShape,
                         ArrayRef<StaticSize> staticDstShape,
@@ -211,8 +211,8 @@ Operation *getTop(Operation *op);
 /// %v3 = complex.constant (5.0, 6.0)
 /// callback({%c3}, %v3)
 void foreachInSparseConstant(
-    Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
-    AffineMap order, function_ref<void(ArrayRef<Value>, Value)> callback);
+    OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
+    function_ref<void(ArrayRef<Value>, Value)> callback);
 
 /// Loads `size`-many values from the memref, which must have rank-1 and
 /// size greater-or-equal to `size`.  If the optional `(offsetIdx,offsetVal)`

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c3c0a4f88c362..8d0c8548097f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -489,7 +489,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
     // Static "shapes" are in fact "sizes".
     fillDimShape(rewriter, loc, dstTp, dstDimSizes);
   else
-    genReshapeDstShape(loc, rewriter, dstDimSizes, srcDimSizes,
+    genReshapeDstShape(rewriter, loc, dstDimSizes, srcDimSizes,
                        dstTp.getDimShape(), op.getReassociationIndices());
   const Value coo =
       params.genBuffers(dstTp, dstDimSizes).genNewCall(Action::kEmptyCOO);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 52281bfa94ae9..19ed23108f2f5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -158,7 +158,7 @@ static LogicalResult genForeachOnSparseConstant(ForeachOp op,
 
   // Foreach on constant.
   foreachInSparseConstant(
-      loc, rewriter, attr, op.getOrder().value_or(AffineMap()),
+      rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
       [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
         SmallVector<Value> args;
         args.append(cvs.begin(), cvs.end());
@@ -372,7 +372,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
         dstSizes.push_back(constantIndex(rewriter, loc, d));
     } else {
       ArrayRef<DynSize> dstShape = dstTp.getDimShape();
-      genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
+      genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
                          op.getReassociationIndices());
       for (auto [idx, shape] : llvm::enumerate(dstShape)) {
         if (shape == ShapedType::kDynamic)


        


More information about the Mlir-commits mailing list