[Mlir-commits] [mlir] 3963b4d - [mlir] Transform op for multitile size generation

Alex Zinenko llvmlistbot at llvm.org
Tue Jul 12 05:36:38 PDT 2022


Author: Alex Zinenko
Date: 2022-07-12T12:36:28Z
New Revision: 3963b4d0dc5bf2bb92eedbab91e2c11653cd8f4e

URL: https://github.com/llvm/llvm-project/commit/3963b4d0dc5bf2bb92eedbab91e2c11653cd8f4e
DIFF: https://github.com/llvm/llvm-project/commit/3963b4d0dc5bf2bb92eedbab91e2c11653cd8f4e.diff

LOG: [mlir] Transform op for multitile size generation

Introduce a structured transform op that emits IR computing the multi-tile
sizes with requested parameters (target size and divisor) for the given
structured op. The sizes may fold to arithmetic constant operations when the
shape is constant. These operations may then be used to call the existing
tiling transformation with a single non-zero dynamic size (i.e. perform
strip-mining) for each of the dimensions separately, thus achieving multi-size
tiling with optional loop interchange. A separate test exercises the entire
script.

Depends On D129217

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129287

Added: 
    mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
    mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 021158f873b04..39ba9983c9052 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -127,6 +127,71 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   }];
 }
 
+def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Emits the IR computing the tile sizes `s1` and `s2` such that:
+
+      - there exists a combination of `n` tiles of size `s1` and `m` tiles of
+        size `s2` that covers the entirety of the iteration space `dimension` of
+        the target structured op;
+      - `s1`, `s2` is less than or equal to `target_size`;
+      - `s1` and `s2` are divisible by `divisor.
+
+    For example, for a dimension of size 54 with target size 12 and divisor 2,
+    this can emit the IR computing the tile size 10, used for 3 tiles, and 12,
+    used for 2 tiles, totally 10*3 + 12*2 = 54. Note that when the divisor does
+    not divide the original dimension size, it is impossible to compute such
+    tile sizes. An assertion is emitted to guard against this in the dynamic
+    case.
+
+    Expects the target size and the divisor to be strictly positive. Folds the
+    IR as much as possible, normally obtaining constant sizes and numbers of
+    tiles for a statically known dimension.
+
+    This does *not* consume the target handle and produces three handles each
+    pointing to single-result index-typed operations (which may be arithmetic
+    constant operations) defining the two respective tile sizes and the product
+    of the first tile size with the number of tiles of that size (useful for
+    splitting the iteration space).
+
+    This operation composes with the regular tiling when applied per-dimension:
+
+    ```mlir
+    %sz1, %sz2, %split = structured.multitile_sizes %target
+                         { target_size = 10, dimension = 1 }
+    %low, %high = structured.split %target after %split { dimension = 1 }
+    %tiled_low = structured.tile %low [0, %sz1]
+    %tiled_high = structured.tile %high [0, %sz2]
+    %common = merge_handles %tiled_low, %tiled_high
+
+    %sz3, %sz4, %split = structured.multitile_size %target
+                         { target_size = 42, dimension = 0 }
+    %sz3r, %sz4r, %splitr = replicate num(%common) %sz3, %sz4, %splitr
+    structured.split %common after %splitr { dimension = 0 }
+    // ...
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                       I64Attr:$dimension,
+                       I64Attr:$target_size,
+                       DefaultValuedAttr<I64Attr, "1">:$divisor);
+  let results = (outs PDL_Operation:$low_size,
+                      PDL_Operation:$high_size,
+                      PDL_Operation:$split_point);
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::linalg::LinalgOp target, 
+        ::llvm::SmallVector<::mlir::Operation *> &results,
+        TransformState &state);
+  }];
+}
+
+
 def PadOp : Op<Transform_Dialect, "structured.pad",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7b45112677183..47cd647d7dbab 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -479,6 +479,48 @@ std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
 makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
                     ValueRange allShapeSizes, ValueRange allTileSizes);
 
+/// A description of a multi-size tiling comprising tile sizes and numbers of
+/// tiles, expressed as Values which may or may not be constant. Multi-size
+/// currently means two-size.
+struct MultiSizeSpecification {
+  /// Tile sizes.
+  Value lowTileSize, highTileSize;
+  /// Number of tiles associated with each size.
+  Value lowTripCount, highTripCount;
+};
+
+/// Emits the IR computing the multi-sized tiling specification with two tile
+/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such that
+/// there exist numbers of tiles with these sizes that fully cover the given
+/// iteration space `dimension` of the structured `op`.
+///
+/// The computation is as follows:
+///
+///   b = originalTripCount floordiv sizeDivisor
+///   t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor
+///   d = (b + t - 1) floordiv t
+///   s = (b floordiv d) * sizeDivisor
+///   v = b % d
+///   u = d - v
+///
+/// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of
+/// the corresponding tiles are `u` and `v`, respectively.  Alternatively,
+///
+///   s * u + (s + sizeDivisor) * v == original size,
+///   where s mod sizeDivisor = 0.
+///
+/// Expects all values to be positive. In some cases with the target tile size
+/// sufficiently close to the dimension shape and non-unit divisor, it is
+/// impossible to compute such sizes. If `emitAssertion` is set, also emit the
+/// assertion that size computation succeeded.
+///
+/// Returns the specification consisting of both tile values and the number of
+/// tiles of each size.
+FailureOr<MultiSizeSpecification>
+computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
+                      OpFoldResult targetSize, OpFoldResult divisor,
+                      bool emitAssertions = true);
+
 /// All indices returned by IndexOp should be invariant with respect to tiling.
 /// Therefore, if an operation is tiled, we have to transform the indices
 /// accordingly, i.e. offset them by the values of the corresponding induction

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index 819c42f83eb62..7d10433a80b2b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -8,11 +8,14 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
   MLIRLinalgTransformOpsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAffineDialect
+  MLIRArithmeticDialect
   MLIRIR
   MLIRLinalgDialect
   MLIRLinalgTransforms
   MLIRParser
   MLIRPDLDialect
+  MLIRSCFDialect
   MLIRSideEffectInterfaces
   MLIRTransformDialect
   MLIRVectorDialect

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ab35b06157b51..f1a9dcd7f0cd0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -8,12 +8,14 @@
 
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -276,6 +278,55 @@ LogicalResult transform::InterchangeOp::verify() {
   return success();
 }
 
+//===---------------------------------------------------------------------===//
+// MultiTileSizesOp
+//===---------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
+    LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
+  OpBuilder builder(target.getContext());
+  builder.setInsertionPoint(target);
+  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
+  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
+      builder, target, getDimension(), targetSize, divisor);
+  if (failed(spec)) {
+    return emitSilenceableError() << "could not generate tile size computation";
+  }
+
+  Operation *splitPoint =
+      builder
+          .createOrFold<arith::MulIOp>(target.getLoc(), spec->lowTileSize,
+                                       spec->lowTripCount)
+          .getDefiningOp();
+  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
+  Operation *highTileSize = spec->highTileSize.getDefiningOp();
+  assert(lowTileSize && highTileSize && splitPoint &&
+         "tile sizes are not produced by operations");
+  results.reserve(results.size() + 3);
+  results.push_back(lowTileSize);
+  results.push_back(highTileSize);
+  results.push_back(splitPoint);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MultiTileSizesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
+                       transform::TransformMappingResource::get());
+  for (Value result : getResults()) {
+    effects.emplace_back(MemoryEffects::Allocate::get(), result,
+                         transform::TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(), result,
+                         transform::TransformMappingResource::get());
+  }
+
+  effects.emplace_back(MemoryEffects::Read::get(),
+                       transform::PayloadIRResource::get());
+  effects.emplace_back(MemoryEffects::Write::get(),
+                       transform::PayloadIRResource::get());
+}
+
 //===---------------------------------------------------------------------===//
 // PadOp
 //===---------------------------------------------------------------------===//
@@ -782,6 +833,7 @@ class LinalgTransformDialectExtension
           LinalgTransformDialectExtension> {
 public:
   LinalgTransformDialectExtension() {
+    declareDependentDialect<AffineDialect>();
     declareDependentDialect<arith::ArithmeticDialect>();
     declareDependentDialect<pdl::PDLDialect>();
     declareDependentDialect<scf::SCFDialect>();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index a7524b7ebfc80..d55876f46b446 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -13,6 +13,7 @@
 #include <utility>
 
 #include "PassDetail.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -82,6 +83,92 @@ void mlir::linalg::transformIndexOps(
   addTileLoopIvsToIndexOpResults(b, op, allIvs);
 }
 
+/// Asserts that the given index-typed value is strictly positive. If the value
+/// is an attribute, asserts at compile time, otherwise emits an assertion
+/// checked at runtime.
+static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
+                                         OpFoldResult value) {
+  if (auto attr = value.dyn_cast<Attribute>()) {
+    assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() &&
+           "expected strictly positive tile size and divisor");
+    return;
+  }
+
+  Value zero = b.create<arith::ConstantIndexOp>(0);
+  Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
+                                            value.get<Value>(), zero);
+  b.create<cf::AssertOp>(
+      condition,
+      b.getStringAttr("expected strictly positive tile size and divisor"));
+}
+
+FailureOr<MultiSizeSpecification>
+mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
+                                    unsigned dimension, OpFoldResult targetSize,
+                                    OpFoldResult divisor, bool emitAssertions) {
+  // Bail out on dimension overflow.
+  if (dimension >= op.getNumLoops())
+    return failure();
+
+  // The code below works only on values.
+  ImplicitLocOpBuilder b(op.getLoc(), builder);
+  if (emitAssertions) {
+    emitIsPositiveIndexAssertion(b, targetSize);
+    emitIsPositiveIndexAssertion(b, divisor);
+  }
+  Value targetSizeValue = materializeOpFoldResult(b, targetSize);
+  Value divisorValue = materializeOpFoldResult(b, divisor);
+
+  // Find the trip count of the iteration space dimension for which the tile
+  // sizes are computed.
+  // TODO: update createFlatListOfOperandDims to return OpFoldResults and avoid
+  // littering by useless constant materialization.
+  SmallVector<Value, 4> allShapes =
+      op.createFlatListOfOperandDims(b, b.getLoc());
+  AffineMap shapesToLoops = op.getShapesToLoopsMap();
+  SmallVector<Value, 4> loopRanges =
+      applyMapToValues(b, op.getLoc(), shapesToLoops, allShapes);
+  Value tripCount = loopRanges[dimension];
+
+  // Compute the tile sizes and the respective numbers of tiles.
+  AffineExpr s0 = b.getAffineSymbolExpr(0);
+  AffineExpr s1 = b.getAffineSymbolExpr(1);
+  AffineExpr s2 = b.getAffineSymbolExpr(2);
+  auto apply = [&](AffineExpr expr, ValueRange values) -> Value {
+    return makeComposedAffineApply(b, b.getLoc(), expr, values);
+  };
+  Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
+  Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
+  Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
+  Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
+  Value v = apply(s0 % s1, {a, d});
+  Value u = apply(s0 - s1, {d, v});
+
+  MultiSizeSpecification spec;
+  spec.lowTileSize = s;
+  spec.highTileSize = apply(s0 + s1, {s, divisorValue});
+  spec.lowTripCount = u;
+  spec.highTripCount = v;
+
+  // If requested, emit the check that the tile sizes are computed correctly.
+  // For example, for iteration dimension size of 15 and the target size 8 it is
+  // impossible to find two tile sizes both divisible by 8 that fully cover the
+  // original space dimension.
+  if (emitAssertions) {
+    AffineExpr s3 = builder.getAffineSymbolExpr(3);
+    Value coveredSize =
+        apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
+                                  spec.highTileSize, spec.highTripCount});
+    Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                           coveredSize, tripCount);
+    b.create<cf::AssertOp>(
+        equals, builder.getStringAttr(
+                    "could not compute dynamic multi-size tile shapes"));
+  }
+
+  return spec;
+}
+
 // Insert a tile `source` into the destination tensor `dest`. The position at
 // which the tile is inserted (as well as size of tile) is taken from a given
 // ExtractSliceOp `sliceOp`.

diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index b6e078fc78b3b..95bf2cc992954 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -110,6 +110,29 @@ def __init__(self,
         ip=ip)
 
 
+class MultiTileSizesOp:
+  """Specialization for MultitileSizesOp class."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               dimension: Union[int, IntegerAttr],
+               target_size: Union[int, IntegerAttr],
+               divisor: Optional[Union[int, IntegerAttr]] = None,
+               loc=None,
+               ip=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        pdl.OperationType.get(),
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        dimension=_get_int64_attr(dimension),
+        target_size=_get_int64_attr(target_size),
+        divisor=_get_int64_attr(divisor if divisor else 1),
+        loc=loc,
+        ip=ip)
+
+
 class PadOp:
   """Specialization for PadOp class."""
 

diff  --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
new file mode 100644
index 0000000000000..e30a140535fcd
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -0,0 +1,114 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter --canonicalize %s | FileCheck %s
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @linalg_generic : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  // This implements a 2D multisize tiling with target sizes [3, 10].
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @linalg_generic in %arg1
+    %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3}
+    %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10}
+    %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 }
+    %3:2 = transform.structured.tile %2#0 [%1#0]
+    %4:2 = transform.structured.tile %2#1 [%1#1]
+    %5 = merge_handles %3#0, %4#0
+    %tt:3 = replicate num(%5) %t#0, %t#1, %t#2
+    %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 }
+    transform.structured.tile %6#0 [0, %tt#0]
+    transform.structured.tile %6#1 [0, %tt#1]
+  }
+}
+
+func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
+
+// CHECK-DAG: #[[$MAP_MIN_4_2:.+]] = affine_map<(d0) -> (-d0 + 4, 2)>
+// CHECK-DAG: #[[$MAP_MIN_16_8:.+]] = affine_map<(d0) -> (-d0 + 16, 8)>
+
+// CHECK-LABEL: @two_d
+// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
+func.func @two_d(%arg0: tensor<10x34xf32>,
+                 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+  ins(%arg0: tensor<10x34xf32>)
+  outs(%arg1: tensor<10x34xf32>) {
+  ^bb0(%0: f32, %1: f32):
+    %i = linalg.index 0 : index
+    %j = linalg.index 1 : index
+    %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
+    linalg.yield %call_res : f32
+  } -> tensor<10x34xf32>
+
+  // 2D multi-size tiling should produce for quadrants with sizes
+  //   (2, 8), (2, 9), (3, 8), (3, 9)
+  // respectively, and in this order.
+  // Check the full code for the first quadrant, the data flow for the second
+  // quadrant and only the overall code structure for the remaining quadrants.
+  //
+  // TODO: unfortunately, the canonicalization is insufficiently powerful to
+  // remove the affine min for sizes, leading to dynamic sizes even when tiling
+  // statically-shaped operation with constant tile sizes.
+
+  // CHECK:      %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
+  // CHECK:      scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
+  // CHECK:        %[[SZ1:.+]] = affine.min #[[$MAP_MIN_4_2]](%[[I1]])
+  // CHECK:        %[[INSLICE_1:.+]] = tensor.extract_slice %[[IN]][%[[I1]], 0] [%[[SZ1]], 34] [1, 1]
+  // CHECK:        %[[SZ2:.+]] = affine.min #[[$MAP_MIN_4_2]](%[[I1]])
+  // CHECK:        %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [%[[SZ2]], 34] [1, 1]
+
+  // CHECK:        %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [%[[SZ1]], 16] [1, 1]
+  // CHECK:        %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
+  // CHECK:          %[[SZ3:.+]] = affine.min #[[$MAP_MIN_16_8]](%[[I2]])
+  // CHECK:          %[[INSLICE_2:.+]] = tensor.extract_slice %[[INSLICE_1]][0, %[[I2]]] [%[[SZ1]], %[[SZ3]]] [1, 1]
+  // CHECK:          %[[SZ4:.+]] = tensor.dim %[[ITERARG_2]]
+  // CHECK:          %[[SZ5:.+]] = affine.min #[[$MAP_MIN_16_8]](%[[I2]])
+  // CHECK:          %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [%[[SZ4]], %[[SZ5]]] [1, 1]
+
+  // CHECK:          %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<?x?xf32>) outs(%[[OUTSLICE_2]] : tensor<?x?xf32>)
+  // CHECK:          %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
+  // CHECK:          scf.yield %[[RESPARTIAL]]
+
+  // CHECK:        %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [%[[SZ1]], 16] [1, 1]
+  // CHECK:        %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [%[[SZ1]], 18] [1, 1]
+  // CHECK:        scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
+  // CHECK-COUNT-2:  tensor.extract_slice
+  // CHECK:          linalg.generic
+  // CHECK:          tensor.insert_slice
+  // CHECK:          scf.yield
+  // CHECK:        %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]]
+  // CHECK:        %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]]
+  // CHECK:        scf.yield %[[INSERTED_3]]
+
+  // CHECK:        tensor.insert_slice
+  // CHECK:        tensor.extract_slice
+  // CHECK:        scf.for
+  // CHECK-COUNT-3:  tensor.extract_slice
+  // CHECK:          scf.for
+  // CHECK-COUNT-2:    tensor.extract_slice
+  // CHECK:            linalg.generic
+  // CHECK:            tensor.insert_slice
+  // CHECK:            scf.yield
+  // CHECK:          tensor.insert_slice
+  // CHECK:          tensor.extract_slice
+  // CHECK:          scf.for
+  // CHECK-COUNT-2:    tensor.extract_slice
+  // CHECK:            linalg.generic
+  // CHECK:            tensor.insert_slice
+  // CHECK:            scf.yield
+  // CHECK-COUNT-2:  tensor.insert_slice
+  // CHECK:          scf.yield
+  // CHECK:        %[[RESULT:.+]] = tensor.insert_slice
+  // CHECK:        return %[[RESULT]]
+
+  return %0 : tensor<10x34xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir b/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir
new file mode 100644
index 0000000000000..08fa9348b8ee8
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$MAP13:.+]] = affine_map<() -> (13)>
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  sequence %arg0 {
+    ^bb0(%arg1: !pdl.operation):
+      %0 = pdl_match @pdl_target in %arg1
+      transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 }
+  }
+
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %0 with "transform.dialect"
+  }
+}
+
+// CHECK-LABEL: @multitile_sizes_static
+func.func @multitile_sizes_static(
+  %arg0: tensor<13x34xf32>, %arg1: tensor<34x42xf32>, %arg2: tensor<13x42xf32>)
+    -> tensor<13x42xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<13x34xf32>, tensor<34x42xf32>)
+                     outs(%arg2: tensor<13x42xf32>)
+    -> tensor<13x42xf32>
+  // The first application computes the total size.
+  // CHECK: %{{.*}} = affine.apply #[[$MAP13]]()
+  // CHECK: %[[SIZE:.+]] = affine.apply #[[$MAP13]]()
+  // CHECK: %[[COND:.+]] = arith.cmpi eq, %[[SIZE]], %{{.*}}
+  // CHECK: cf.assert %[[COND]], "could not compute dynamic multi-size tile shapes"
+
+  return %0 : tensor<13x42xf32>
+}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  sequence %arg0 {
+    ^bb0(%arg1: !pdl.operation):
+      %0 = pdl_match @pdl_target in %arg1
+      transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 }
+  }
+
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %0 with "transform.dialect"
+  }
+}
+
+// CHECK: #[[$MAP_A:.+]] = affine_map<()[s0] -> ([[A_IMPL:s0 floordiv 2]])>
+// CHECK: #[[$MAP_T:.+]] = affine_map<() -> (2)>
+// CHECK: #[[$MAP_D:.+]] = affine_map<()[s0] -> ([[D_IMPL:\(s0 floordiv 2 \+ 1\) floordiv 2]])>
+// CHECK: #[[$MAP_S:.+]] = affine_map<()[s0] -> ((([[A_IMPL]]) floordiv ([[D_IMPL]])) * 2)>
+// CHECK: #[[$MAP_V:.+]] = affine_map<()[s0] -> (([[A_IMPL]]) mod ([[D_IMPL]]))>
+// CHECK: #[[$MAP_U:.+]] = affine_map<()[s0] -> ([[D_IMPL]] - ([[A_IMPL]]) mod ([[D_IMPL]]))>
+
+// CHECK-LABEL: @multitile_sizes_dynamic
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>, %{{.*}}: tensor<?x?xf32>, %{{.*}}: tensor<?x?xf32>)
+func.func @multitile_sizes_dynamic(
+  // For matmul, the extent of the first iteration space dimension is equal to
+  // the size of the first dimension of the first tensor. The indexing map was
+  // folded so there is no map application happening.
+  //
+  // CHECK: %[[C0:.+]] = arith.constant 0
+  // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+  //
+  // The following are the maps as emitted by computeMultiTileSizes.
+  // CHECK: affine.apply #[[$MAP_A]]()[%[[DIM]]]
+  // CHECK: affine.apply #[[$MAP_T]]()
+  // CHECK: affine.apply #[[$MAP_D]]()[%[[DIM]]]
+  // CHECK: affine.apply #[[$MAP_S]]()[%[[DIM]]]
+  // CHECK: affine.apply #[[$MAP_V]]()[%[[DIM]]]
+  // CHECK: affine.apply #[[$MAP_U]]()[%[[DIM]]]
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+  
+  return %0 : tensor<?x?xf32>
+}

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index cd4412f92f19e..9d2641c6e498f 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -54,6 +54,20 @@ def testInterchange():
   # CHECK: iterator_interchange = [1, 0]
 
 
+ at run
+def testMultitileSizes():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.MultiTileSizesOp(
+        sequence.bodyTarget, dimension=1, target_size=42)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testMultitileSizes
+  # CHECK: transform.sequence
+  # CHECK: transform.structured.multitile_sizes
+  # CHECK-DAG: dimension = 1
+  # CHECK-DAG: target_size = 42
+
+
 @run
 def testPad():
   sequence = transform.SequenceOp()

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e86b41c05bd23..9cb214743eea1 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7461,7 +7461,9 @@ cc_library(
     ],
     includes = ["include"],
     deps = [
+        ":AffineDialect",
         ":ArithmeticDialect",
+        ":ControlFlowDialect",
         ":IR",
         ":LinalgDialect",
         ":LinalgTransformOpsIncGen",


        


More information about the Mlir-commits mailing list