[Mlir-commits] [mlir] Add support of param type for transform.structured.tile_using_forall (PR #72097)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Nov 17 12:38:16 PST 2023


================
@@ -123,24 +134,40 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
   return DiagnosedSilenceableFailure::success();
 }
 
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
+// Given a list of params that are index attrs or a list of OpFoldResults
+// that are either index attrs or op handles, return a list of OpFoldResults
+// of index attrs or a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op.
+// (There must be exactly one parameter associated with the AnyParamType or
+// one mapped payload op which must have exactly one index result.)
 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, Value packedHandle) {
-  for (Operation *op : state.getPayloadOps(packedHandle)) {
-    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
-      DiagnosedSilenceableFailure diag =
-          transformOp.emitSilenceableError()
-          << "payload op must have exactly 1 index result";
-      diag.attachNote(op->getLoc())
-          << "has " << op->getNumResults() << " results";
-      return diag;
+  if (isa<AnyParamType>(packedHandle.getType())) {
+    ArrayRef<Attribute> params = state.getParams(packedHandle);
+    if (params.size() != 1)
+      return transformOp.emitDefiniteFailure()
+             << "requires exactly one parameter associated";
+    ArrayAttr paramsArray = dyn_cast<ArrayAttr>(params[0]);
+    if (!paramsArray)
+      return transformOp.emitDefiniteFailure() << "param array is null";
+    for (Attribute param : paramsArray.getValue()) {
+      if (!isa<IntegerAttr>(param))
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      result.push_back(param);
+    }
+  } else {
----------------
ftynse wrote:

Nit: `return DiagnosedSilenceableFailure::success()` in the branch above so `else` cna be dropped.

https://github.com/llvm/llvm-project/pull/72097


More information about the Mlir-commits mailing list