[Mlir-commits] [mlir] 9d4df97 - [mlir][sparse] Canonicalizing arguments to genReshapeDstShape and foreachInSparseConstant
wren romano
llvmlistbot at llvm.org
Tue Apr 11 19:12:09 PDT 2023
Author: wren romano
Date: 2023-04-11T19:11:59-07:00
New Revision: 9d4df97ff0ef459316cdb55f17eaea9fe9b1c1d5
URL: https://github.com/llvm/llvm-project/commit/9d4df97ff0ef459316cdb55f17eaea9fe9b1c1d5
DIFF: https://github.com/llvm/llvm-project/commit/9d4df97ff0ef459316cdb55f17eaea9fe9b1c1d5.diff
LOG: [mlir][sparse] Canonicalizing arguments to genReshapeDstShape and foreachInSparseConstant
These functions don't need a`PatternRewriter`, they only need an `OpBuilder`. And, the builder should be the first argument, before the `Location`, to match the style used everywhere else in MLIR.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D148059
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 957d41b82d23b..cbf591372f9af 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -244,16 +244,16 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
}
void mlir::sparse_tensor::genReshapeDstShape(
- Location loc, PatternRewriter &rewriter, SmallVectorImpl<Value> &dstShape,
+ OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
ArrayRef<Value> srcShape, ArrayRef<StaticSize> staticDstShape,
ArrayRef<ReassociationIndices> reassociation) {
// Collapse shape.
if (reassociation.size() < srcShape.size()) {
unsigned start = 0;
for (const auto &map : llvm::enumerate(reassociation)) {
- auto dstDim = constantIndex(rewriter, loc, 1);
+ auto dstDim = constantIndex(builder, loc, 1);
for (unsigned i = start; i < start + map.value().size(); i++) {
- dstDim = rewriter.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
+ dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
}
dstShape.push_back(dstDim);
start = start + map.value().size();
@@ -285,13 +285,13 @@ void mlir::sparse_tensor::genReshapeDstShape(
}
}
// Compute the dynamic dimension size.
- Value productVal = constantIndex(rewriter, loc, product);
+ Value productVal = constantIndex(builder, loc, product);
Value dynamicSize =
- rewriter.create<arith::DivUIOp>(loc, srcDim, productVal);
+ builder.create<arith::DivUIOp>(loc, srcDim, productVal);
dstShape.push_back(dynamicSize);
} else {
// The expanded dimension is statically known.
- dstShape.push_back(constantIndex(rewriter, loc, staticDstShape[j]));
+ dstShape.push_back(constantIndex(builder, loc, staticDstShape[j]));
}
}
start = start + map.size();
@@ -512,8 +512,8 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
}
void sparse_tensor::foreachInSparseConstant(
- Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
- AffineMap order, function_ref<void(ArrayRef<Value>, Value)> callback) {
+ OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
+ function_ref<void(ArrayRef<Value>, Value)> callback) {
const Dimension dimRank = getSparseTensorType(attr).getDimRank();
const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
const auto values = attr.getValues().getValues<Attribute>();
@@ -560,17 +560,17 @@ void sparse_tensor::foreachInSparseConstant(
cvs.clear();
for (Dimension d = 0; d < dimRank; d++) {
auto crd = elems[i].first[d].getInt();
- cvs.push_back(rewriter.create<arith::ConstantIndexOp>(loc, crd));
+ cvs.push_back(builder.create<arith::ConstantIndexOp>(loc, crd));
}
// Remap value.
Value val;
if (attr.getElementType().isa<ComplexType>()) {
auto valAttr = elems[i].second.cast<ArrayAttr>();
- val = rewriter.create<complex::ConstantOp>(loc, attr.getElementType(),
- valAttr);
+ val = builder.create<complex::ConstantOp>(loc, attr.getElementType(),
+ valAttr);
} else {
auto valAttr = elems[i].second.cast<TypedAttr>();
- val = rewriter.create<arith::ConstantOp>(loc, valAttr);
+ val = builder.create<arith::ConstantOp>(loc, valAttr);
}
assert(val);
callback(cvs, val);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 6d6351cce6dba..47c581c8d88ca 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -89,7 +89,7 @@ Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
/// Computes the shape of destination tensor of a reshape operator. This is only
/// used when operands have dynamic shape. The shape of the destination is
/// stored into dstShape.
-void genReshapeDstShape(Location loc, PatternRewriter &rewriter,
+void genReshapeDstShape(OpBuilder &builder, Location loc,
SmallVectorImpl<Value> &dstShape,
ArrayRef<Value> srcShape,
ArrayRef<StaticSize> staticDstShape,
@@ -211,8 +211,8 @@ Operation *getTop(Operation *op);
/// %v3 = complex.constant (5.0, 6.0)
/// callback({%c3}, %v3)
void foreachInSparseConstant(
- Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
- AffineMap order, function_ref<void(ArrayRef<Value>, Value)> callback);
+ OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
+ function_ref<void(ArrayRef<Value>, Value)> callback);
/// Loads `size`-many values from the memref, which must have rank-1 and
/// size greater-or-equal to `size`. If the optional `(offsetIdx,offsetVal)`
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index c3c0a4f88c362..8d0c8548097f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -489,7 +489,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
// Static "shapes" are in fact "sizes".
fillDimShape(rewriter, loc, dstTp, dstDimSizes);
else
- genReshapeDstShape(loc, rewriter, dstDimSizes, srcDimSizes,
+ genReshapeDstShape(rewriter, loc, dstDimSizes, srcDimSizes,
dstTp.getDimShape(), op.getReassociationIndices());
const Value coo =
params.genBuffers(dstTp, dstDimSizes).genNewCall(Action::kEmptyCOO);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 52281bfa94ae9..19ed23108f2f5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -158,7 +158,7 @@ static LogicalResult genForeachOnSparseConstant(ForeachOp op,
// Foreach on constant.
foreachInSparseConstant(
- loc, rewriter, attr, op.getOrder().value_or(AffineMap()),
+ rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
[&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
SmallVector<Value> args;
args.append(cvs.begin(), cvs.end());
@@ -372,7 +372,7 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
dstSizes.push_back(constantIndex(rewriter, loc, d));
} else {
ArrayRef<DynSize> dstShape = dstTp.getDimShape();
- genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
+ genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
op.getReassociationIndices());
for (auto [idx, shape] : llvm::enumerate(dstShape)) {
if (shape == ShapedType::kDynamic)
More information about the Mlir-commits
mailing list