[Mlir-commits] [mlir] Add support of param type for transform.structured.tile_using_forall (PR #72097)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 15 15:18:51 PST 2023


https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/72097

>From a04b3bdfe51e5949e1ce602cfc646865a9262ba4 Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Mon, 13 Nov 2023 02:19:39 -0800
Subject: [PATCH] Add support of param type for
 transform.structured.tile_using_forall

---
 .../Linalg/TransformOps/LinalgTransformOps.td | 10 +--
 .../TransformOps/LinalgTransformOps.cpp       | 63 +++++++++++++------
 2 files changed, 50 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..a24f6ff8308ba34 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -23,7 +23,7 @@ include "mlir/IR/RegionKindInterface.td"
 // value in the payload IR.
 def TransformParamTypeOrAnyHandle : Type<
     Or<[TransformHandleTypeInterface.predicate,
-        Transform_ParamType.predicate]>,
+        TransformParamTypeInterface.predicate]>,
     "transform 'param' type or any handle type">;
 
 //===----------------------------------------------------------------------===//
@@ -1924,10 +1924,10 @@ def TileUsingForallOp :
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   Variadic<TransformHandleTypeInterface>:$num_threads,
-                   Variadic<TransformHandleTypeInterface>:$tile_sizes,
-                   Optional<TransformHandleTypeInterface>:$packed_num_threads,
-                   Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
+                   Variadic<TransformParamTypeOrAnyHandle>:$num_threads,
+                   Variadic<TransformParamTypeOrAnyHandle>:$tile_sizes,
+                   Optional<TransformParamTypeOrAnyHandle>:$packed_num_threads,
+                   Optional<TransformParamTypeOrAnyHandle>:$packed_tile_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..3615cd784027200 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -86,8 +86,9 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
   return cast<LinalgOp>(result->getOperation());
 }
 
-/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
-/// to exactly one op with one index result, return that value.
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, return that value.
 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
@@ -98,12 +99,22 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
       result.push_back(ofr);
       continue;
     }
-    auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+
+    Value transformValue = ofr.get<Value>();
+    if (isa<ParamType>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (!isa<IntegerAttr>(params[0]))
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      result.push_back(params[0]);
+      continue;
+    }
+
+    auto payloadOps = state.getPayloadOps(transformValue);
     if (!llvm::hasSingleElement(payloadOps)) {
       DiagnosedSilenceableFailure diag =
           transformOp.emitSilenceableError()
           << "handle must be mapped to exactly one payload op";
-      diag.attachNote(ofr.get<Value>().getLoc())
+      diag.attachNote(transformValue.getLoc())
           << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
       return diag;
     }
@@ -123,24 +134,40 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
   return DiagnosedSilenceableFailure::success();
 }
 
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
+// Given a list of params that are index attrs or a list of OpFoldResults
+// that are either index attrs or op handles, return a list of OpFoldResults
+// of index attrs or a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op.
+// (There must be exactly one parameter associated with the AnyParamType or
+// one mapped payload op which must have exactly one index result.)
 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, Value packedHandle) {
-  for (Operation *op : state.getPayloadOps(packedHandle)) {
-    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
-      DiagnosedSilenceableFailure diag =
-          transformOp.emitSilenceableError()
-          << "payload op must have exactly 1 index result";
-      diag.attachNote(op->getLoc())
-          << "has " << op->getNumResults() << " results";
-      return diag;
+  if (isa<AnyParamType>(packedHandle.getType())) {
+    ArrayRef<Attribute> params = state.getParams(packedHandle);
+    if (params.size() != 1)
+      return transformOp.emitDefiniteFailure()
+             << "requires exactly one parameter associated";
+    ArrayAttr paramsArray = dyn_cast<ArrayAttr>(params[0]);
+    if (!paramsArray)
+      return transformOp.emitDefiniteFailure() << "param array is null";
+    for (Attribute param : paramsArray.getValue()) {
+      if (!isa<IntegerAttr>(param))
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      result.push_back(param);
+    }
+  } else {
+    for (Operation *op : state.getPayloadOps(packedHandle)) {
+      if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+        DiagnosedSilenceableFailure diag =
+            transformOp.emitSilenceableError()
+            << "payload op must have exactly 1 index result";
+        diag.attachNote(op->getLoc())
+            << "has " << op->getNumResults() << " results";
+        return diag;
+      }
+      result.push_back(op->getResult(0));
     }
-    result.push_back(op->getResult(0));
   }
 
   return DiagnosedSilenceableFailure::success();



More information about the Mlir-commits mailing list