[Mlir-commits] [mlir] [MLIR] Move warp_execute_on_lane_0 from vector to gpu (PR #116994)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 20 07:45:46 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Petr Kurapov (kurapov-peter)

<details>
<summary>Changes</summary>

Please see the related RFC here: https://discourse.llvm.org/t/rfc-move-execute-on-lane-0-from-vector-to-gpu-dialect/82989.

This patch does exactly one thing - moves the op to gpu.

---

Patch is 137.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116994.diff


15 Files Affected:

- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+138) 
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (-133) 
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h (+9-8) 
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+182) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (-182) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+50-48) 
- (modified) mlir/test/Conversion/GPUCommon/transfer_write.mlir (+1-1) 
- (modified) mlir/test/Dialect/GPU/invalid.mlir (+86) 
- (modified) mlir/test/Dialect/GPU/ops.mlir (+36) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (-86) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (-35) 
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+228-228) 
- (modified) mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir (+1-1) 
- (modified) mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir (+1-1) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6-5) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6098eb34d04d52..5b1d7bb87a219a 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1097,6 +1097,10 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, ReturnLike, Terminator]>,
     ```
   }];
 
+  let builders = [
+    OpBuilder<(ins), [{ /* nothing to do */ }]>
+  ];
+
   let assemblyFormat = "attr-dict ($values^ `:` type($values))?";
 }
 
@@ -2921,4 +2925,138 @@ def GPU_SetCsrPointersOp : GPU_Op<"set_csr_pointers", [GPU_AsyncOpInterface]> {
   }];
 }
 
+def GPU_WarpExecuteOnLane0Op : GPU_Op<"warp_execute_on_lane_0",
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["areTypesCompatible"]>,
+       SingleBlockImplicitTerminator<"gpu::YieldOp">,
+       RecursiveMemoryEffects]> {
+  let summary = "Executes operations in the associated region on thread #0 of a"
+                "SPMD program";
+  let description = [{
+    `warp_execute_on_lane_0` is an operation used to bridge the gap between
+    vector programming and SPMD programming model like GPU SIMT. It allows to
+    trivially convert a region of vector code meant to run on a multiple threads
+    into a valid SPMD region and then allows incremental transformation to
+    distribute vector operations on the threads.
+
+    Any code present in the region would only be executed on first thread/lane
+    based on the `laneid` operand. The `laneid` operand is an integer ID between
+    [0, `warp_size`). The `warp_size` attribute indicates the number of lanes in
+    a warp.
+
+    Operands are vector values distributed on all lanes that may be used by
+    the single lane execution. The matching region argument is a vector of all
+    the values of those lanes available to the single active lane. The
+    distributed dimension is implicit based on the shape of the operand and
+    argument. the properties of the distribution may be described by extra
+    attributes (e.g. affine map).
+
+    Return values are distributed on all lanes using laneId as index. The
+    vector is distributed based on the shape ratio between the vector type of
+    the yield and the result type.
+    If the shapes are the same this means the value is broadcasted to all lanes.
+    In the future the distribution can be made more explicit using affine_maps
+    and will support having multiple Ids.
+
+    Therefore the `warp_execute_on_lane_0` operations allow to implicitly copy
+    between lane0 and the lanes of the warp. When distributing a vector
+    from lane0 to all the lanes, the data are distributed in a block cyclic way.
+    For example `vector<64xf32>` gets distributed on 32 threads and map to
+    `vector<2xf32>` where thread 0 contains vector[0] and vector[1].
+
+    During lowering values passed as operands and return value need to be
+    visible to different lanes within the warp. This would usually be done by
+    going through memory.
+
+    The region is *not* isolated from above. For values coming from the parent
+    region not going through operands only the lane 0 value will be accesible so
+    it generally only make sense for uniform values.
+
+    Example:
+    ```
+    // Execute in parallel on all threads/lanes.
+    gpu.warp_execute_on_lane_0 (%laneid)[32] {
+      // Serial code running only on thread/lane 0.
+      ...
+    }
+    // Execute in parallel on all threads/lanes.
+    ```
+
+    This may be lowered to an scf.if region as below:
+    ```
+      // Execute in parallel on all threads/lanes.
+      %cnd = arith.cmpi eq, %laneid, %c0 : index
+      scf.if %cnd {
+        // Serial code running only on thread/lane 0.
+        ...
+      }
+      // Execute in parallel on all threads/lanes.
+    ```
+
+    When the region has operands and/or return values:
+    ```
+    // Execute in parallel on all threads/lanes.
+    %0 = gpu.warp_execute_on_lane_0(%laneid)[32]
+    args(%v0 : vector<4xi32>) -> (vector<1xf32>) {
+    ^bb0(%arg0 : vector<128xi32>) :
+      // Serial code running only on thread/lane 0.
+      ...
+      gpu.yield %1 : vector<32xf32>
+    }
+    // Execute in parallel on all threads/lanes.
+    ```
+
+    values at the region boundary would go through memory:
+    ```
+    // Execute in parallel on all threads/lanes.
+    ...
+    // Store the data from each thread into memory and Synchronization.
+    %tmp0 = memreg.alloc() : memref<128xf32>
+    %tmp1 = memreg.alloc() : memref<32xf32>
+    %cnd = arith.cmpi eq, %laneid, %c0 : index
+    vector.store %v0, %tmp0[%laneid] : memref<128xf32>, vector<4xf32>
+    some_synchronization_primitive
+    scf.if %cnd {
+      // Serialized code running only on thread 0.
+      // Load the data from all the threads into a register from thread 0. This
+      // allow threads 0 to access data from all the threads.
+      %arg0 = vector.load %tmp0[%c0] : memref<128xf32>, vector<128xf32>
+      ...
+      // Store the data from thread 0 into memory.
+      vector.store %1, %tmp1[%c0] : memref<32xf32>, vector<32xf32>
+    }
+    // Synchronization and load the data in a block cyclic way so that the
+    // vector is distributed on all threads.
+    some_synchronization_primitive
+    %0 = vector.load %tmp1[%laneid] : memref<32xf32>, vector<32xf32>
+    // Execute in parallel on all threads/lanes.
+    ```
+
+  }];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+  let arguments = (ins Index:$laneid, I64Attr:$warp_size,
+                       Variadic<AnyType>:$args);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$warpRegion);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
+                   "int64_t":$warpSize)>,
+    // `blockArgTypes` are different than `args` types as they are they
+    // represent all the `args` instances visibile to lane 0. Therefore we need
+    // to explicit pass the type.
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
+                   "int64_t":$warpSize, "ValueRange":$args,
+                   "TypeRange":$blockArgTypes)>
+  ];
+
+  let extraClassDeclaration = [{
+    bool isDefinedOutsideOfRegion(Value value) {
+      return !getRegion().isAncestor(value.getParentRegion());
+    }
+  }];
+}
+
 #endif // GPU_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c5b08d6aa022b1..d0f11acb448355 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2983,138 +2983,5 @@ def Vector_YieldOp : Vector_Op<"yield", [
   let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
 }
 
-def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0",
-      [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["areTypesCompatible"]>,
-       SingleBlockImplicitTerminator<"vector::YieldOp">,
-       RecursiveMemoryEffects]> {
-  let summary = "Executes operations in the associated region on thread #0 of a"
-                "SPMD program";
-  let description = [{
-    `warp_execute_on_lane_0` is an operation used to bridge the gap between
-    vector programming and SPMD programming model like GPU SIMT. It allows to
-    trivially convert a region of vector code meant to run on a multiple threads
-    into a valid SPMD region and then allows incremental transformation to
-    distribute vector operations on the threads.
-
-    Any code present in the region would only be executed on first thread/lane
-    based on the `laneid` operand. The `laneid` operand is an integer ID between
-    [0, `warp_size`). The `warp_size` attribute indicates the number of lanes in
-    a warp.
-
-    Operands are vector values distributed on all lanes that may be used by
-    the single lane execution. The matching region argument is a vector of all
-    the values of those lanes available to the single active lane. The
-    distributed dimension is implicit based on the shape of the operand and
-    argument. the properties of the distribution may be described by extra
-    attributes (e.g. affine map).
-
-    Return values are distributed on all lanes using laneId as index. The
-    vector is distributed based on the shape ratio between the vector type of
-    the yield and the result type.
-    If the shapes are the same this means the value is broadcasted to all lanes.
-    In the future the distribution can be made more explicit using affine_maps
-    and will support having multiple Ids.
-
-    Therefore the `warp_execute_on_lane_0` operations allow to implicitly copy
-    between lane0 and the lanes of the warp. When distributing a vector
-    from lane0 to all the lanes, the data are distributed in a block cyclic way.
-    For exemple `vector<64xf32>` gets distributed on 32 threads and map to
-    `vector<2xf32>` where thread 0 contains vector[0] and vector[1].
-
-    During lowering values passed as operands and return value need to be
-    visible to different lanes within the warp. This would usually be done by
-    going through memory.
-
-    The region is *not* isolated from above. For values coming from the parent
-    region not going through operands only the lane 0 value will be accesible so
-    it generally only make sense for uniform values.
-
-    Example:
-    ```
-    // Execute in parallel on all threads/lanes.
-    vector.warp_execute_on_lane_0 (%laneid)[32] {
-      // Serial code running only on thread/lane 0.
-      ...
-    }
-    // Execute in parallel on all threads/lanes.
-    ```
-
-    This may be lowered to an scf.if region as below:
-    ```
-      // Execute in parallel on all threads/lanes.
-      %cnd = arith.cmpi eq, %laneid, %c0 : index
-      scf.if %cnd {
-        // Serial code running only on thread/lane 0.
-        ...
-      }
-      // Execute in parallel on all threads/lanes.
-    ```
-
-    When the region has operands and/or return values:
-    ```
-    // Execute in parallel on all threads/lanes.
-    %0 = vector.warp_execute_on_lane_0(%laneid)[32]
-    args(%v0 : vector<4xi32>) -> (vector<1xf32>) {
-    ^bb0(%arg0 : vector<128xi32>) :
-      // Serial code running only on thread/lane 0.
-      ...
-      vector.yield %1 : vector<32xf32>
-    }
-    // Execute in parallel on all threads/lanes.
-    ```
-
-    values at the region boundary would go through memory:
-    ```
-    // Execute in parallel on all threads/lanes.
-    ...
-    // Store the data from each thread into memory and Synchronization.
-    %tmp0 = memreg.alloc() : memref<128xf32>
-    %tmp1 = memreg.alloc() : memref<32xf32>
-    %cnd = arith.cmpi eq, %laneid, %c0 : index
-    vector.store %v0, %tmp0[%laneid] : memref<128xf32>, vector<4xf32>
-    some_synchronization_primitive
-    scf.if %cnd {
-      // Serialized code running only on thread 0.
-      // Load the data from all the threads into a register from thread 0. This
-      // allow threads 0 to access data from all the threads.
-      %arg0 = vector.load %tmp0[%c0] : memref<128xf32>, vector<128xf32>
-      ...
-      // Store the data from thread 0 into memory.
-      vector.store %1, %tmp1[%c0] : memref<32xf32>, vector<32xf32>
-    }
-    // Synchronization and load the data in a block cyclic way so that the
-    // vector is distributed on all threads.
-    some_synchronization_primitive
-    %0 = vector.load %tmp1[%laneid] : memref<32xf32>, vector<32xf32>
-    // Execute in parallel on all threads/lanes.
-    ```
-
-  }];
-
-  let hasVerifier = 1;
-  let hasCustomAssemblyFormat = 1;
-  let arguments = (ins Index:$laneid, I64Attr:$warp_size,
-                       Variadic<AnyType>:$args);
-  let results = (outs Variadic<AnyType>:$results);
-  let regions = (region SizedRegion<1>:$warpRegion);
-
-  let skipDefaultBuilders = 1;
-  let builders = [
-    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
-                   "int64_t":$warpSize)>,
-    // `blockArgTypes` are different than `args` types as they are they
-    // represent all the `args` instances visibile to lane 0. Therefore we need
-    // to explicit pass the type.
-    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
-                   "int64_t":$warpSize, "ValueRange":$args,
-                   "TypeRange":$blockArgTypes)>
-  ];
-
-  let extraClassDeclaration = [{
-    bool isDefinedOutsideOfRegion(Value value) {
-      return !getRegion().isAncestor(value.getParentRegion());
-    }
-  }];
-}
 
 #endif // MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index 8907a2a583609a..dda45219b2acc2 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
 
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
 namespace mlir {
@@ -23,15 +24,15 @@ struct WarpExecuteOnLane0LoweringOptions {
   /// type may be VectorType or a scalar) and be availble for the current warp.
   /// If there are several warps running in parallel the allocation needs to be
   /// split so that each warp has its own allocation.
-  using WarpAllocationFn =
-      std::function<Value(Location, OpBuilder &, WarpExecuteOnLane0Op, Type)>;
+  using WarpAllocationFn = std::function<Value(
+      Location, OpBuilder &, gpu::WarpExecuteOnLane0Op, Type)>;
   WarpAllocationFn warpAllocationFn = nullptr;
 
   /// Lamdba function to let user emit operation to syncronize all the thread
   /// within a warp. After this operation all the threads can see any memory
   /// written before the operation.
   using WarpSyncronizationFn =
-      std::function<void(Location, OpBuilder &, WarpExecuteOnLane0Op)>;
+      std::function<void(Location, OpBuilder &, gpu::WarpExecuteOnLane0Op)>;
   WarpSyncronizationFn warpSyncronizationFn = nullptr;
 };
 
@@ -48,17 +49,17 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
 ///
 /// Example:
 /// ```
-/// %0 = vector.warp_execute_on_lane_0(%id){
+/// %0 = gpu.warp_execute_on_lane_0(%id){
 ///   ...
 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
-///   vector.yield
+///   gpu.yield
 /// }
 /// ```
 /// To
 /// ```
-/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
+/// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
 ///   ...
-///   vector.yield %v : vector<32xf32>
+///   gpu.yield %v : vector<32xf32>
 /// }
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
 ///
@@ -73,7 +74,7 @@ void populateDistributeTransferWriteOpPatterns(
 
 /// Move scalar operations with no dependency on the warp op outside of the
 /// region.
-void moveScalarUniformCode(WarpExecuteOnLane0Op op);
+void moveScalarUniformCode(gpu::WarpExecuteOnLane0Op op);
 
 /// Lambda signature to compute a warp shuffle of a given value of a given lane
 /// within a given warp size.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 956877497d9338..f019007faede8d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -36,6 +36,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/StringSaver.h"
 #include <cassert>
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::gpu;
@@ -2188,6 +2189,187 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// GPU WarpExecuteOnLane0Op
+//===----------------------------------------------------------------------===//
+
+void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
+  p << "(" << getLaneid() << ")";
+
+  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
+  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
+  p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
+
+  if (!getArgs().empty())
+    p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
+  if (!getResults().empty())
+    p << " -> (" << getResults().getTypes() << ')';
+  p << " ";
+  p.printRegion(getRegion(),
+                /*printEntryBlockArgs=*/true,
+                /*printBlockTerminators=*/!getResults().empty());
+  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
+}
+
+ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
+                                        OperationState &result) {
+  // Create the region.
+  result.regions.reserve(1);
+  Region *warpRegion = result.addRegion();
+
+  auto &builder = parser.getBuilder();
+  OpAsmParser::UnresolvedOperand laneId;
+
+  // Parse predicate operand.
+  if (parser.parseLParen() ||
+      parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
+      parser.parseRParen())
+    return failure();
+
+  int64_t warpSize;
+  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
+      parser.parseRSquare())
+    return failure();
+  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
+                                                        builder.getContext())),
+                      builder.getI64IntegerAttr(warpSize));
+
+  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
+    return failure();
+
+  llvm::SMLoc inputsOperandsLoc;
+  SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
+  SmallVector<Type> inputTypes;
+  if (succeeded(parser.parseOptionalKeyword("args"))) {
+    if (parser.parseLParen())
+      return failure();
+
+    inputsOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseOperandList(inputsOperands) ||
+        parser.parseColonTypeList(inputTypes) || parser.parseRParen())
+      return failure();
+  }
+  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
+                             result.operands))
+    return failure();
+
+  // Parse optional results type list.
+  if (parser.parseOptionalArrowTypeList(result.types))
+    return failure();
+  // Parse the region.
+  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
+                         /*argTypes=*/{}))
+    return failure();
+  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  return success();
+}
+
+void WarpExecuteOnLane0Op::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (!point.isParent()) {
+    regions.push_back(RegionSuccessor(getResults()));
+    return;
+  }
+
+  // The warp region is always executed
+  regions.push_back(RegionSuccessor(&getWarpRegion()));
+}
+
+void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
+                                 TypeRange resultTypes, Value laneId,
+                                 int64_t warpSize) {
+  build(builder, result, resultTypes, laneId, warpSize,
+        /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
+}
+
+void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
+                                 TypeRange resultTypes, Value laneId,
+                                 int64_t warpSize, ValueRange args,
+                                 TypeRange blockArgTypes) {
+  result.addOperands(laneId);
+  result.addAttribute(getAttributeNames()[0],
+                      builder.getI64IntegerAttr(warpSize));
+  result.addTypes(resultTypes);
+  result.addOperands(args);
+  assert(args.size() == blockArgTypes.size());
+  OpBuilder::InsertionGuard guard(builder);
+  Region *warpRegion = result.addRegion();
+  Block *block = builder.createBlock(warpRegion);
+  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
+    block->addArgument(type, arg.getLoc());
+}
+
+/// Helper check if the distributed vector type is consistent with the expanded
+/// type and distributed size.
+static LogicalResult verifyDistributedType(Type expanded, Type distributed,
+                                           int64_t warpSize, Operation *op) {
+  // If the types matches there is no distribution.
+  if (exp...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/116994


More information about the Mlir-commits mailing list