[Mlir-commits] [mlir] Add support of param type for transform.structured.tile_using_forall (PR #72097)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 15 15:18:51 PST 2023
https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/72097
>From a04b3bdfe51e5949e1ce602cfc646865a9262ba4 Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Mon, 13 Nov 2023 02:19:39 -0800
Subject: [PATCH] Add support of param type for
transform.structured.tile_using_forall
---
.../Linalg/TransformOps/LinalgTransformOps.td | 10 +--
.../TransformOps/LinalgTransformOps.cpp | 63 +++++++++++++------
2 files changed, 50 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..a24f6ff8308ba34 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -23,7 +23,7 @@ include "mlir/IR/RegionKindInterface.td"
// value in the payload IR.
def TransformParamTypeOrAnyHandle : Type<
Or<[TransformHandleTypeInterface.predicate,
- Transform_ParamType.predicate]>,
+ TransformParamTypeInterface.predicate]>,
"transform 'param' type or any handle type">;
//===----------------------------------------------------------------------===//
@@ -1924,10 +1924,10 @@ def TileUsingForallOp :
}];
let arguments = (ins TransformHandleTypeInterface:$target,
- Variadic<TransformHandleTypeInterface>:$num_threads,
- Variadic<TransformHandleTypeInterface>:$tile_sizes,
- Optional<TransformHandleTypeInterface>:$packed_num_threads,
- Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
+ Variadic<TransformParamTypeOrAnyHandle>:$num_threads,
+ Variadic<TransformParamTypeOrAnyHandle>:$tile_sizes,
+ Optional<TransformParamTypeOrAnyHandle>:$packed_num_threads,
+ Optional<TransformParamTypeOrAnyHandle>:$packed_tile_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..3615cd784027200 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -86,8 +86,9 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
return cast<LinalgOp>(result->getOperation());
}
-/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
-/// to exactly one op with one index result, return that value.
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, return that value.
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
transform::TransformState &state, TransformOpInterface transformOp,
SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
@@ -98,12 +99,22 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
result.push_back(ofr);
continue;
}
- auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+
+ Value transformValue = ofr.get<Value>();
+ if (isa<ParamType>(transformValue.getType())) {
+ ArrayRef<Attribute> params = state.getParams(transformValue);
+ if (!isa<IntegerAttr>(params[0]))
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ result.push_back(params[0]);
+ continue;
+ }
+
+ auto payloadOps = state.getPayloadOps(transformValue);
if (!llvm::hasSingleElement(payloadOps)) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
<< "handle must be mapped to exactly one payload op";
- diag.attachNote(ofr.get<Value>().getLoc())
+ diag.attachNote(transformValue.getLoc())
<< "mapped to " << llvm::range_size(payloadOps) << " payload ops";
return diag;
}
@@ -123,24 +134,40 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
return DiagnosedSilenceableFailure::success();
}
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
+// Given a list of params that are index attrs or a list of OpFoldResults
+// that are either index attrs or op handles, return a list of OpFoldResults
+// of index attrs or a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op.
+// (There must be exactly one parameter associated with the AnyParamType or
+// one mapped payload op which must have exactly one index result.)
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
transform::TransformState &state, TransformOpInterface transformOp,
SmallVector<OpFoldResult> &result, Value packedHandle) {
- for (Operation *op : state.getPayloadOps(packedHandle)) {
- if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
- DiagnosedSilenceableFailure diag =
- transformOp.emitSilenceableError()
- << "payload op must have exactly 1 index result";
- diag.attachNote(op->getLoc())
- << "has " << op->getNumResults() << " results";
- return diag;
+ if (isa<AnyParamType>(packedHandle.getType())) {
+ ArrayRef<Attribute> params = state.getParams(packedHandle);
+ if (params.size() != 1)
+ return transformOp.emitDefiniteFailure()
+ << "requires exactly one parameter associated";
+ ArrayAttr paramsArray = dyn_cast<ArrayAttr>(params[0]);
+ if (!paramsArray)
+ return transformOp.emitDefiniteFailure() << "param array is null";
+ for (Attribute param : paramsArray.getValue()) {
+ if (!isa<IntegerAttr>(param))
+ return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+ result.push_back(param);
+ }
+ } else {
+ for (Operation *op : state.getPayloadOps(packedHandle)) {
+ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+ DiagnosedSilenceableFailure diag =
+ transformOp.emitSilenceableError()
+ << "payload op must have exactly 1 index result";
+ diag.attachNote(op->getLoc())
+ << "has " << op->getNumResults() << " results";
+ return diag;
+ }
+ result.push_back(op->getResult(0));
}
- result.push_back(op->getResult(0));
}
return DiagnosedSilenceableFailure::success();
More information about the Mlir-commits
mailing list