[Mlir-commits] [mlir] d93be48 - [mlir][transform] Make `tile_to_foreach_thread_op` builder to use ArrayAttr

Guray Ozen llvmlistbot at llvm.org
Sat Nov 12 10:27:32 PST 2022


Author: Guray Ozen
Date: 2022-11-12T19:27:25+01:00
New Revision: d93be483eaf5e22f4192325f9357821cbd2e934e

URL: https://github.com/llvm/llvm-project/commit/d93be483eaf5e22f4192325f9357821cbd2e934e
DIFF: https://github.com/llvm/llvm-project/commit/d93be483eaf5e22f4192325f9357821cbd2e934e.diff

LOG: [mlir][transform] Make `tile_to_foreach_thread_op` builder to use ArrayAttr

D137413 clarified `scf_foreach_thread` thread mapping nicely. `tile_to_foreach_thread_op` is one of the op that generates `scf_foreach_thread`, however, its builders are still having integer array.

This is bug fix of potential problem.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D137891

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b8638f10de98b..b92ed19863069 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -842,22 +842,22 @@ def TileToForeachThreadOp :
                    "ArrayRef<int64_t>":$staticTileSizes,
                    CArg<"::mlir::transform::TileSizesSpec", 
                         "::mlir::transform::TileSizesSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+                   CArg<"ArrayAttr", "{}">:$mapping)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedTileSizes,
                    CArg<"::mlir::transform::TileSizesSpec", 
                         "::mlir::transform::TileSizesSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+                   CArg<"ArrayAttr", "{}">:$mapping)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<int64_t>":$staticNumThreads,
                    CArg<"::mlir::transform::NumThreadsSpec", 
                         "::mlir::transform::NumThreadsSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+                   CArg<"ArrayAttr", "{}">:$mapping)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedNumThreads,
                    CArg<"::mlir::transform::NumThreadsSpec", 
                         "::mlir::transform::NumThreadsSpec()">,
-                   CArg<"ArrayRef<int64_t>", "{}">:$mapping)>,
+                   CArg<"ArrayAttr", "{}">:$mapping)>,
   ];
 
   let assemblyFormat = [{

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7b720a7452a5a..cdd4e158f9b1c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1326,7 +1326,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
                                              Value target,
                                              ArrayRef<int64_t> staticTileSizes,
                                              transform::TileSizesSpec,
-                                             ArrayRef<int64_t> mapping) {
+                                             ArrayAttr mapping) {
   return build(builder, result, target,
                getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
                TileSizesSpec(), mapping);
@@ -1335,7 +1335,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
 void transform::TileToForeachThreadOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec,
-    ArrayRef<int64_t> mapping) {
+    ArrayAttr mapping) {
   SmallVector<int64_t> staticTileSizes;
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes,
@@ -1346,12 +1346,9 @@ void transform::TileToForeachThreadOp::build(
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
-  ArrayAttr mappingAttr;
-  if (!mapping.empty())
-    mappingAttr = builder.getI64ArrayAttr(mapping);
   build(builder, result, TypeRange{operationType, operationType}, target,
         /*numThreads=*/ValueRange{}, dynamicTileSizes,
-        /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mappingAttr);
+        /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping);
 }
 
 void transform::TileToForeachThreadOp::build(OpBuilder &builder,
@@ -1359,7 +1356,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
                                              Value target,
                                              ArrayRef<int64_t> staticNumThreads,
                                              transform::NumThreadsSpec,
-                                             ArrayRef<int64_t> mapping) {
+                                             ArrayAttr mapping) {
   return build(builder, result, target,
                getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
                NumThreadsSpec(), mapping);
@@ -1368,7 +1365,7 @@ void transform::TileToForeachThreadOp::build(OpBuilder &builder,
 void transform::TileToForeachThreadOp::build(
     OpBuilder &builder, OperationState &result, Value target,
     ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec,
-    ArrayRef<int64_t> mapping) {
+    ArrayAttr mapping) {
   SmallVector<int64_t> staticNumThreads;
   SmallVector<Value> dynamicNumThreads;
   dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
@@ -1379,12 +1376,9 @@ void transform::TileToForeachThreadOp::build(
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
-  ArrayAttr mappingAttr;
-  if (!mapping.empty())
-    mappingAttr = builder.getI64ArrayAttr(mapping);
   build(builder, result, TypeRange{operationType, operationType}, target,
         dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr,
-        /*staticTileSizes=*/ArrayAttr(), mappingAttr);
+        /*staticTileSizes=*/ArrayAttr(), mapping);
 }
 
 DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(


        


More information about the Mlir-commits mailing list