[Mlir-commits] [mlir] Add support of param type for transform.structured.tile_using_forall (PR #72097)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 16 21:12:11 PST 2023


https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/72097

>From 48c045ee1accdf66814930fd56fb5b8ef10bd318 Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Mon, 13 Nov 2023 02:19:39 -0800
Subject: [PATCH] Add support of param type for
 transform.structured.tile_using_forall

---
 .../Linalg/TransformOps/LinalgTransformOps.td | 10 +--
 .../TransformOps/LinalgTransformOps.cpp       | 63 ++++++++++----
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  | 87 +++++++++++++++++++
 3 files changed, 137 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..a24f6ff8308ba34 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -23,7 +23,7 @@ include "mlir/IR/RegionKindInterface.td"
 // value in the payload IR.
 def TransformParamTypeOrAnyHandle : Type<
     Or<[TransformHandleTypeInterface.predicate,
-        Transform_ParamType.predicate]>,
+        TransformParamTypeInterface.predicate]>,
     "transform 'param' type or any handle type">;
 
 //===----------------------------------------------------------------------===//
@@ -1924,10 +1924,10 @@ def TileUsingForallOp :
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   Variadic<TransformHandleTypeInterface>:$num_threads,
-                   Variadic<TransformHandleTypeInterface>:$tile_sizes,
-                   Optional<TransformHandleTypeInterface>:$packed_num_threads,
-                   Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
+                   Variadic<TransformParamTypeOrAnyHandle>:$num_threads,
+                   Variadic<TransformParamTypeOrAnyHandle>:$tile_sizes,
+                   Optional<TransformParamTypeOrAnyHandle>:$packed_num_threads,
+                   Optional<TransformParamTypeOrAnyHandle>:$packed_tile_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..3615cd784027200 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -86,8 +86,9 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
   return cast<LinalgOp>(result->getOperation());
 }
 
-/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
-/// to exactly one op with one index result, return that value.
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, return that value.
 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
@@ -98,12 +99,22 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
       result.push_back(ofr);
       continue;
     }
-    auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+
+    Value transformValue = ofr.get<Value>();
+    if (isa<ParamType>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (!isa<IntegerAttr>(params[0]))
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      result.push_back(params[0]);
+      continue;
+    }
+
+    auto payloadOps = state.getPayloadOps(transformValue);
     if (!llvm::hasSingleElement(payloadOps)) {
       DiagnosedSilenceableFailure diag =
           transformOp.emitSilenceableError()
           << "handle must be mapped to exactly one payload op";
-      diag.attachNote(ofr.get<Value>().getLoc())
+      diag.attachNote(transformValue.getLoc())
           << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
       return diag;
     }
@@ -123,24 +134,40 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
   return DiagnosedSilenceableFailure::success();
 }
 
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
+// Given a list of params that are index attrs or a list of OpFoldResults
+// that are either index attrs or op handles, return a list of OpFoldResults
+// of index attrs or a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op.
+// (There must be exactly one parameter associated with the AnyParamType or
+// one mapped payload op which must have exactly one index result.)
 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, Value packedHandle) {
-  for (Operation *op : state.getPayloadOps(packedHandle)) {
-    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
-      DiagnosedSilenceableFailure diag =
-          transformOp.emitSilenceableError()
-          << "payload op must have exactly 1 index result";
-      diag.attachNote(op->getLoc())
-          << "has " << op->getNumResults() << " results";
-      return diag;
+  if (isa<AnyParamType>(packedHandle.getType())) {
+    ArrayRef<Attribute> params = state.getParams(packedHandle);
+    if (params.size() != 1)
+      return transformOp.emitDefiniteFailure()
+             << "requires exactly one parameter associated";
+    ArrayAttr paramsArray = dyn_cast<ArrayAttr>(params[0]);
+    if (!paramsArray)
+      return transformOp.emitDefiniteFailure() << "param array is null";
+    for (Attribute param : paramsArray.getValue()) {
+      if (!isa<IntegerAttr>(param))
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      result.push_back(param);
+    }
+  } else {
+    for (Operation *op : state.getPayloadOps(packedHandle)) {
+      if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+        DiagnosedSilenceableFailure diag =
+            transformOp.emitSilenceableError()
+            << "payload op must have exactly 1 index result";
+        diag.attachNote(op->getLoc())
+            << "has " << op->getNumResults() << " results";
+        return diag;
+      }
+      result.push_back(op->getResult(0));
     }
-    result.push_back(op->getResult(0));
   }
 
   return DiagnosedSilenceableFailure::success();
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 38742028e481012..550f8f88719988a 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -451,3 +451,90 @@ module attributes {transform.with_named_sequence} {
     }
   }
 
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic_param(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic_param(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  //      CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
+  //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+  //      CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
+  //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+  //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+  //      CHECK:   tensor.extract_slice %[[A]]
+  //      CHECK:   tensor.extract_slice %[[B]]
+  //      CHECK:   tensor.extract_slice %[[C_BLK]]
+  //      CHECK:   linalg.matmul
+  //      CHECK:   scf.forall.in_parallel
+  // CHECK-NEXT:    tensor.parallel_insert_slice
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %sz = transform.param.constant 10 : i64 -> !transform.param<i64>
+    %1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param<i64>, 20]
+           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic_param(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic_param(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  //      CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
+  //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+  //      CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
+  //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+  //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+  //      CHECK:   tensor.extract_slice %[[A]]
+  //      CHECK:   tensor.extract_slice %[[B]]
+  //      CHECK:   tensor.extract_slice %[[C_BLK]]
+  //      CHECK:   linalg.matmul
+  //      CHECK:   scf.forall.in_parallel
+  // CHECK-NEXT:    tensor.parallel_insert_slice
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %sz = transform.param.constant [10 : i64, 20 : i64] -> !transform.any_param
+    %1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
+           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list