[Mlir-commits] [mlir] Add support of param type for transform.structured.tile_using_forall (PR #72097)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 28 18:30:58 PST 2024


https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/72097

>From df347ad207bd3978266ef69ef0d6430bf672f9ca Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Mon, 13 Nov 2023 02:19:39 -0800
Subject: [PATCH] Add support of param type for
 transform.structured.tile_using_forall

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  24 +--
 .../TransformOps/LinalgTransformOps.cpp       |  43 ++++--
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  | 137 +++++++++++++++++-
 3 files changed, 182 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b139f1ef58b3a99..309573a562872fb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -21,10 +21,10 @@ include "mlir/IR/RegionKindInterface.td"
 
 // This is roughly similar to OpFoldResult assuming the handle produces a single
 // value in the payload IR.
-def TransformParamTypeOrAnyHandle : Type<
+def TransformAnyParamTypeOrAnyHandle : Type<
     Or<[TransformHandleTypeInterface.predicate,
-        Transform_ParamType.predicate]>,
-    "transform 'param' type or any handle type">;
+        TransformParamTypeInterface.predicate]>,
+    "transform any param type or any handle type">;
 
 //===----------------------------------------------------------------------===//
 // Apply...PatternsOp
@@ -691,9 +691,9 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
                        I64Attr:$dimension,
                        I64Attr:$target_size,
                        DefaultValuedAttr<I64Attr, "1">:$divisor);
-  let results = (outs TransformParamTypeOrAnyHandle:$low_size,
-                      TransformParamTypeOrAnyHandle:$high_size,
-                      TransformParamTypeOrAnyHandle:$split_point);
+  let results = (outs TransformAnyParamTypeOrAnyHandle:$low_size,
+                      TransformAnyParamTypeOrAnyHandle:$high_size,
+                      TransformAnyParamTypeOrAnyHandle:$split_point);
   let hasVerifier = 1;
   let assemblyFormat =
     "$target attr-dict `:` custom<MultitileSizesTypes>("
@@ -1408,7 +1408,7 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        I64Attr:$dimension,
-                       Optional<TransformParamTypeOrAnyHandle>:$dynamic_split_point,
+                       Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
                        I64Attr:$static_split_point);
   let results = (outs TransformHandleTypeInterface:$first,
                       TransformHandleTypeInterface:$second);
@@ -1857,7 +1857,7 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$dynamic_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
                    DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
@@ -1968,10 +1968,10 @@ def TileUsingForallOp :
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                   Variadic<TransformHandleTypeInterface>:$num_threads,
-                   Variadic<TransformHandleTypeInterface>:$tile_sizes,
-                   Optional<TransformHandleTypeInterface>:$packed_num_threads,
-                   Optional<TransformHandleTypeInterface>:$packed_tile_sizes,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$num_threads,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$tile_sizes,
+                   Optional<TransformAnyParamTypeOrAnyHandle>:$packed_num_threads,
+                   Optional<TransformAnyParamTypeOrAnyHandle>:$packed_tile_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
                    OptionalAttr<DeviceMappingArrayAttr>:$mapping);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index df9e613e04aed35..6431bbd25396a52 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -86,8 +86,9 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
   return cast<LinalgOp>(result->getOperation());
 }
 
-/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
-/// to exactly one op with one index result, return that value.
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, return that value.
 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
@@ -98,12 +99,23 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
       result.push_back(ofr);
       continue;
     }
-    auto payloadOps = state.getPayloadOps(ofr.get<Value>());
+
+    Value transformValue = ofr.get<Value>();
+    if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (params.size() != 1)
+        return transformOp.emitDefiniteFailure()
+               << "requires exactly one parameter associated";
+      result.push_back(params[0]);
+      continue;
+    }
+
+    auto payloadOps = state.getPayloadOps(transformValue);
     if (!llvm::hasSingleElement(payloadOps)) {
       DiagnosedSilenceableFailure diag =
           transformOp.emitSilenceableError()
           << "handle must be mapped to exactly one payload op";
-      diag.attachNote(ofr.get<Value>().getLoc())
+      diag.attachNote(transformValue.getLoc())
           << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
       return diag;
     }
@@ -123,14 +135,27 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
   return DiagnosedSilenceableFailure::success();
 }
 
-// Given a list of OpFoldResults that are either index attrs or op
-// handles, return a list of OpFoldResults where all op handles are
-// replaced with the first (and only) OpResult of that payload op. (There
-// must be exactly one mapped payload op and it must have exactly one
-// index result.)
+// Given a list of params that are index attrs or a list of OpFoldResults
+// that are either index attrs or op handles, return a list of OpFoldResults
+// of index attrs or a list of OpFoldResults where all op handles are
+// replaced with the first (and only) OpResult of that payload op.
+// (There must be exactly one parameter associated with the AnyParamType or
+// one mapped payload op which must have exactly one index result.)
 static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
     transform::TransformState &state, TransformOpInterface transformOp,
     SmallVector<OpFoldResult> &result, Value packedHandle) {
+  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
+    ArrayRef<Attribute> params = state.getParams(packedHandle);
+    for (auto param : params) {
+      if (!isa<IntegerAttr>(param))
+        return transformOp.emitDefiniteFailure()
+               << "expected the parameter to be associated with an integer "
+                  "attribute";
+      result.push_back(param);
+    }
+    return DiagnosedSilenceableFailure::success();
+  }
+
   for (Operation *op : state.getPayloadOps(packedHandle)) {
     if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
       DiagnosedSilenceableFailure diag =
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 2192d160b1150f7..abd807b3e4d3e19 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s
 
 // Offset per thread:
 // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
@@ -451,3 +451,138 @@ module attributes {transform.with_named_sequence} {
     }
   }
 
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  //      CHECK: %[[c1:.*]] = arith.constant 1 : index
+  //      CHECK: %[[c0:.*]] = arith.constant 0 : index
+  //      CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
+  //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] :
+  //      CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
+  //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+  //      CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
+  //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
+  //      CHECK:   tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
+  //      CHECK:   tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
+  //      CHECK:   tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
+  //      CHECK:   linalg.matmul
+  //      CHECK:   scf.forall.in_parallel
+  // CHECK-NEXT:    tensor.parallel_insert_slice
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %sz = transform.param.constant 10 : i64 -> !transform.param<i64>
+    %1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param<i64>, 20]
+           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %c10 = transform.param.constant 10 : i64 -> !transform.param<i64>
+    %c20 = transform.param.constant 20 : i64 -> !transform.param<i64>
+    %sz = transform.merge_handles %c10, %c20 : !transform.param<i64>
+    // expected-error @below {{requires exactly one parameter associated}}
+    %1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param<i64>, 20]
+           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  //      CHECK: %[[c1:.*]] = arith.constant 1 : index
+  //      CHECK: %[[c0:.*]] = arith.constant 0 : index
+  //      CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
+  //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] :
+  //      CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
+  //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+  //      CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
+  //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
+  //      CHECK:   tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
+  //      CHECK:   tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
+  //      CHECK:   tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
+  //      CHECK:   linalg.matmul
+  //      CHECK:   scf.forall.in_parallel
+  // CHECK-NEXT:    tensor.parallel_insert_slice
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %c10 = transform.param.constant 10 : i64 -> !transform.any_param
+    %c20 = transform.param.constant 20 : i64 -> !transform.any_param
+    %sz = transform.merge_handles %c10, %c20 : !transform.any_param
+    %1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
+           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %sz = transform.param.constant "[10 : i64, 20 : i64]" -> !transform.any_param
+    // expected-error @below {{expected the parameter to be associated with an integer attribute}}
+    %1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param)
+           : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list