[Mlir-commits] [mlir] 2b7ded2 - [mlir][linalg] Add option to pad Linalg ops to a specified multiple

Matthias Springer llvmlistbot at llvm.org
Wed Jun 7 00:00:13 PDT 2023


Author: Matthias Springer
Date: 2023-06-07T08:54:36+02:00
New Revision: 2b7ded215dcd9ecb0ecc606715e41f9e96a8a9c6

URL: https://github.com/llvm/llvm-project/commit/2b7ded215dcd9ecb0ecc606715e41f9e96a8a9c6
DIFF: https://github.com/llvm/llvm-project/commit/2b7ded215dcd9ecb0ecc606715e41f9e96a8a9c6.diff

LOG: [mlir][linalg] Add option to pad Linalg ops to a specified multiple

A multiple (int64_t) can optionally be specified for every padding dimension.

Differential Revision: https://reviews.llvm.org/D152262

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-op-pad.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a7356c0e57e20..6ba9f9aff42d1 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -823,6 +823,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
     (ins TransformHandleTypeInterface:$target,
          DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
          DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
+         OptionalAttr<I64ArrayAttr>:$pad_to_multiple_of,
          DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
          DefaultValuedAttr<
           TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 693b89a932160..1d7f448ff180a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -148,6 +148,12 @@ struct LinalgPaddingOptions {
     paddingDimensions.assign(pd.begin(), pd.end());
     return *this;
   }
+  /// A list of multiples to which each padding dimension should be padded to.
+  std::optional<SmallVector<int64_t>> padToMultipleOf;
+  LinalgPaddingOptions &setPadToMultipleOf(ArrayRef<int64_t> m) {
+    padToMultipleOf.emplace(m.begin(), m.end());
+    return *this;
+  }
   /// A flag for every operand to mark the PadOp as nofold which enables
   /// packing for statically shaped operands.
   SmallVector<bool> packPaddings;
@@ -350,14 +356,17 @@ SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
 void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
 
 /// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
-/// to a static bounding box. Use `paddingValues` and `packPaddings` to set
-/// padding value and nofold attribute of the created tensor::PadOps,
-/// respectively. Update `paddedOp` to the cloned operation with statically
-/// shaped `paddingDimensions` and return the extracted dynamically shaped
-/// results. If padding fails, return failure.
+/// to a static bounding box. `padToMultipleOf` indicates that each padding
+/// dimension should be padded to the specified multiple. If the derived padding
+/// sizes should not be rounded up to any multiple, use "1". Use `paddingValues`
+/// and `packPaddings` to set padding value and nofold attribute of the created
+/// tensor::PadOps, respectively. Update `paddedOp` to the cloned operation with
+/// statically shaped `paddingDimensions` and return the extracted dynamically
+/// shaped results. If padding fails, return failure.
 FailureOr<SmallVector<Value>>
 rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                   ArrayRef<int64_t> paddingDimensions,
+                  ArrayRef<int64_t> padToMultipleOf,
                   ArrayRef<Attribute> paddingValues,
                   ArrayRef<bool> packPaddings, LinalgOp &paddedOp);
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4934a60f578b5..f7fa2f107754b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1619,9 +1619,14 @@ transform::PadOp::applyToOne(LinalgOp target,
   TrackingListener listener(state, *this);
   IRRewriter rewriter(getContext(), &listener);
   LinalgOp paddedOp;
-  FailureOr<SmallVector<Value>> result = rewriteAsPaddedOp(
-      rewriter, target, extractFromI64ArrayAttr(getPaddingDimensions()),
-      paddingValues, packPaddings, paddedOp);
+  SmallVector<int64_t> paddingDimensions =
+      extractFromI64ArrayAttr(getPaddingDimensions());
+  SmallVector<int64_t> padToMultipleOf(paddingDimensions.size(), 1);
+  if (getPadToMultipleOf().has_value())
+    padToMultipleOf = extractFromI64ArrayAttr(*getPadToMultipleOf());
+  FailureOr<SmallVector<Value>> result =
+      rewriteAsPaddedOp(rewriter, target, paddingDimensions, padToMultipleOf,
+                        paddingValues, packPaddings, paddedOp);
   if (succeeded(result)) {
     // We need to perform our own replacement here because this API is still
     // used in patterns that "pad and hoist", for which the replacement values
@@ -1655,7 +1660,11 @@ LogicalResult transform::PadOp::verify() {
                             "integers, found "
                          << getPaddingDimensions();
   }
-
+  if (getPadToMultipleOf().has_value()) {
+    if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
+      return emitOpError() << "expects as many multiples as padding_dimensions";
+    }
+  }
   ArrayAttr transposes = getTransposePaddings();
   for (Attribute attr : transposes) {
     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 5beb2ffee545b..165c350ecf88a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -48,34 +48,50 @@ using namespace mlir::linalg;
 
 /// Pad the `opOperand` in the `paddingDimensions` using the padding value and
 /// the nofold flag found in `paddingValues` and `packPaddings`, respectively.
-/// Exit early and return the `opOperand` value if the shape dimensions that
-/// match `paddingDimensions` have a static size and the nofold flag is not set.
+///
+/// Exit early and return the `opOperand` value if it already has the requested
+/// shape. I.e.:
+/// - static shape
+/// - nofold is not set
+/// - dim sizes are multiples of `padToMultipleOf`
+///
 /// Otherwise, try to pad the shape dimensions that match the iterator
 /// dimensions `paddingDimensions` and return the tensor::PadOp result if
 /// padding succeeds or failure otherwise.
 static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
     RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
-    ArrayRef<int64_t> paddingDimensions, ArrayRef<Attribute> paddingValues,
-    ArrayRef<bool> packPaddings) {
+    ArrayRef<int64_t> paddingDimensions, ArrayRef<int64_t> padToMultipleOf,
+    ArrayRef<Attribute> paddingValues, ArrayRef<bool> packPaddings) {
+  assert(padToMultipleOf.size() == paddingDimensions.size() &&
+         "invalid number of elements in padToMultipleOf");
+
   AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
   ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
 
-  // Collect the shape dimension that are a function of the `paddingDimensions`.
-  llvm::SmallDenseSet<int64_t> shapeDimsToPad;
-  for (int64_t dim : paddingDimensions)
-    for (const auto &en : enumerate(indexingMap.getResults()))
-      if (en.value().isFunctionOfDim(dim))
-        shapeDimsToPad.insert(en.index());
+  // Collect the shape dimensions that are a function of `paddingDimensions`,
+  // along with the multiple that they should be padded to ("1" if none).
+  bool alreadyHasRequestedShape = true;
+  DenseMap<int64_t, int64_t> shapeDimToMultiple;
+  for (const auto &dimEn : enumerate(paddingDimensions)) {
+    for (const auto &en : enumerate(indexingMap.getResults())) {
+      if (en.value().isFunctionOfDim(dimEn.value())) {
+        int64_t dimSize = shape[en.index()];
+        shapeDimToMultiple[en.index()] = padToMultipleOf[dimEn.index()];
+        if (ShapedType::isDynamic(dimSize)) {
+          alreadyHasRequestedShape = false;
+        } else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
+          alreadyHasRequestedShape = false;
+        }
+      }
+    }
+  }
 
   // Return the unpadded operand if padding to a static shape is not needed and
   // if the nofold flag is not set.
   bool nofold = opOperand->getOperandNumber() < packPaddings.size()
                     ? packPaddings[opOperand->getOperandNumber()]
                     : false;
-  bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) {
-    return ShapedType::isDynamic(shape[dim]);
-  });
-  if (!nofold && hasStaticShape)
+  if (!nofold && alreadyHasRequestedShape)
     return opOperand->get();
 
   // Fail if `paddingValues` specifies no padding value.
@@ -86,12 +102,17 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
   Value paddingValue = rewriter.create<arith::ConstantOp>(
       opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
 
+  // Helper function to round a number up to a given multiple.
+  auto ceil = [](int64_t val, int64_t multiple) {
+    return ((val + multiple - 1) / multiple) * multiple;
+  };
+
   // Upper bound the sizes to obtain a static bounding box.
   SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
   for (int64_t i = 0, e = shape.size(); i < e; ++i) {
     LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
     // Skip dimensions that do not require padding.
-    if (!shapeDimsToPad.contains(i)) {
+    if (!shapeDimToMultiple.contains(i)) {
       LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
       continue;
     }
@@ -105,7 +126,7 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
       return rewriter.notifyMatchFailure(
           opToPad, "count not compute a bounding box for padding");
     }
-    paddedShape[i] = *upperBound;
+    paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]);
     LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
   }
 
@@ -131,9 +152,11 @@ getNParallelLoopsAttrs(unsigned nParallelLoops) {
 //===----------------------------------------------------------------------===//
 // rewriteAsPaddedOp transformation.
 //===----------------------------------------------------------------------===//
+
 FailureOr<SmallVector<Value>>
 linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                           ArrayRef<int64_t> paddingDimensions,
+                          ArrayRef<int64_t> padToMultipleOf,
                           ArrayRef<Attribute> paddingValues,
                           ArrayRef<bool> packPaddings, LinalgOp &paddedOp) {
   LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
@@ -153,8 +176,8 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
   newOperands.reserve(opToPad->getNumOperands());
   for (OpOperand &opOperand : opToPad->getOpOperands()) {
     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
-        rewriter, opToPad, &opOperand, paddingDimensions, paddingValues,
-        packPaddings);
+        rewriter, opToPad, &opOperand, paddingDimensions, padToMultipleOf,
+        paddingValues, packPaddings);
     // Exit if `paddingDimensions` cannot be bounded statically.
     if (failed(paddedOperand)) {
       LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : "
@@ -241,9 +264,13 @@ mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
 
   // Pad the operation.
   LinalgOp paddedOp;
-  FailureOr<SmallVector<Value>> newResults =
-      rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions,
-                        options.paddingValues, options.packPaddings, paddedOp);
+  SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
+  if (options.padToMultipleOf.has_value())
+    padToMultipleOf.assign(options.padToMultipleOf->begin(),
+                           options.padToMultipleOf->end());
+  FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
+      rewriter, linalgOp, options.paddingDimensions, padToMultipleOf,
+      options.paddingValues, options.packPaddings, paddedOp);
   if (failed(newResults))
     return rewriter.notifyMatchFailure(linalgOp,
                                        "failed to rewrite as a padded op");

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 78197aa7098d0..c14a1e1fbbc6c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -45,6 +45,39 @@ transform.sequence failures(propagate) {
 
 #map = affine_map<()[s0] -> (-s0 + 12, 7)>
 
+// CHECK-LABEL: @pad_to_multiple
+func.func @pad_to_multiple(%arg0: tensor<24x12xf32>,
+                           %arg1: tensor<12x25xf32>,
+                           %arg2: tensor<24x25xf32>,
+                           %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> {
+  %0 = affine.min #map()[%iv2]
+  %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
+  %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
+  %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
+
+  //      CHECK: linalg.matmul
+  // CHECK-SAME:     ins(%{{.*}}, %{{.*}} : tensor<4x7xf32>, tensor<7x6xf32>)
+  // CHECK-SAME:     outs(%{{.*}} : tensor<4x6xf32>)
+  %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
+  %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
+  func.return %5 : tensor<24x25xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.pad %0 {
+    padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
+    padding_dimensions=[0, 1, 2],
+    pad_to_multiple_of=[2, 2, 1],
+    pack_paddings=[1, 1, 0]
+  } : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+#map = affine_map<()[s0] -> (-s0 + 12, 7)>
+
 // CHECK-LABEL: @static_sizes_output_divisible_on_empty_op
 func.func @static_sizes_output_divisible_on_empty_op(%arg0: tensor<24x12xf32>,
     %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0: index,


        


More information about the Mlir-commits mailing list