[Mlir-commits] [mlir] 6370f75 - [mlir][Transform] Add support for dynamically unpacking tile_sizes / num_threads in tile_to_foreach_thread

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 14 04:40:09 PST 2022


Author: Nicolas Vasilache
Date: 2022-11-14T04:39:57-08:00
New Revision: 6370f75ad70c508b52a3d98250bad51e8d8138b6

URL: https://github.com/llvm/llvm-project/commit/6370f75ad70c508b52a3d98250bad51e8d8138b6
DIFF: https://github.com/llvm/llvm-project/commit/6370f75ad70c508b52a3d98250bad51e8d8138b6.diff

LOG: [mlir][Transform] Add support for dynamically unpacking tile_sizes / num_threads in tile_to_foreach_thread

This commit adds automatic unpacking of Value's of type pdl::OperationType to the underlying single-result OpResult.
This allows mixing single-value, attribute and multi-value pdl::Operation tile sizes and num threads to TileToForeachThreadOp.

Differential Revision: https://reviews.llvm.org/D137896

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b92ed19863069..4cfec70b0c07e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -796,6 +796,10 @@ def TileToForeachThreadOp :
     If non-empty, the `mapping` is added as an attribute to the
     resulting `scf.foreach_thread`.
 
+    Note: `tile_sizes` and `num_threads` are variadic. Each tile size/number of
+    threads can be an index attribute or a transform handle that is mapped to
+    exactly one payload op with exactly one index result.
+
     #### Return modes
 
     This operation ignores ops that do not implement the TilingInterface and
@@ -857,7 +861,7 @@ def TileToForeachThreadOp :
                    "ArrayRef<OpFoldResult>":$mixedNumThreads,
                    CArg<"::mlir::transform::NumThreadsSpec", 
                         "::mlir::transform::NumThreadsSpec()">,
-                   CArg<"ArrayAttr", "{}">:$mapping)>,
+                   CArg<"ArrayAttr", "{}">:$mapping)>
   ];
 
   let assemblyFormat = [{

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index cdd4e158f9b1c..960fabee05b59 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1327,9 +1327,12 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
                                              ArrayRef<int64_t> staticTileSizes,
                                              transform::TileSizesSpec,
                                              ArrayAttr mapping) {
-  return build(builder, result, target,
+  return build(builder, result,
+               /*target=*/target,
+               /*mixedTileSizes=*/
                getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
-               TileSizesSpec(), mapping);
+               /*_=*/TileSizesSpec(),
+               /*mapping=*/mapping);
 }
 
 void transform::TileToForeachThreadOp::build(
@@ -1346,9 +1349,14 @@ void transform::TileToForeachThreadOp::build(
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
-  build(builder, result, TypeRange{operationType, operationType}, target,
-        /*numThreads=*/ValueRange{}, dynamicTileSizes,
-        /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping);
+  build(builder, result,
+        /*resultTypes=*/TypeRange{operationType, operationType},
+        /*target=*/target,
+        /*numThreads=*/ValueRange{},
+        /*tileSizes=*/dynamicTileSizes,
+        /*staticNumThreads=*/builder.getI64ArrayAttr({}),
+        /*staticTileSizes=*/staticTileSizesAttr,
+        /*mapping=*/mapping);
 }
 
 void transform::TileToForeachThreadOp::build(OpBuilder &builder,
@@ -1376,11 +1384,48 @@ void transform::TileToForeachThreadOp::build(
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
-  build(builder, result, TypeRange{operationType, operationType}, target,
-        dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr,
-        /*staticTileSizes=*/ArrayAttr(), mapping);
+  build(builder, result,
+        /*resultTypes=*/TypeRange{operationType, operationType},
+        /*target=*/target,
+        /*numThreads=*/dynamicNumThreads,
+        /*tileSizes=*/ValueRange{},
+        /*staticNumThreads=*/staticNumThreadsAttr,
+        /*staticTileSizes=*/builder.getI64ArrayAttr({}),
+        /*mapping=*/mapping);
 }
 
+// Given a list of OpFoldResults that are either index attrs or op
+// handles, return a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op. (There
+// must be exactly one mapped payload op and it must have exactly one
+// index result.)
+static DiagnosedSilenceableFailure unpackPDLOperations(
+    transform::TransformState &state, TransformOpInterface transformOp,
+    SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
+  for (OpFoldResult ofr : ofrs) {
+    // Don't try to unpack non-PDL operation.
+    if (ofr.is<Attribute>() ||
+        !ofr.get<Value>().getType().isa<pdl::OperationType>()) {
+      result.push_back(ofr);
+      continue;
+    }
+    ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
+    for (Operation *op : payloadOps) {
+      if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+        DiagnosedSilenceableFailure diag =
+            transformOp.emitSilenceableError()
+            << "payload op must have exactly 1 index result";
+        diag.attachNote(op->getLoc())
+            << "has " << op->getNumResults() << " results";
+        return diag;
+      }
+      result.push_back(op->getResult(0));
+    }
+  }
+
+  return DiagnosedSilenceableFailure(success());
+};
+
 DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
     RewriterBase &rewriter, transform::TransformState &state,
     TransformOpInterface transformOp, ArrayRef<Operation *> targets,
@@ -1390,56 +1435,18 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
   if (targets.empty())
     return DiagnosedSilenceableFailure(success());
 
-  // Given a list of OpFoldResults that are either index attrs or op handles,
-  // return a list of OpFoldResults where all op handles are replaced with the
-  // first (and only) OpResult of that payload op. (There must be exactly one
-  // mapped payload op and it must have exactly one index result.)
-  auto getOpResultsOrIndexAttrs =
-      [&](SmallVector<OpFoldResult> &result,
-          ArrayRef<OpFoldResult> opHandlesOrIndexAttrs) {
-        for (OpFoldResult ofr : opHandlesOrIndexAttrs) {
-          if (ofr.is<Attribute>()) {
-            result.push_back(ofr);
-            continue;
-          }
-          ArrayRef<Operation *> dynamicNumThreads =
-              state.getPayloadOps(ofr.get<Value>());
-          if (dynamicNumThreads.size() != 1) {
-            DiagnosedSilenceableFailure diag =
-                transformOp.emitSilenceableError()
-                << "handle must be mapped to exactly 1 payload op";
-            diag.attachNote(ofr.get<Value>().getLoc())
-                << "mapped to " << dynamicNumThreads.size() << " ops";
-            return diag;
-          }
-          Operation *op = dynamicNumThreads[0];
-          if (op->getNumResults() != 1 ||
-              !op->getResult(0).getType().isIndex()) {
-            DiagnosedSilenceableFailure diag =
-                transformOp.emitSilenceableError()
-                << "payload op must have exactly 1 index result";
-            diag.attachNote(op->getLoc())
-                << "has " << op->getNumResults() << " results";
-            return diag;
-          }
-          result.push_back(op->getResult(0));
-        }
-
-        return DiagnosedSilenceableFailure(success());
-      };
-
   // getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
   // Convert to OpFoldResults[index attributes or payload op].
   SmallVector<OpFoldResult> numThreads;
   DiagnosedSilenceableFailure status =
-      getOpResultsOrIndexAttrs(numThreads, mixedNumThreads);
+      unpackPDLOperations(state, transformOp, numThreads, mixedNumThreads);
   if (!status.succeeded())
     return status;
 
   // getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
   // Convert to OpFoldResults[index attributes or payload op].
   SmallVector<OpFoldResult> tileSizes;
-  status = getOpResultsOrIndexAttrs(tileSizes, mixedTileSizes);
+  status = unpackPDLOperations(state, transformOp, tileSizes, mixedTileSizes);
   if (!status.succeeded())
     return status;
 
@@ -1488,8 +1495,11 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
       getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps,
       tiledOps);
 
-  if (!diag.succeeded())
+  if (!diag.succeeded()) {
+    transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
+    transformResults.set(getTiledOp().cast<OpResult>(), {});
     return diag;
+  }
 
   transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
   transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);

diff  --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index f015cdff93fa4..4dda150ca5341 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -41,6 +41,48 @@ module {
 
 // -----
 
+// In this test case, matmul dims and tile size are dynamic.
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+  //  CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+  //  CHECK-DAG: %[[tile_size_1:.*]] = "test.dummy"()
+  //  CHECK-DAG: %[[tile_size_2:.*]] = "test.dummy"()
+  //  CHECK-DAG: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
+  //  CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+  //  CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size_1]]]
+  //  CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map0]]()[%[[N]], %[[tile_size_2]]]
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+  //      CHECK:   tensor.extract_slice %[[A]]
+  //      CHECK:   tensor.extract_slice %[[B]]
+  //      CHECK:   tensor.extract_slice %[[C_BLK]]
+  //      CHECK:   linalg.matmul
+  //      CHECK:   scf.foreach_thread.perform_concurrently
+  // CHECK-NEXT:    tensor.parallel_insert_slice
+  %tile_size_1 = "test.dummy"() : () -> (index)
+  %tile_size_2 = "test.dummy"() : () -> (index)
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+  %sz = transform.structured.match ops{["test.dummy"]} in %arg1
+  %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz]
+}
+
+// -----
+
 // Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>


        


More information about the Mlir-commits mailing list