[Mlir-commits] [mlir] 495acf9 - [mlir][Linalg] NFC - Purge OpBuilder uses in favor of RewriterBase in places unrelated to op definitions

Nicolas Vasilache llvmlistbot at llvm.org
Fri Dec 2 08:13:01 PST 2022


Author: Nicolas Vasilache
Date: 2022-12-02T08:06:29-08:00
New Revision: 495acf98da4b37bbcf58d1f14175a01d85d2b8f5

URL: https://github.com/llvm/llvm-project/commit/495acf98da4b37bbcf58d1f14175a01d85d2b8f5
DIFF: https://github.com/llvm/llvm-project/commit/495acf98da4b37bbcf58d1f14175a01d85d2b8f5.diff

LOG: [mlir][Linalg] NFC - Purge OpBuilder uses in favor of RewriterBase in places unrelated to op definitions

RewriterBase is the proper builder to use so one can listen to IR modifications (i.e. not just creation).

Differential Revision: https://reviews.llvm.org/D137922

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index 1780085b920c8..1297e87714f79 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -30,6 +30,7 @@ namespace mlir {
 class OpBuilder;
 class TypeRange;
 class ValueRange;
+class RewriterBase;
 
 /// Tests whether the given maps describe a row major matmul. The test is
 /// permutation-invariant. Note that this only checks the affine maps from an
@@ -81,8 +82,8 @@ class StructuredGenerator {
     Red() : IteratorType(IteratorTypeT::reduction) {}
   };
 
-  StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
-      : builder(builder), ctx(op.getContext()), loc(op.getLoc()),
+  StructuredGenerator(RewriterBase &rewriter, StructuredOpInterface op)
+      : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
         iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()),
         op(op) {}
 
@@ -102,7 +103,7 @@ class StructuredGenerator {
   }
 
 protected:
-  OpBuilder &builder;
+  RewriterBase &rewriter;
   MLIRContext *ctx;
   Location loc;
   SmallVector<IteratorTypeT> iterators;
@@ -112,10 +113,12 @@ class StructuredGenerator {
 
 // Clone the current operation with the operands. This is used to abstract away
 // the optional underlying region creation.
+// Note: this is a true builder that notifies the OpBuilder listener.
 Operation *clone(OpBuilder &b, Operation *op, TypeRange newResultTypes,
                  ValueRange newOperands);
 
 // Clone the current operation with the operands but leave the regions empty.
+// Note: this is a true builder that notifies the OpBuilder listener.
 Operation *cloneWithoutRegions(OpBuilder &b, Operation *op,
                                TypeRange newResultTypes,
                                ValueRange newOperands);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 82b191cdb4b7f..090c4d4a081ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -46,7 +46,7 @@ using namespace mlir::linalg;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X)
 
 /// Try to vectorize `convOp` as a convolution.
-static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b,
+static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
                                                    LinalgOp convOp);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -174,14 +174,18 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
       vector::BroadcastableToResult::Success)
     return value;
   Location loc = b.getInsertionPoint()->getLoc();
-  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
+  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
+                                                    value);
 }
 
 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
 /// assumes that `reductionOp` has two operands and one of them is the reduction
-/// initial value.
-static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
-                                      Value valueToReduce, Value acc,
+/// initial value.buildMultiDimReduce
+// Note: this is a true builder that notifies the OpBuilder listener.
+// TODO: Consider moving as a static helper on the ReduceOp.
+static Operation *buildMultiDimReduce(OpBuilder &b,
+                                      Operation *reduceOp, Value valueToReduce,
+                                      Value acc,
                                       const SmallVector<bool> &reductionMask) {
   auto maybeKind = getCombinerOpKind(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
@@ -198,6 +202,8 @@ static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
 /// currently being vectorized. If `dest` has null rank, build an memref.store.
 /// Return the produced value or null if no value is produced.
+// Note: this is a true builder that notifies the OpBuilder listener.
+// TODO: Consider moving as a static helper on the ReduceOp.
 static Value buildVectorWrite(OpBuilder &b, Value value,
                               OpOperand *outputOperand) {
   Operation *write;
@@ -217,14 +223,14 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
                                b.create<arith::ConstantIndexOp>(loc, 0));
     value = broadcastIfNeeded(b, value, vectorType.getShape());
-    write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
-                                              indices, map);
+    write = b.create<vector::TransferWriteOp>(
+        loc, value, outputOperand->get(), indices, map);
   } else {
     if (!value.getType().isa<VectorType>())
       value = b.create<vector::BroadcastOp>(loc, vectorType, value);
     assert(value.getType() == vectorType && "incorrect type");
-    write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
-                                              ValueRange{});
+    write = b.create<vector::TransferWriteOp>(
+        loc, value, outputOperand->get(), ValueRange{});
   }
   LDBG("vectorized op: " << *write);
   if (!write->getResults().empty())
@@ -233,7 +239,7 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
 }
 
 // Custom vectorization precondition function type. This is intented to be used
-// with CustomVectorizationHook. Returns success if the correpsonding custom
+// with CustomVectorizationHook. Returns success if the corresponding custom
 // hook can vectorize the op.
 using CustomVectorizationPrecondition =
     std::function<LogicalResult(Operation *)>;
@@ -248,11 +254,11 @@ using CustomVectorizationHook = std::function<VectorizationResult(
 /// vector values are appended to `newResults`. Return
 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
 /// should not try to map produced operations and instead return the results
-/// using the `newResults` vector making them available to the
-/// vectorization algorithm for RAUW. This function is meant to be used as a
+/// using the `newResults` vector making them available to the vectorization
+/// algorithm for RAUW. This function is meant to be used as a
 /// CustomVectorizationHook.
 static VectorizationResult
-vectorizeLinalgYield(OpBuilder &b, Operation *op,
+vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
                      SmallVectorImpl<Value> &newResults) {
   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
@@ -263,7 +269,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
     // TODO: use a map.
     Value vectorValue = bvm.lookup(outputs.value());
     Value newResult = buildVectorWrite(
-        b, vectorValue, linalgOp.getDpsInitOperand(outputs.index()));
+        rewriter, vectorValue, linalgOp.getDpsInitOperand(outputs.index()));
     if (newResult)
       newResults.push_back(newResult);
   }
@@ -274,8 +280,8 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op,
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
-                                                LinalgOp linalgOp) {
+static VectorizationResult
+vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) {
   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
   if (!indexOp)
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
@@ -285,8 +291,8 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
   // Compute a one-dimensional index vector for the index op dimension.
   SmallVector<int64_t> constantSeq =
       llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
-  auto constantOp =
-      b.create<arith::ConstantOp>(loc, b.getIndexVectorAttr(constantSeq));
+  auto constantOp = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getIndexVectorAttr(constantSeq));
   // Return the one-dimensional index vector if it lives in the trailing
   // dimension of the iteration space since the vectorization algorithm in this
   // case can handle the broadcast.
@@ -296,13 +302,13 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
   // broadcast the one-dimensional index vector to the permuted shape, and
   // finally transpose the broadcasted index vector to undo the permutation.
   std::swap(targetShape[indexOp.getDim()], targetShape.back());
-  auto broadCastOp = b.create<vector::BroadcastOp>(
-      loc, VectorType::get(targetShape, b.getIndexType()), constantOp);
+  auto broadCastOp = rewriter.create<vector::BroadcastOp>(
+      loc, VectorType::get(targetShape, rewriter.getIndexType()), constantOp);
   SmallVector<int64_t> transposition =
       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
   std::swap(transposition.back(), transposition[indexOp.getDim()]);
   auto transposeOp =
-      b.create<vector::TransposeOp>(loc, broadCastOp, transposition);
+      rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
 }
 
@@ -334,7 +340,7 @@ static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) {
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
 static VectorizationResult
-vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp,
+vectorizeTensorExtract(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp,
                        const BlockAndValueMapping &bvm) {
   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
   if (!extractOp)
@@ -350,19 +356,19 @@ vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp,
   auto targetShape = linalgOp.computeStaticLoopSizes();
 
   SmallVector<Value> gatherIndices;
-  gatherIndices.push_back(b.create<arith::ConstantIndexOp>(loc, 0));
+  gatherIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
 
-  auto maskConstantOp = b.create<arith::ConstantOp>(
-      loc,
-      DenseIntElementsAttr::get(VectorType::get(targetShape, b.getI1Type()),
-                                /*value=*/true));
+  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseIntElementsAttr::get(
+               VectorType::get(targetShape, rewriter.getI1Type()),
+               /*value=*/true));
 
   auto resultType =
       VectorType::get(targetShape, extractOp.getResult().getType());
   auto passThruConstantOp =
-      b.create<arith::ConstantOp>(loc, b.getZeroAttr(resultType));
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
 
-  auto gatherOp = b.create<vector::GatherOp>(
+  auto gatherOp = rewriter.create<vector::GatherOp>(
       loc, resultType, extractOp.getTensor(), gatherIndices, indexVec,
       maskConstantOp, passThruConstantOp);
 
@@ -371,8 +377,11 @@ vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp,
 
 /// Emit reduction operations if the shapes of the value to reduce is 
diff erent
 /// that the result shape.
-static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
-                                 Value reduceValue, Value initialValue,
+// Note: this is a true builder that notifies the OpBuilder listener.
+// TODO: Consider moving as a static helper on the ReduceOp.
+static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp,
+                                 Operation *op, Value reduceValue,
+                                 Value initialValue,
                                  const BlockAndValueMapping &bvm) {
   Value reduceVec = bvm.lookup(reduceValue);
   Value outputVec = bvm.lookup(initialValue);
@@ -402,12 +411,12 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
 ///
 /// This function assumes all operands of `op` have been vectorized and are in
-/// the `bvm` mapping. As a consequence, this function is meant to be called on
+/// the `bvm` mapping. As a consequence, this function is meant to be called  on
 /// a topologically-sorted list of ops.
 /// This function does not update `bvm` but returns a VectorizationStatus that
 /// instructs the caller what `bvm` update needs to occur.
 static VectorizationResult
-vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
+vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
                const BlockAndValueMapping &bvm,
                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
   LDBG("vectorize op " << *op);
@@ -425,7 +434,7 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
   // Clone so that the constant is not confined to the linalgOp block .
   if (isa<arith::ConstantOp, func::ConstantOp>(op))
-    return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)};
+    return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
 
   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
   if (!OpTrait::hasElementwiseMappableTraits(op))
@@ -448,7 +457,7 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
   if (!reductionOperands.empty()) {
     assert(reductionOperands.size() == 1);
     Operation *reduceOp =
-        reduceIfNeeded(b, linalgOp, op, reductionOperands[0].first,
+        reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
                        reductionOperands[0].second, bvm);
     if (reduceOp)
       return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
@@ -462,11 +471,12 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
     if (vt && firstMaxRankedShape.size() < vt.getShape().size())
       firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
   }
-  //   b. broadcast each op if needed.
+  //   rewriter. broadcast each op if needed.
   auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
     return firstMaxRankedShape.empty()
                ? bvm.lookup(v)
-               : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape);
+               : broadcastIfNeeded(rewriter, bvm.lookup(v),
+                                   firstMaxRankedShape);
   });
   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
   auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
@@ -478,9 +488,9 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
   // Build and return the new op.
   return VectorizationResult{
       VectorizationStatus::NewOp,
-      b.create(op->getLoc(), op->getName().getIdentifier(),
-               llvm::to_vector<4>(vectorizedOperands),
-               llvm::to_vector<4>(returnTypes), op->getAttrs())};
+      rewriter.create(op->getLoc(), op->getName().getIdentifier(),
+                      llvm::to_vector<4>(vectorizedOperands),
+                      llvm::to_vector<4>(returnTypes), op->getAttrs())};
 }
 
 /// Generic vectorization function that rewrites the body of a `linalgOp` into
@@ -492,8 +502,8 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
 ///   load).
 ///   TODO: Reuse opportunities for RAR dependencies.
 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
-///   4b. Register CustomVectorizationHook for IndexOp to access the iteration
-///   indices.
+///   4rewriter. Register CustomVectorizationHook for IndexOp to access the
+///   iteration indices.
 ///   5. Iteratively call vectorizeOneOp on the region operations.
 ///
 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
@@ -506,7 +516,7 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
 /// This is not deemed a problem as we expect canonicalizations and foldings to
 /// aggressively clean up the useless work.
 static LogicalResult
-vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
+vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
                          SmallVectorImpl<Value> &newResults) {
   Block *block = linalgOp.getBlock();
 
@@ -527,7 +537,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
 
   // 3. Turn all BBArgs into vector.transfer_read / load.
   Location loc = linalgOp.getLoc();
-  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
     BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
     if (linalgOp.isScalar(opOperand)) {
@@ -555,12 +565,12 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
 
     auto shape = linalgOp.getShape(opOperand);
     SmallVector<Value> indices(shape.size(), zero);
-    Value readValue = b.create<vector::TransferReadOp>(
+    Value readValue = rewriter.create<vector::TransferReadOp>(
         loc, readType, opOperand->get(), indices, map);
     // Not all ops support 0-d vectors, extract the scalar for now.
     // TODO: remove this.
     if (readValue.getType().cast<VectorType>().getRank() == 0)
-      readValue = b.create<vector::ExtractElementOp>(loc, readValue);
+      readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
 
     LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
     bvm.map(bbarg, readValue);
@@ -572,15 +582,15 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
   CustomVectorizationHook vectorizeYield =
       [&](Operation *op,
           const BlockAndValueMapping &bvm) -> VectorizationResult {
-    return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults);
+    return vectorizeLinalgYield(rewriter, op, bvm, linalgOp, newResults);
   };
   hooks.push_back(vectorizeYield);
 
-  // 4b. Register CustomVectorizationHook for indexOp.
+  // 4rewriter. Register CustomVectorizationHook for indexOp.
   CustomVectorizationHook vectorizeIndex =
       [&](Operation *op,
           const BlockAndValueMapping &bvm) -> VectorizationResult {
-    return vectorizeLinalgIndex(b, op, linalgOp);
+    return vectorizeLinalgIndex(rewriter, op, linalgOp);
   };
   hooks.push_back(vectorizeIndex);
 
@@ -588,13 +598,14 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
   CustomVectorizationHook vectorizeExtract =
       [&](Operation *op,
           const BlockAndValueMapping &bvm) -> VectorizationResult {
-    return vectorizeTensorExtract(b, op, linalgOp, bvm);
+    return vectorizeTensorExtract(rewriter, op, linalgOp, bvm);
   };
   hooks.push_back(vectorizeExtract);
 
   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
   for (Operation &op : block->getOperations()) {
-    VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks);
+    VectorizationResult result =
+        vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks);
     if (result.status == VectorizationStatus::Failure) {
       LDBG("failed to vectorize: " << op);
       return failure();
@@ -760,14 +771,14 @@ static int64_t getIntFromAttr(Attribute attr) {
 /// Given an ArrayRef of OpFoldResults, return a vector of Values.
 /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
 /// not supported.
-static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
+static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
                                            ArrayRef<OpFoldResult> ofrs) {
   SmallVector<Value> result;
   for (auto o : ofrs) {
     if (auto val = o.template dyn_cast<Value>()) {
       result.push_back(val);
     } else {
-      result.push_back(builder.create<arith::ConstantIndexOp>(
+      result.push_back(rewriter.create<arith::ConstantIndexOp>(
           loc, getIntFromAttr(o.template get<Attribute>())));
     }
   }
@@ -1415,9 +1426,9 @@ namespace {
 /// kw is unrolled, w is unrolled iff dilationW > 1.
 struct Conv1DGenerator
     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
-  Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
+  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
                   int dilationW)
-      : StructuredGenerator<LinalgOp, utils::IteratorType>(builder, linalgOp),
+      : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
         strideW(strideW), dilationW(dilationW) {
     // Determine whether `linalgOp` can be generated with this generator
     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
@@ -1481,8 +1492,7 @@ struct Conv1DGenerator
   /// > 1.
   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
     if (!valid)
-      return IRRewriter(builder).notifyMatchFailure(op,
-                                                    "unvectorizable 1-D conv");
+      return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv");
 
     int64_t nSize, wSize, cSize, kwSize, fSize;
     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
@@ -1519,7 +1529,7 @@ struct Conv1DGenerator
     }
 
     vector::TransferWriteOp write;
-    Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
 
     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
     // When strideW == 1, we can batch the contiguous loads and avoid
@@ -1534,13 +1544,13 @@ struct Conv1DGenerator
     auto resType = VectorType::get(resShape, resEltType);
     // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
     // 0].
-    Value lhs = builder.create<vector::TransferReadOp>(
+    Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
     // Read rhs slice of size {kw, c, f} @ [0, 0, 0].
-    Value rhs = builder.create<vector::TransferReadOp>(
+    Value rhs = rewriter.create<vector::TransferReadOp>(
         loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
     // Read res slice of size {n, w, f} @ [0, 0, 0].
-    Value res = builder.create<vector::TransferReadOp>(
+    Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
 
     // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output:
@@ -1554,13 +1564,13 @@ struct Conv1DGenerator
       // To match base vectorization case, we pre-transpose current case.
       // ncw -> nwc
       static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
-      lhs = builder.create<vector::TransposeOp>(loc, lhs, permLhs);
+      lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
       // fcw -> wcf
       static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
-      rhs = builder.create<vector::TransposeOp>(loc, rhs, permRhs);
+      rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
       // nfw -> nwf
       static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
-      res = builder.create<vector::TransposeOp>(loc, res, permRes);
+      res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
       break;
     }
     }
@@ -1573,7 +1583,7 @@ struct Conv1DGenerator
     // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+        lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
             loc, lhs,
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
@@ -1582,12 +1592,12 @@ struct Conv1DGenerator
     }
     // Extract rhs slice of size {c, f} @ [kw].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
-      rhsVals.push_back(builder.create<vector::ExtractOp>(
+      rhsVals.push_back(rewriter.create<vector::ExtractOp>(
           loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
     }
     // Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+      resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
           loc, res,
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
@@ -1602,14 +1612,14 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         resVals[w] = conv1dSliceAsContraction(
-            builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
+            rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
       }
     }
 
     // Write back res slice: {n, wSizeStep, f} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      res = builder.create<vector::InsertStridedSliceOp>(
+      res = rewriter.create<vector::InsertStridedSliceOp>(
           loc, resVals[w], res,
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
@@ -1628,26 +1638,26 @@ struct Conv1DGenerator
     case Conv1DOpOrder::Ncw: {
       // nwf -> nfw
       static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
-      res = builder.create<vector::TransposeOp>(loc, res, perm);
+      res = rewriter.create<vector::TransposeOp>(loc, res, perm);
       break;
     }
     }
 
     // Write back res slice of size {n, w, f} @ [0, 0, 0].
-    return builder
+    return rewriter
         .create<vector::TransferWriteOp>(loc, res, resShaped,
                                          ValueRange{zero, zero, zero})
         .getOperation();
   }
 
   // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
-  Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs,
-                                 Value rhs, Value res) {
+  Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
+                                 Value lhs, Value rhs, Value res) {
     vector::IteratorType par = vector::IteratorType::parallel;
     vector::IteratorType red = vector::IteratorType::reduction;
     AffineExpr n, w, f, c;
     bindDims(ctx, n, w, f, c);
-    return builder.create<vector::ContractionOp>(
+    return rewriter.create<vector::ContractionOp>(
         loc, lhs, rhs, res,
         /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
         /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
@@ -1664,8 +1674,7 @@ struct Conv1DGenerator
   /// > 1.
   FailureOr<Operation *> depthwiseConv() {
     if (!valid)
-      return IRRewriter(builder).notifyMatchFailure(
-          op, "unvectorizable depthwise conv");
+      return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
     int64_t nSize, wSize, cSize, kwSize;
     // kernel{kw, c}
@@ -1674,7 +1683,7 @@ struct Conv1DGenerator
     bindShapeDims(resShapedType, nSize, wSize);
 
     vector::TransferWriteOp write;
-    Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
 
     // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
     // When strideW == 1, we can batch the contiguous loads and avoid
@@ -1696,13 +1705,13 @@ struct Conv1DGenerator
 
     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
     // 0].
-    Value lhs = builder.create<vector::TransferReadOp>(
+    Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
     // Read rhs slice of size {kw, c} @ [0, 0].
-    Value rhs = builder.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
-                                                       ValueRange{zero, zero});
+    Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+                                                        ValueRange{zero, zero});
     // Read res slice of size {n, w, c} @ [0, 0, 0].
-    Value res = builder.create<vector::TransferReadOp>(
+    Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
 
     //===------------------------------------------------------------------===//
@@ -1714,7 +1723,7 @@ struct Conv1DGenerator
     //   @ [0, sw * w + dw * kw, 0].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+        lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
             loc, lhs,
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
@@ -1723,12 +1732,12 @@ struct Conv1DGenerator
     }
     // Extract rhs slice of size {c} @ [kw].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
-      rhsVals.push_back(builder.create<vector::ExtractOp>(
+      rhsVals.push_back(rewriter.create<vector::ExtractOp>(
           loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
     }
     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
+      resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
           loc, res,
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
@@ -1743,18 +1752,23 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         resVals[w] = depthwiseConv1dSliceAsMulAcc(
-            builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
+            rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
       }
     }
 
-    // Its possible we failed to create the Fma
-    if (!llvm::all_of(resVals, [](Value v) { return v; }))
-      return IRRewriter(builder).notifyMatchFailure(op, "failed to create FMA");
+    // Its possible we failed to create the Fma.
+    if (!llvm::all_of(resVals, [](Value v) { return v; })) {
+      // Manually revert (in reverse order) to avoid leaving a bad IR state.
+      for (auto &collection : {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
+        for (Value v : collection)
+          rewriter.eraseOp(v.getDefiningOp());
+      return rewriter.notifyMatchFailure(op, "failed to create FMA");
+    }
 
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      res = builder.create<vector::InsertStridedSliceOp>(
+      res = rewriter.create<vector::InsertStridedSliceOp>(
           loc, resVals[w], res,
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
@@ -1764,14 +1778,14 @@ struct Conv1DGenerator
     //===------------------------------------------------------------------===//
 
     // Write back res slice of size {n, w, c} @ [0, 0, 0].
-    return builder
+    return rewriter
         .create<vector::TransferWriteOp>(loc, res, resShaped,
                                          ValueRange{zero, zero, zero})
         .getOperation();
   }
 
   // Take a value of element type T and widen to the destination type.
-  Value promote(OpBuilder &b, Location loc, Value val, Type ty) {
+  Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
     if (val.getType() == ty)
       return val;
 
@@ -1780,35 +1794,35 @@ struct Conv1DGenerator
     const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth();
 
     if (getElementTypeOrSelf(ty).isa<FloatType>() && srcWidth < destWidth)
-      return builder.create<arith::ExtFOp>(loc, ty, val);
+      return rewriter.create<arith::ExtFOp>(loc, ty, val);
 
     if (getElementTypeOrSelf(ty).isa<IntegerType>() && srcWidth < destWidth)
-      return builder.create<arith::ExtSIOp>(loc, ty, val);
+      return rewriter.create<arith::ExtSIOp>(loc, ty, val);
 
     return nullptr;
   }
 
   /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
-  Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs,
-                                     Value rhs, Value res) {
+  Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
+                                     Value lhs, Value rhs, Value res) {
     auto rhsTy = rhs.getType().cast<ShapedType>();
     auto resTy = res.getType().cast<ShapedType>();
 
     // TODO(suderman): Change this to use a vector.ima intrinsic.
-    lhs = promote(b, loc, lhs, resTy);
+    lhs = promote(rewriter, loc, lhs, resTy);
 
-    rhs = builder.create<vector::BroadcastOp>(
+    rhs = rewriter.create<vector::BroadcastOp>(
         loc, resTy.clone(rhsTy.getElementType()), rhs);
-    rhs = promote(b, loc, rhs, resTy);
+    rhs = promote(rewriter, loc, rhs, resTy);
 
     if (!lhs || !rhs)
       return nullptr;
 
     if (resTy.getElementType().isa<FloatType>())
-      return b.create<vector::FMAOp>(loc, lhs, rhs, res);
+      return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
 
-    auto mul = b.create<arith::MulIOp>(loc, lhs, rhs);
-    return b.create<arith::AddIOp>(loc, mul, res);
+    auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
+    return rewriter.create<arith::AddIOp>(loc, mul, res);
   }
 
   /// Entry point that transposes into the common form:
@@ -1817,7 +1831,7 @@ struct Conv1DGenerator
     AffineExpr n, w, f, kw, c;
     bindDims(ctx, n, w, f, kw, c);
     if (!iters({Par(), Par(), Par(), Red(), Red()}))
-      return IRRewriter(builder).notifyMatchFailure(
+      return rewriter.notifyMatchFailure(
           op, "failed to match conv::Nwc 3-par 2-red");
 
     // No transposition needed.
@@ -1825,7 +1839,7 @@ struct Conv1DGenerator
                 /*rhsIndex*/ {kw, c, f},
                 /*resIndex*/ {n, w, f}}))
       return conv(Conv1DOpOrder::Nwc);
-    return IRRewriter(builder).notifyMatchFailure(op, "not a conv::Nwc layout");
+    return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
   }
 
   /// Entry point that transposes into the common form:
@@ -1834,7 +1848,7 @@ struct Conv1DGenerator
     AffineExpr n, w, f, kw, c;
     bindDims(ctx, n, f, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red(), Red()}))
-      return IRRewriter(builder).notifyMatchFailure(
+      return rewriter.notifyMatchFailure(
           op, "failed to match conv::Ncw 3-par 2-red");
 
     if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
@@ -1842,7 +1856,7 @@ struct Conv1DGenerator
                 /*resIndex*/ {n, f, w}}))
       return conv(Conv1DOpOrder::Ncw);
 
-    return IRRewriter(builder).notifyMatchFailure(op, "not a conv::Ncw layout");
+    return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
   }
 
   /// Entry point that transposes into the common form:
@@ -1851,7 +1865,7 @@ struct Conv1DGenerator
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
-      return IRRewriter(builder).notifyMatchFailure(
+      return rewriter.notifyMatchFailure(
           op, "failed to match depthwise::Nwc conv 3-par 1-red");
 
     // No transposition needed.
@@ -1860,8 +1874,7 @@ struct Conv1DGenerator
                 /*resIndex*/ {n, w, c}}))
       return depthwiseConv();
 
-    return IRRewriter(builder).notifyMatchFailure(
-        op, "not a depthwise::Nwc layout");
+    return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
 
 private:
@@ -1874,7 +1887,8 @@ struct Conv1DGenerator
 
 /// Helper function to vectorize a LinalgOp with convolution semantics.
 // TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
+static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
+                                                   LinalgOp op) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can simply
   // use them if present, otherwise use the default and let the generic conv.
@@ -1883,7 +1897,7 @@ static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
   auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-  Conv1DGenerator e(b, op, stride, dilation);
+  Conv1DGenerator e(rewriter, op, stride, dilation);
   auto res = e.generateNwcConv();
   if (succeeded(res))
     return res;

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c41e66325c011..d44d46690b7eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1522,15 +1522,14 @@ namespace {
 /// This unrolls outer-products along the reduction dimension.
 struct UnrolledOuterProductGenerator
     : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
-  UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
-      : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(
-            builder, op),
+  UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
+      : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
         kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
         res(op.getAcc()), lhsType(op.getLhsType()) {}
 
   Value t(Value v) {
     static constexpr std::array<int64_t, 2> perm = {1, 0};
-    return builder.create<vector::TransposeOp>(loc, v, perm);
+    return rewriter.create<vector::TransposeOp>(loc, v, perm);
   }
 
   Value promote(Value v, Type dstElementType) {
@@ -1544,20 +1543,20 @@ struct UnrolledOuterProductGenerator
     if (vecType)
       promotedType = VectorType::get(vecType.getShape(), promotedType);
     if (dstElementType.isa<FloatType>())
-      return builder.create<arith::ExtFOp>(loc, promotedType, v);
-    return builder.create<arith::ExtSIOp>(loc, promotedType, v);
+      return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
+    return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
   }
 
   Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
     assert(reductionSize > 0);
     Type resElementType = res.getType().cast<VectorType>().getElementType();
     for (int64_t k = 0; k < reductionSize; ++k) {
-      Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
-      Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
-      a = promote(a, resElementType);
-      b = promote(b, resElementType);
-      res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
-                                                   res, kind);
+      Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
+      Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
+      extractA = promote(extractA, resElementType);
+      extractB = promote(extractB, resElementType);
+      res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), extractA,
+                                             extractB, res, kind);
     }
     return res;
   }
@@ -1568,7 +1567,7 @@ struct UnrolledOuterProductGenerator
       return failure();
     // Set up the parallel/reduction structure in the right form.
     AffineExpr m, n, k;
-    bindDims(builder.getContext(), m, n, k);
+    bindDims(rewriter.getContext(), m, n, k);
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
@@ -1604,7 +1603,7 @@ struct UnrolledOuterProductGenerator
     if (!iters({Par(), Red()}))
       return failure();
     AffineExpr m, k;
-    bindDims(builder.getContext(), m, k);
+    bindDims(rewriter.getContext(), m, k);
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
@@ -1628,7 +1627,7 @@ struct UnrolledOuterProductGenerator
     if (!iters({Red(), Par()}))
       return failure();
     AffineExpr k, m;
-    bindDims(builder.getContext(), k, m);
+    bindDims(rewriter.getContext(), k, m);
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))


        


More information about the Mlir-commits mailing list