[Mlir-commits] [mlir] 1d6eb09 - [mlir] NFC - VectorTransforms use OpBuilder where relevant

Nicolas Vasilache llvmlistbot at llvm.org
Sun May 17 07:21:05 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-17T10:17:12-04:00
New Revision: 1d6eb09d2225310b1af54856c34fdcd45cd0f9ef

URL: https://github.com/llvm/llvm-project/commit/1d6eb09d2225310b1af54856c34fdcd45cd0f9ef
DIFF: https://github.com/llvm/llvm-project/commit/1d6eb09d2225310b1af54856c34fdcd45cd0f9ef.diff

LOG: [mlir] NFC - VectorTransforms use OpBuilder where relevant

Summary: This will allow using unrolling outside of only rewrite patterns.

Differential Revision: https://reviews.llvm.org/D80083

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index a7325ce838cb..337ac75f7cbb 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -65,7 +65,7 @@ namespace vector {
 // This will be extended in the future to support more advanced use cases than
 // simple pointwise ops.
 SmallVector<Value, 1>
-unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
+unrollSingleResultOpMatchingType(OpBuilder &builder, Operation *op,
                                  ArrayRef<int64_t> targetShape);
 
 } // namespace vector

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 851b54beb452..af7e5ad86af8 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -68,8 +68,8 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
 
 // Clones `op` into a new operations that takes `operands` and returns
 // `resultTypes`.
-static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
-                                              Location loc, Operation *op,
+static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
+                                              Operation *op,
                                               ArrayRef<Value> operands,
                                               ArrayRef<Type> resultTypes) {
   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
@@ -98,7 +98,7 @@ static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
 static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
                                                    ArrayRef<int64_t> sizes,
                                                    ArrayRef<int64_t> strides,
-                                                   PatternRewriter &builder) {
+                                                   OpBuilder &builder) {
   assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
   assert(static_cast<int64_t>(sizes.size()) == vectorType.getRank());
   assert(static_cast<int64_t>(strides.size()) == vectorType.getRank());
@@ -140,7 +140,7 @@ static void initUnrolledVectorState(VectorType vectorType, Value initValue,
                                     const DenseMap<int64_t, int64_t> &indexMap,
                                     ArrayRef<int64_t> targetShape,
                                     UnrolledVectorState &state,
-                                    PatternRewriter &builder) {
+                                    OpBuilder &builder) {
   // Compute unrolled shape of 'vectorType'.
   state.unrolledShape.resize(vectorType.getRank());
   getMappedElements(indexMap, targetShape, state.unrolledShape);
@@ -183,7 +183,7 @@ getUnrolledVectorLinearIndex(UnrolledVectorState &state,
 static Value getOrCreateUnrolledVectorSlice(
     Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets,
     ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap,
-    Value initValue, SmallVectorImpl<Value> &cache, PatternRewriter &builder) {
+    Value initValue, SmallVectorImpl<Value> &cache, OpBuilder &builder) {
   // Compute slice offsets.
   SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size());
   getMappedElements(indexMap, offsets, sliceOffsets);
@@ -275,7 +275,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
                                             std::vector<VectorState> &vectors,
                                             unsigned resultIndex,
                                             ArrayRef<int64_t> targetShape,
-                                            PatternRewriter &builder) {
+                                            OpBuilder &builder) {
   auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
   if (!shapedType || !shapedType.hasStaticShape())
     assert(false && "Expected a statically shaped result type");
@@ -426,7 +426,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
 
 // Entry point for unrolling declarative pattern rewrites.
 SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
-    PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) {
+    OpBuilder &builder, Operation *op, ArrayRef<int64_t> targetShape) {
   assert(op->getNumResults() == 1 && "Expected single result operation");
 
   // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'.
@@ -451,12 +451,10 @@ SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
 
 /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
 /// calls 'fn' with linear index and indices for each slice.
-static void
-generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
-                         TupleType tupleType, ArrayRef<int64_t> sizes,
-                         ArrayRef<int64_t> strides, ArrayRef<Value> indices,
-                         PatternRewriter &rewriter,
-                         function_ref<void(unsigned, ArrayRef<Value>)> fn) {
+static void generateTransferOpSlices(
+    Type memrefElementType, VectorType vectorType, TupleType tupleType,
+    ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, ArrayRef<Value> indices,
+    OpBuilder &builder, function_ref<void(unsigned, ArrayRef<Value>)> fn) {
   // Compute strides w.r.t. to slice counts in each dimension.
   auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes);
   assert(maybeDimSliceCounts.hasValue());
@@ -484,7 +482,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
   }
   unsigned indexOffset = numSliceIndices - vectorRank;
 
-  auto *ctx = rewriter.getContext();
+  auto *ctx = builder.getContext();
   for (unsigned i = 0; i < numSlices; ++i) {
     auto vectorOffsets = delinearize(sliceStrides, i);
     auto elementOffsets =
@@ -498,7 +496,7 @@ generateTransferOpSlices(Type memrefElementType, VectorType vectorType,
         auto expr = getAffineDimExpr(0, ctx) +
                     getAffineConstantExpr(elementOffsets[j - indexOffset], ctx);
         auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
-        sliceIndices[j] = rewriter.create<AffineApplyOp>(
+        sliceIndices[j] = builder.create<AffineApplyOp>(
             indices[j].getLoc(), map, ArrayRef<Value>(indices[j]));
       }
     }
@@ -1683,8 +1681,13 @@ class ShapeCastOp2DUpCastRewritePattern
 // TODO(andydavis) Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, SplitTransferReadOp,
-                  SplitTransferWriteOp, TupleGetFolderOp>(context);
+  // clang-format off
+  patterns.insert<ShapeCastOpDecomposer,
+                  ShapeCastOpFolder,
+                  SplitTransferReadOp,
+                  SplitTransferWriteOp,
+                  TupleGetFolderOp>(context);
+  // clang-format on
 }
 
 void mlir::vector::populateVectorSlicesLoweringPatterns(
@@ -1695,9 +1698,14 @@ void mlir::vector::populateVectorSlicesLoweringPatterns(
 void mlir::vector::populateVectorContractLoweringPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context,
     VectorTransformsOptions parameters) {
-  patterns.insert<ShapeCastOp2DDownCastRewritePattern,
-                  ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
-                  TransposeOpLowering, OuterProductOpLowering,
-                  ConstantMaskOpLowering, CreateMaskOpLowering>(context);
+  // clang-format off
+  patterns.insert<BroadcastOpLowering,
+                  CreateMaskOpLowering,
+                  ConstantMaskOpLowering,
+                  OuterProductOpLowering,
+                  ShapeCastOp2DDownCastRewritePattern,
+                  ShapeCastOp2DUpCastRewritePattern,
+                  TransposeOpLowering>(context);
+  // clang-format on
   patterns.insert<ContractionOpLowering>(parameters, context);
 }


        


More information about the Mlir-commits mailing list