[Mlir-commits] [mlir] 1eae247 - [mlir][linalg] Use OpBuilder in rewriteAsPaddedOp (NFC).

Tobias Gysi llvmlistbot at llvm.org
Thu Oct 28 23:31:44 PDT 2021


Author: Tobias Gysi
Date: 2021-10-29T05:54:29Z
New Revision: 1eae247a2d206151810b3b266d2fabf1bcd56cb1

URL: https://github.com/llvm/llvm-project/commit/1eae247a2d206151810b3b266d2fabf1bcd56cb1
DIFF: https://github.com/llvm/llvm-project/commit/1eae247a2d206151810b3b266d2fabf1bcd56cb1.diff

LOG: [mlir][linalg] Use OpBuilder in rewriteAsPaddedOp (NFC).

Adapt the rewriteAsPaddedOp method to use the OpBuilder instead of the PatterRewriter.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D112003

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 27688b5492530..6f5ed28d9bf78 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1085,10 +1085,13 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
-/// Try to create a static bounding box around each operand of `opToPad`.
-/// If successful, `paddedOp` will be updated to the cloned static form.
-LogicalResult
-rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
+/// Pad the operands of `opToPad` to a static bounding box. Use `paddingFunc`
+/// and `nofoldFunc` to set the padding value and the nofold attribute of the
+/// introduced PadTensorOps, respectively. Update `paddedOp` to the cloned
+/// statically shaped operation and return the extracted dynamically shaped
+/// results. If padding fails, return failure.
+FailureOr<SmallVector<Value>>
+rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
                   const PaddingValueComputationFunction &paddingFunc,
                   const PaddingNoFoldComputationFunction &nofoldFunc,
                   LinalgOp &paddedOp);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 3633d22dd987a..18c3620783f36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -152,14 +152,14 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
 /// result of the created PadTensorOp or return failure if the operand cannot be
 /// padded to a static shape.
 static LogicalResult padOperandToSmallestStaticBoundingBox(
-    PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
+    OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
     const PaddingValueComputationFunction &paddingFunc,
     const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
   // Can't pad scalars.
   if (opToPad.getShape(opOperand).empty())
     return success();
   // Can't pad if no padding value is known.
-  FailureOr<Value> paddingValue = paddingFunc(rewriter, *opOperand);
+  FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
   if (failed(paddingValue))
     return success();
   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
@@ -175,9 +175,10 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
                          : linalg::getSmallestBoundingIndex(size.get<Value>());
     // SmallestBoundingIndex must exist for all sizes.
     // For now return an error if we can't find it.
-    if (!indexAttr)
-      return rewriter.notifyMatchFailure(
-          opToPad, "No constant bounding box can be found for padding");
+    if (!indexAttr) {
+      LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
+      return failure();
+    }
     staticSizes.push_back(indexAttr.getInt());
   }
   auto staticTensorType = RankedTensorType::get(
@@ -185,12 +186,12 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
   bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
   result = linalg::PadTensorOp::createPadHighOp(
       staticTensorType, opOperand->get(), paddingValue.getValue(),
-      /*nofold=*/nofold, opToPad->getLoc(), rewriter);
+      /*nofold=*/nofold, opToPad->getLoc(), b);
   return success();
 }
 
-LogicalResult
-linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
+FailureOr<SmallVector<Value>>
+linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
                           const PaddingValueComputationFunction &paddingFunc,
                           const PaddingNoFoldComputationFunction &nofoldFunc,
                           LinalgOp &paddedOp) {
@@ -200,9 +201,9 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
   assert(opToPad.hasTensorSemantics() &&
          "expected operation to have tensor semantics");
 
-  OpBuilder::InsertionGuard g(rewriter);
+  OpBuilder::InsertionGuard g(b);
   // Set IP after op because we also take the dims of the original output.
-  rewriter.setInsertionPointAfter(opToPad);
+  b.setInsertionPointAfter(opToPad);
   // Make a copy of the shaped operands and update it.
   SmallVector<Value> newOperands;
   newOperands.reserve(opToPad.getNumInputsAndOutputs());
@@ -211,15 +212,14 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
     // If padding was requested but the shape cannot be bounded statically then
     // the pattern fails to apply.
     if (failed(padOperandToSmallestStaticBoundingBox(
-            rewriter, opToPad, opOperand, paddingFunc, nofoldFunc,
-            paddedOperand)))
+            b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
       return failure();
     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
   }
 
   SmallVector<SmallVector<Value>> reifiedResultShapes;
   if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
-                 .reifyResultShapes(rewriter, reifiedResultShapes)))
+                 .reifyResultShapes(b, reifiedResultShapes)))
     return failure();
   assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
          "expected same number of results");
@@ -227,7 +227,7 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
   // Clone `opToPad` to operate on the statically padded shapes.
   auto resultTensorTypes =
       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
-  paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
+  paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
 
   // Recover the slice out of the new static results. This keeps the original
   // linalg op around because it uses the dims of the original results.
@@ -237,16 +237,15 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
     Value paddedResult = en.value();
     int64_t resultNumber = en.index();
     int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
-    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
     SmallVector<OpFoldResult> sizes;
     for (Value v : reifiedResultShapes[resultNumber])
       sizes.push_back(v);
-    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
-    paddedSubviewResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
+    SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
+    paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
         loc, paddedResult, offsets, sizes, strides));
   }
-  rewriter.replaceOp(opToPad, paddedSubviewResults);
-  return success();
+  return paddedSubviewResults;
 }
 
 /// Linalg base tiling pattern.
@@ -345,9 +344,11 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
   // Try to pad on the fly by rewriting res->op as a padded op. If successful,
   // `res.op` is rewritten in static form with padded operands.
   LinalgOp paddedOp;
-  if (succeeded(rewriteAsPaddedOp(
-          rewriter, res->op, options.paddingValueComputationFunction,
-          options.paddingNoFoldComputationFunction, paddedOp))) {
+  FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
+      rewriter, res->op, options.paddingValueComputationFunction,
+      options.paddingNoFoldComputationFunction, paddedOp);
+  if (succeeded(newResults)) {
+    rewriter.replaceOp(res->op, newResults.getValue());
     filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
     res->op = paddedOp;
     result = *res;


        


More information about the Mlir-commits mailing list