[Mlir-commits] [mlir] a0f843f - [SCF] Add thread_dim_mapping attribute to scf.foreach_thread

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jun 27 04:59:13 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-27T04:58:36-07:00
New Revision: a0f843fdafa71a8f12095afca12c8964954ffab6

URL: https://github.com/llvm/llvm-project/commit/a0f843fdafa71a8f12095afca12c8964954ffab6
DIFF: https://github.com/llvm/llvm-project/commit/a0f843fdafa71a8f12095afca12c8964954ffab6.diff

LOG: [SCF] Add thread_dim_mapping attribute to scf.foreach_thread

An optional thread_dim_mapping index array attribute specifies for each
virtual thread dimension, how it remaps 1-1 to a set of concrete processing
element resources (e.g. a CUDA grid dimension or a level of concrete nested
async parallelism). At this time, the specification is backend-dependent and
is not verified by the op, beyond being an index array attribute.
It is the reponsibility of the lowering to interpret the index array in the
context of the concrete target the op is lowered to, or to ignore it when
the specification is ill-formed or unsupported for a particular target.

Differential Revision: https://reviews.llvm.org/D128633

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
    mlir/test/Dialect/SCF/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b36f6b7a1dba6..cde966592212d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -339,6 +339,15 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     application per thread. Further lowerings are responsible for specifying
     how this is materialized on concrete hardware resources.
 
+    An optional thread_dim_mapping index array attribute specifies for each
+    virtual thread dimension, how it remaps 1-1 to a set of concrete processing
+    element resources (e.g. a CUDA grid dimension or a level of concrete nested
+    async parallelism). At this time, the specification is backend-dependent and
+    is not verified by the op, beyond being an index array attribute.
+    It is the reponsibility of the lowering to interpret the index array in the
+    context of the concrete target the op is lowered to, or to ignore it when
+    the specification is ill-formed or unsupported for a particular target.
+
     The only allowed terminator is `scf.foreach_thread.perform_concurrently`,
     which dictates how the partial results of all parallel invocations should be
     reconciled into a full value.
@@ -398,8 +407,27 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     // Sequential context.
     //
     ```
+
+    Example with thread_dim_mapping attribute:
+    //
+    // Sequential context.
+    //
+    %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
+         (%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
+      //
+      // Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)**
+      // runs its version of the code.
+      //
+       scf.foreach_thread.perform_concurrently {
+         ...
+      }
+    } { thread_dim_mapping = [1, 0] }
+    // Implicit synchronization point.
+    // Sequential context.
+    //
   }];
-  let arguments = (ins Variadic<Index>:$num_threads);
+  let arguments = (ins Variadic<Index>:$num_threads,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
 
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -411,11 +439,13 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
   let skipDefaultBuilders = 1;
   let builders = [
     // Bodyless builder, result types must be specified.
-    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads,
+                   CArg<"ArrayRef<int64_t>", "{}">:$thread_dim_mapping)>,
     // Builder that takes a bodyBuilder lambda, result types are inferred from
     // the terminator.
     OpBuilder<(ins "ValueRange":$num_threads,
-              "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
+                   "ArrayRef<int64_t>":$thread_dim_mapping,
+                   "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
   ];
   let extraClassDeclaration = [{
     int64_t getRank() { return getNumThreads().size(); }

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 73eca29ec9270..bd0f16dbd0e07 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1135,8 +1135,12 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
 // Bodyless builder, result types must be specified.
 void ForeachThreadOp::build(mlir::OpBuilder &builder,
                             mlir::OperationState &result, TypeRange resultTypes,
-                            ValueRange numThreads) {
+                            ValueRange numThreads,
+                            ArrayRef<int64_t> threadDimMapping) {
   result.addOperands(numThreads);
+  result.addAttribute(
+      // TODO: getThreadDimMappingAttrName() but it is not a static member.
+      "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
 
   Region *bodyRegion = result.addRegion();
   OpBuilder::InsertionGuard g(builder);
@@ -1156,9 +1160,12 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
 // the terminator.
 void ForeachThreadOp::build(
     mlir::OpBuilder &builder, mlir::OperationState &result,
-    ValueRange numThreads,
+    ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
   result.addOperands(numThreads);
+  result.addAttribute(
+      // TODO: getThreadDimMappingAttrName() but it is not a static member.
+      "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
 
   OpBuilder::InsertionGuard g(builder);
   Region *bodyRegion = result.addRegion();

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 16ef7b5cad130..36d0f0cefbae3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -999,7 +1000,8 @@ struct ForeachThreadOpInterface
     TypeRange newResultTypes;
     auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
         foreachThreadOp.getLoc(), newResultTypes,
-        foreachThreadOp.getNumThreads());
+        foreachThreadOp.getNumThreads(),
+        extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
     newForeachThreadOp.getBody()->getTerminator()->erase();
 
     // Move over block contents of the old op.

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
index 63d5d88ba0317..365195bc7896b 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
@@ -130,6 +130,7 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
         scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
           tensor<1xf32> into tensor<100xf32>
       }
-  }
+  // CHECK: } {thread_dim_mapping = [5]}
+  } {thread_dim_mapping = [5]}
   return
 }

diff  --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 3ba3c04deb152..294017aef6225 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -338,11 +338,11 @@ func.func @elide_terminator() -> () {
   %num_threads = arith.constant 100 : index
 
   //      CHECK:    scf.foreach_thread
-  // CHECK-NEXT:  }
+  // CHECK-NEXT:  } {thread_dim_mapping = [42]}
   // CHECK-NEXT:  return
   scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
     scf.foreach_thread.perform_concurrently {
     }
-  }
+  } {thread_dim_mapping = [42]}
   return
 }


        


More information about the Mlir-commits mailing list