[Mlir-commits] [mlir] Add support of param type for transform.structured.tile_using_forall (PR #72097)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 13 02:30:30 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: jinchen62 (jinchen62)

<details>
<summary>Changes</summary>

Make transform.structured.tile_using_forall be able to take param type tile sizes.

Examples:
```
%tile_size1 = transform.param.constant 16 : i64 -> !transform.param<i64>
transform.structured.tile_using_forall %matmul tile_sizes [%tile_size1 : !transform.param<i64>, 32] ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
```
```
%tile_sizes = transform.param.constant [16 : i64, 32 : i64] -> !transform.any_param
transform.structured.tile_using_forall %matmul tile_sizes [%tile_sizes : !transform.any_param] ( mapping = [#gpu.block<x>, #gpu.block<y>] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
```

---
Full diff: https://github.com/llvm/llvm-project/pull/72097.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+5-5) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+24-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..a24f6ff8308ba34 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -23,7 +23,7 @@ include "mlir/IR/RegionKindInterface.td"
 // value in the payload IR.
 def TransformParamTypeOrAnyHandle : Type<
     Or<[TransformHandleTypeInterface.predicate,
-        Transform_ParamType.predicate]>,
+        TransformParamTypeInterface.predicate]>,
     "transform 'param' type or any handle type">;
 
 //===----------------------------------------------------------------------===//
@@ -1924,10 +1924,10 @@ def TileUsingForallOp :
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   Variadic<TransformHandleTypeInterface>:$num_threads,
-                   Variadic<TransformHandleTypeInterface>:$tile_sizes,
-                   Optional<TransformHandleTypeInterface>:$packed_num_threads,
-                   Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
+                   Variadic<TransformParamTypeOrAnyHandle>:$num_threads,
+                   Variadic<TransformParamTypeOrAnyHandle>:$tile_sizes,
+                   Optional<TransformParamTypeOrAnyHandle>:$packed_num_threads,
+                   Optional<TransformParamTypeOrAnyHandle>:$packed_tile_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..4bf4db3381fab79 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -98,12 +98,34 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
       result.push_back(ofr);
       continue;
     }
-    auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+
+    Value transformValue = ofr.get<Value>();
+    if (isa<ParamType>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (!isa<IntegerAttr>(params[0]))
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      result.push_back(params[0]);
+      continue;
+    }
+    if (isa<AnyParamType>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (!isa<ArrayAttr>(params[0]))
+        return transformOp.emitDefiniteFailure() << "expected ArrayAttr";
+      ArrayAttr paramsArray = cast<ArrayAttr>(params[0]);
+      for (Attribute param : paramsArray.getValue()) {
+        if (!isa<IntegerAttr>(param))
+          return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+        result.push_back(param);
+      }
+      continue;
+    }
+
+    auto payloadOps = state.getPayloadOps(transformValue);
     if (!llvm::hasSingleElement(payloadOps)) {
       DiagnosedSilenceableFailure diag =
           transformOp.emitSilenceableError()
           << "handle must be mapped to exactly one payload op";
-      diag.attachNote(ofr.get<Value>().getLoc())
+      diag.attachNote(transformValue.getLoc())
           << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
       return diag;
     }

``````````

</details>


https://github.com/llvm/llvm-project/pull/72097


More information about the Mlir-commits mailing list