[Mlir-commits] [mlir] 1d6eb09 - [mlir] NFC - VectorTransforms use OpBuilder where relevant
Nicolas Vasilache
llvmlistbot at llvm.org
Sun May 17 07:21:05 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-17T10:17:12-04:00
New Revision: 1d6eb09d2225310b1af54856c34fdcd45cd0f9ef
URL: https://github.com/llvm/llvm-project/commit/1d6eb09d2225310b1af54856c34fdcd45cd0f9ef
DIFF: https://github.com/llvm/llvm-project/commit/1d6eb09d2225310b1af54856c34fdcd45cd0f9ef.diff
LOG: [mlir] NFC - VectorTransforms use OpBuilder where relevant
Summary: This will allow using unrolling outside of only rewrite patterns.
Differential Revision: https://reviews.llvm.org/D80083
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index a7325ce838cb..337ac75f7cbb 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -65,7 +65,7 @@ namespace vector {
// This will be extended in the future to support more advanced use cases than
// simple pointwise ops.
SmallVector<Value, 1>
-unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
+unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape);
} // namespace vector
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 851b54beb452..af7e5ad86af8 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -68,8 +68,8 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
-static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
- Location loc, Operation *op,
+static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
+ Operation *op,
ArrayRef<Value> operands,
ArrayRef<Type> resultTypes) {
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
@@ -98,7 +98,7 @@ static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides,
- PatternRewriter &builder) {
+ OpBuilder &builder) {
assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
assert(static_cast<int64_t>(sizes.size()) == vectorType.getRank());
assert(static_cast<int64_t>(strides.size()) == vectorType.getRank());
@@ -140,7 +140,7 @@ static void initUnrolledVectorState(VectorType vectorType, Value initValue,
const DenseMap<int64_t, int64_t> &indexMap,
ArrayRef<int64_t> targetShape,
UnrolledVectorState &state,
- PatternRewriter &builder) {
+ OpBuilder &builder) {
// Compute unrolled shape of 'vectorType'.
state.unrolledShape.resize(vectorType.getRank());
getMappedElements(indexMap, targetShape, state.unrolledShape);
@@ -183,7 +183,7 @@ getUnrolledVectorLinearIndex(UnrolledVectorState &state,
static Value getOrCreateUnrolledVectorSlice(
Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
- Value initValue, SmallVectorImpl<Value> &cache, PatternRewriter &builder) {
+ Value initValue, SmallVectorImpl<Value> &cache, OpBuilder &builder) {
// Compute slice offsets.
SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
getMappedElements(indexMap, offsets, sliceOffsets);
@@ -275,7 +275,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
std::vector<VectorState> &vectors,
unsigned resultIndex,
ArrayRef<int64_t> targetShape,
- PatternRewriter &builder) {
+ OpBuilder &builder) {
auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape())
assert(false && "Expected a statically shaped result type");
@@ -426,7 +426,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
// Entry point for unrolling declarative pattern rewrites.
SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
- PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) {
+ OpBuilder &builder, Operation *op, ArrayRef<int64_t> targetShape) {
assert(op->getNumResults() == 1 && "Expected single result operation");
// Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
@@ -451,12 +451,10 @@ SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
/// calls 'fn' with linear index and indices for each slice.
-static void
-generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
- TupleType tupleType, ArrayRef<int64_t> sizes,
- ArrayRef<int64_t> strides, ArrayRef<Value> indices,
- PatternRewriter &rewriter,
- function_ref<void(unsigned, ArrayRef<Value>)> fn) {
+static void generateTransferOpSlices(
+ Type memrefElementType, VectorType vectorType, TupleType tupleType,
+ ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
+ OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
// Compute strides w.r.t. to slice counts in each dimension.
auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
assert(maybeDimSliceCounts.hasValue());
@@ -484,7 +482,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
}
unsigned indexOffset = numSliceIndices - vectorRank;
- auto *ctx = rewriter.getContext();
+ auto *ctx = builder.getContext();
for (unsigned i = 0; i < numSlices; ++i) {
auto vectorOffsets = delinearize(sliceStrides, i);
auto elementOffsets =
@@ -498,7 +496,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
auto expr = getAffineDimExpr(0, ctx) +
getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
- sliceIndices[j] = rewriter.create<AffineApplyOp>(
+ sliceIndices[j] = builder.create<AffineApplyOp>(
indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
}
}
@@ -1683,8 +1681,13 @@ class ShapeCastOp2DUpCastRewritePattern
// TODO(andydavis) Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, SplitTransferReadOp,
- SplitTransferWriteOp, TupleGetFolderOp>(context);
+ // clang-format off
+ patterns.insert<ShapeCastOpDecomposer,
+ ShapeCastOpFolder,
+ SplitTransferReadOp,
+ SplitTransferWriteOp,
+ TupleGetFolderOp>(context);
+ // clang-format on
}
void mlir::vector::populateVectorSlicesLoweringPatterns(
@@ -1695,9 +1698,14 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
void mlir::vector::populateVectorContractLoweringPatterns(
OwningRewritePatternList &patterns, MLIRContext *context,
VectorTransformsOptions parameters) {
- patterns.insert<ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
- TransposeOpLowering, OuterProductOpLowering,
- ConstantMaskOpLowering, CreateMaskOpLowering>(context);
+ // clang-format off
+ patterns.insert<BroadcastOpLowering,
+ CreateMaskOpLowering,
+ ConstantMaskOpLowering,
+ OuterProductOpLowering,
+ ShapeCastOp2DDownCastRewritePattern,
+ ShapeCastOp2DUpCastRewritePattern,
+ TransposeOpLowering>(context);
+ // clang-format on
patterns.insert<ContractionOpLowering>(parameters, context);
}
More information about the Mlir-commits
mailing list