[Mlir-commits] [mlir] aa2210a - [linalg] Expose `rewriteAsPaddedOp` function.

Alexander Belyaev llvmlistbot at llvm.org
Fri Aug 6 03:08:29 PDT 2021


Author: Alexander Belyaev
Date: 2021-08-06T12:08:12+02:00
New Revision: aa2210a830699fbf6e218789ac3da7abffee0b0c

URL: https://github.com/llvm/llvm-project/commit/aa2210a830699fbf6e218789ac3da7abffee0b0c
DIFF: https://github.com/llvm/llvm-project/commit/aa2210a830699fbf6e218789ac3da7abffee0b0c.diff

LOG: [linalg] Expose `rewriteAsPaddedOp` function.

Differential Revision: https://reviews.llvm.org/D107629

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a2ce99bc97b0e..87908ac88231f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -886,6 +886,13 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
+/// Try to create a static bounding box around each operand of `opToPad`.
+/// If successful, `paddedOp` will be updated to the cloned static form.
+LogicalResult
+rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
+                  const PaddingValueComputationFunction &paddingFunc,
+                  LinalgOp &paddedOp);
+
 using OptimizeCopyFn =
     std::function<LogicalResult(PatternRewriter &, PadTensorOp, Value)>;
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8d3ee8fe5566d..7fc1c5f47043f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -126,7 +126,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
 /// Return failure if the operand cannot be padded to a static shape.
 static LogicalResult padOperandToSmallestStaticBoundingBox(
     PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
-    const LinalgTilingOptions &options, Value &result) {
+    const PaddingValueComputationFunction &paddingFunc, Value &result) {
   // Already static shape, no need to pad.
   if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
     return success();
@@ -148,7 +148,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
           opToPad, "No constant bounding box can be found for padding");
     staticSizes.push_back(indexAttr.getInt());
   }
-  Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
+  Value pad = paddingFunc(rewriter, *opOperand);
   auto staticTensorType = RankedTensorType::get(
       staticSizes, getElementTypeOrSelf(opOperand->get()));
   result = linalg::PadTensorOp::createPadHighOp(
@@ -156,13 +156,10 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
   return success();
 }
 
-// Try to create a static bounding box around each operand of `res.op`.
-// If successful, `res.op` is rewritten in static form with padded operands.
-// `res.op` is updated to the cloned static form of the op on success.
-static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
-                                       TiledLinalgOp &res,
-                                       const LinalgTilingOptions &options) {
-  LinalgOp opToPad = res.op;
+LogicalResult
+linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
+                          const PaddingValueComputationFunction &paddingFunc,
+                          LinalgOp &paddedOp) {
   Location loc = opToPad->getLoc();
 
   // If the op is fully static, it does not need padding.
@@ -183,7 +180,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
     // If padding was requested but the shape cannot be bounded statically then
     // the pattern fails to apply.
     if (failed(padOperandToSmallestStaticBoundingBox(
-            rewriter, opToPad, opOperand, options, paddedOperand)))
+            rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
       return failure();
     newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
   }
@@ -191,8 +188,7 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
   // Clone `opToPad` to operate on the statically padded shapes.
   auto resultTensorTypes =
       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
-  linalg::LinalgOp paddedOp =
-      opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
+  paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
 
   // Recover the slice out of the new static results. This keeps the original
   // linalg op around because it uses the dims of the original results.
@@ -218,8 +214,6 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
   rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
     return !newUsersOfOpToPad.contains(opOp.getOwner());
   });
-
-  res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults};
   return success();
 }
 
@@ -265,15 +259,19 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
       !linalgOp.hasTensorSemantics())
     return success();
 
-  // Try to pad on the fly by rewriting res->op as a padded op.
-  if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
-    // Set so RAII guard does not propagate TiledLinalgOp to `result`.
-    return failure();
+  // Try to pad on the fly by rewriting res->op as a padded op. If successful,
+  // `res.op` is rewritten in static form with padded operands.
+  LinalgOp paddedOp;
+  if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
+                                  options.paddingValueComputationFunction,
+                                  paddedOp))) {
+    res->op = paddedOp;
+    // Do not perform replacement of `linalgOp`, let the derived patterns
+    // do this as they see fit, from the resulting TiledLinalgOp.
+    return success();
   }
-
-  // Do not perform replacement of `linalgOp`, let the derived patterns
-  // do this as they see fit, from the resulting TiledLinalgOp.
-  return success();
+  // Set so RAII guard does not propagate TiledLinalgOp to `result`.
+  return failure();
 }
 
 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {


        


More information about the Mlir-commits mailing list