[Mlir-commits] [mlir] 416ba22 - [mlir][linalg][transform] Support dynamic tile sizes in TileToForeachThreadOp

Matthias Springer llvmlistbot at llvm.org
Mon Aug 22 07:52:28 PDT 2022


Author: Matthias Springer
Date: 2022-08-22T16:48:45+02:00
New Revision: 416ba2256d2310abbfc185d1bce53715cf29092f

URL: https://github.com/llvm/llvm-project/commit/416ba2256d2310abbfc185d1bce53715cf29092f
DIFF: https://github.com/llvm/llvm-project/commit/416ba2256d2310abbfc185d1bce53715cf29092f.diff

LOG: [mlir][linalg][transform] Support dynamic tile sizes in TileToForeachThreadOp

TileToForeachThreadOp now accepts mixed SSA value operands / index attributes for tile_sizes and num_threads. (Reusing OperandsOrIntegersSizesList.) In case of an operand, a PDL_Operation must be specified that is mapped to a payload op that returns the tile size or number of threads.

Differential Revision: https://reviews.llvm.org/D131949

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 11c9b898ef58d..571d01b83d15f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -660,25 +660,31 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
 
 def TileToForeachThreadOp :
     Op<Transform_Dialect, "structured.tile_to_foreach_thread_op",
-      [FunctionalStyleTransformOpTrait,
-       MemoryEffectsOpInterface,
-       TransformEachOpTrait,
+      [AttrSizedOperandSegments,
+       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
        TransformOpInterface]> {
   let description = [{
-    Tile a TilingInterface op to a tiled `scf.foreach_thread`. Tiling is
-    applied by either specifying `num_threads` or `tile_size`. If `num_threads`
-    is specified, then the tile size for each dimension `i` is calculated
-    dynamically via `ceilDiv(dimSize[i], num_threads[i])`.
-    If non-empty, the `thread_dim_mapping` is added as an attribute to the
-    resulting `scf.foreach_thread`.
-    Zero tile sizes indicate that the dimension is not tiled and can be
+    Tile a TilingInterface op to a tiled `scf.foreach_thread`.
+
+    Tiling is applied by either specifying `num_threads` or `tile_size`. If
+    `num_threads` is specified, then the tile size for each dimension `i` is
+    calculated dynamically via `ceilDiv(dimSize[i], num_threads[i])`.
+    `num_threads` and `tile_size` can be either static index attributes or SSA
+    values of PDL operation handle type (or a mix thereof). Operation handles
+    must be mapped to exactly one op that has exactly one result of index type.
+
+    Static zero tile sizes indicate that the dimension is not tiled and can be
     thought of as tiling by the full size of data.
+
     It is the user's responsibility to ensure that `num_threads/tile_sizes` is
     a valid tiling specification (i.e. that only tiles parallel dimensions, 
     e.g. in the Linalg case).
-    
+
+    If non-empty, the `thread_dim_mapping` is added as an attribute to the
+    resulting `scf.foreach_thread`.
+
     #### Return modes
-    
+
     This operation ignores ops that do not implement the TilingInterface and
     drops them in the return.
 
@@ -696,36 +702,46 @@ def TileToForeachThreadOp :
     ### Example using `num_threads`
 
     ```
-    %0 = pdl_match @match_matmul in %arg1    
+    %0 = pdl_match @match_matmul in %arg1
     %3:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20]
     ```
 
     ### Example using `tile_sizes`
-    
+
     ```
-    %0 = pdl_match @match_matmul in %arg1    
-    %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20, 0]
+    %0 = pdl_match @match_matmul in %arg1
+    %sz = pdl_match @match_size_op in %arg1
+    %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [0, %sz, 20]
     ```
   }];
 
   let arguments = (ins PDL_Operation:$target,
-                   // TODO: dynamic number of threads.
-                   OptionalAttr<DefaultValuedAttr<I64ArrayAttr, "{}">>:$num_threads,
-                   OptionalAttr<DefaultValuedAttr<I64ArrayAttr, "{}">>:$tile_sizes,
+                   Variadic<PDL_Operation>:$num_threads,
+                   Variadic<PDL_Operation>:$tile_sizes,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$static_num_threads,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$static_tile_sizes,
                    OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
   let results = (outs PDL_Operation:$foreach_thread_op,
                       PDL_Operation:$tiled_op);
-
   let assemblyFormat = [{
-    $target  (`num_threads` $num_threads^) : (`tile_sizes` $tile_sizes)?
-      (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict
+    $target oilist(
+        `num_threads` custom<DynamicIndexList>($num_threads,
+                                               $static_num_threads,
+                                               "ShapedType::kDynamicSize") |
+         `tile_sizes` custom<DynamicIndexList>($tile_sizes,
+                                               $static_tile_sizes,
+                                               "ShapedType::kDynamicSize"))
+    (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict
   }];
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::TilingInterface target, 
-        ::llvm::SmallVectorImpl<::mlir::Operation *> &results, 
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformResults &transformResults,
         ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedNumThreads();
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedTileSizes();
   }];
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index be4f202cc0b39..819f34596f6a0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1020,33 +1020,137 @@ void transform::TileOp::getEffects(
 // TileToForeachThreadOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
-    TilingInterface target, SmallVectorImpl<Operation *> &results,
+DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
+    transform::TransformResults &transformResults,
     transform::TransformState &state) {
   IRRewriter rewriter(getContext());
-  rewriter.setInsertionPoint(target);
-  auto maybeThreadDimMappingAttr = getThreadDimMapping();
-  auto dimMapping =
-      llvm::to_vector(maybeThreadDimMappingAttr
-                          ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
-                          : ArrayRef<int64_t>{});
-
-  FailureOr<ForeachThreadTilingResult> tilingResult = failure();
-  if (Optional<ArrayAttr> numThreads = getNumThreads())
-    tilingResult = linalg::tileToForeachThreadOp(
-        rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
-
-  if (Optional<ArrayAttr> tileSizes = getTileSizes())
-    tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
-        rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
-
-  if (failed(tilingResult))
-    return emitDefaultSilenceableFailure(target);
-  rewriter.replaceOp(target, tilingResult->tileOp->getResults());
-  results.assign({tilingResult->tileOp, tilingResult->tiledOp});
+  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+
+  // If there the target payload ops are empty, there is nothing to do.
+  if (targets.empty()) {
+    transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
+    transformResults.set(getTiledOp().cast<OpResult>(), {});
+    return DiagnosedSilenceableFailure(success());
+  }
+
+  // Result payload ops.
+  SmallVector<Operation *> tileOps;
+  SmallVector<Operation *> tiledOps;
+
+  // Given a list of OpFoldResults that are either index attrs or op handles,
+  // return a list of OpFoldResults where all op handles are replaced with the
+  // first (and only) OpResult of that payload op. (There must be exactly one
+  // mapped payload op and it must have exactly one index result.)
+  auto getOpResultsOrIndexAttrs =
+      [&](SmallVector<OpFoldResult> &result,
+          ArrayRef<OpFoldResult> opHandlesOrIndexAttrs) {
+        for (OpFoldResult ofr : opHandlesOrIndexAttrs) {
+          if (ofr.is<Attribute>()) {
+            result.push_back(ofr);
+            continue;
+          }
+          ArrayRef<Operation *> dynamicNumThreads =
+              state.getPayloadOps(ofr.get<Value>());
+          if (dynamicNumThreads.size() != 1) {
+            DiagnosedSilenceableFailure diag =
+                emitSilenceableError()
+                << "handle must be mapped to exactly 1 payload op";
+            diag.attachNote(ofr.get<Value>().getLoc())
+                << "mapped to " << dynamicNumThreads.size() << " ops";
+            return diag;
+          }
+          Operation *op = dynamicNumThreads[0];
+          if (op->getNumResults() != 1 ||
+              !op->getResult(0).getType().isIndex()) {
+            DiagnosedSilenceableFailure diag =
+                emitSilenceableError()
+                << "payload op must have exactly 1 index result";
+            diag.attachNote(op->getLoc())
+                << "has " << op->getNumResults() << " results";
+            return diag;
+          }
+          result.push_back(op->getResult(0));
+        }
+
+        return DiagnosedSilenceableFailure(success());
+      };
+
+  // getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
+  // Convert to OpFoldResults[index attributes or payload op].
+  SmallVector<OpFoldResult> numThreads;
+  DiagnosedSilenceableFailure status =
+      getOpResultsOrIndexAttrs(numThreads, getMixedNumThreads());
+  if (!status.succeeded())
+    return status;
+
+  // getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
+  // Convert to OpFoldResults[index attributes or payload op].
+  SmallVector<OpFoldResult> tileSizes;
+  status = getOpResultsOrIndexAttrs(tileSizes, getMixedTileSizes());
+  if (!status.succeeded())
+    return status;
+
+  // Transform all targets one by one.
+  for (Operation *target : targets) {
+    auto tilableOp = dyn_cast<TilingInterface>(target);
+    if (!tilableOp) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError() << "only TilingInterface ops are supported";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+    rewriter.setInsertionPoint(tilableOp);
+    auto maybeThreadDimMappingAttr = getThreadDimMapping();
+    auto dimMapping = llvm::to_vector(
+        maybeThreadDimMappingAttr
+            ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
+            : ArrayRef<int64_t>{});
+
+    FailureOr<ForeachThreadTilingResult> tilingResult = failure();
+    if (!getMixedNumThreads().empty()) {
+      tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
+                                                   numThreads, dimMapping);
+    } else {
+      tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
+          rewriter, tilableOp, tileSizes, dimMapping);
+    }
+
+    if (failed(tilingResult))
+      return emitDefaultSilenceableFailure(tilableOp);
+    rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
+
+    tileOps.push_back(tilingResult->tileOp);
+    tiledOps.push_back(tilingResult->tiledOp);
+  }
+
+  transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
+  transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
+
   return DiagnosedSilenceableFailure(success());
 }
 
+void transform::TileToForeachThreadOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  onlyReadsHandle(getTileSizes(), effects);
+  onlyReadsHandle(getNumThreads(), effects);
+  producesHandle(getResults(), effects);
+}
+
+SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
+  return getMixedSizes(getStaticNumThreads(), getNumThreads());
+}
+
+SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
+  return getMixedSizes(getStaticTileSizes(), getTileSizes());
+}
+
+LogicalResult TileToForeachThreadOp::verify() {
+  if (getMixedNumThreads().empty() == getMixedTileSizes().empty())
+    return emitOpError("either num_threads or tile_sizes must be specified");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // VectorizeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index dd062908c0bf1..d519adb4dfd2c 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -203,3 +203,49 @@ module {
 //       CHECK:    %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]])
 //       CHECK:    scf.foreach_thread.perform_concurrently {
 //       CHECK:      tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32>
+
+// -----
+
+// In this test case, matmul dims and tile size are dynamic.
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+  //  CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+  //  CHECK-DAG: %[[tile_size:.*]] = "test.dummy"()
+  //  CHECK-DAG: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
+  //  CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+  //  CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size]]]
+  //  CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]]
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
+  //      CHECK    tensor.extract_slice %[[A]]
+  //      CHECK    tensor.extract_slice %[[B]]
+  //      CHECK    tensor.extract_slice %[[C]]
+  //      CHECK:   linalg.matmul
+  //      CHECK:   scf.foreach_thread.perform_concurrently
+  // CHECK-NEXT:    tensor.parallel_insert_slice
+  %tile_size = "test.dummy"() : () -> (index)
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %sz = transform.structured.match ops{["test.dummy"]} in %arg1
+    %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz, 20]
+  }
+}


        


More information about the Mlir-commits mailing list