[Mlir-commits] [mlir] 0660f3c - [mlir][vector] Relax reduction distribution pattern

Thomas Raoux llvmlistbot at llvm.org
Sat Jul 9 11:47:16 PDT 2022

Author: Thomas Raoux
Date: 2022-07-09T18:36:39Z
New Revision: 0660f3c5a0a083cef78f39b3105a9a8e27cf5095

URL: https://github.com/llvm/llvm-project/commit/0660f3c5a0a083cef78f39b3105a9a8e27cf5095
DIFF: https://github.com/llvm/llvm-project/commit/0660f3c5a0a083cef78f39b3105a9a8e27cf5095.diff

LOG: [mlir][vector] Relax reduction distribution pattern

Support distributing reductions with vector size multiple of the warp

Differential Revision: https://reviews.llvm.org/D129387




diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index bf6e2225e6f20..2b9635835d7b1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -798,7 +798,7 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return rewriter.notifyMatchFailure(
           warpOp, "Only rank 1 reductions can be distributed.");
     // Only warp_size-sized vectors supported.
-    if (static_cast<uint64_t>(vectorType.getShape()[0]) != warpOp.getWarpSize())
+    if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
       return rewriter.notifyMatchFailure(
           warpOp, "Reduction vector dimension must match was size.");
     // Only f32 and i32 element types are supported.
@@ -808,24 +808,26 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
           "Reduction distribution currently only supports 32bits types.");
-    Location yieldLoc = yieldOperand->getOwner()->getLoc();
+    int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
     unsigned operandIndex = yieldOperand->getOperandNumber();
     SmallVector<Value> yieldValues = {reductionOp.getVector()};
-    SmallVector<Type> retTypes = {VectorType::get({1}, reductionOp.getType())};
+    SmallVector<Type> retTypes = {
+        VectorType::get({numElements}, reductionOp.getType())};
     unsigned numResults = warpOp.getNumResults();
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, yieldValues, retTypes);
-    // Every lane has one scalar value. These should be reduced.
     Value laneValVec = newWarpOp.getResult(numResults);
-    Value laneVal = rewriter.create<vector::ExtractOp>(yieldLoc, laneValVec, 0);
-    laneVal =
-        distributedReductionFn(reductionOp.getLoc(), rewriter, laneVal,
+    // First reduce on a single thread.
+    Value perLaneReduction = rewriter.create<vector::ReductionOp>(
+        reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
+    // Then distribute across threads.
+    Value fullReduce =
+        distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
                                reductionOp.getKind(), newWarpOp.getWarpSize());
-    newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal);
+    newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce);
     return success();

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 718a7bf1c6b08..82f6299634578 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -509,5 +509,39 @@ func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref
     %5 = vector.broadcast %4 : f32 to vector<f32>
     vector.transfer_write %5, %m1[] : vector<f32>, memref<f32>
-  return 
+  return
+// -----
+// CHECK-PROP-LABEL: func @vector_reduction_large(
+//  CHECK-PROP-SAME:     %[[laneid:.*]]: index)
+//   CHECK-PROP-DAG:   %[[c1:.*]] = arith.constant 1 : i32
+//   CHECK-PROP-DAG:   %[[c2:.*]] = arith.constant 2 : i32
+//   CHECK-PROP-DAG:   %[[c4:.*]] = arith.constant 4 : i32
+//   CHECK-PROP-DAG:   %[[c8:.*]] = arith.constant 8 : i32
+//   CHECK-PROP-DAG:   %[[c16:.*]] = arith.constant 16 : i32
+//   CHECK-PROP-DAG:   %[[c32:.*]] = arith.constant 32 : i32
+//       CHECK-PROP:   %[[warp_op:.*]] = vector.warp_execute_on_lane_0(%[[laneid]])[32] -> (vector<2xf32>) {
+//       CHECK-PROP:     vector.yield %{{.*}} : vector<64xf32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[a:.*]] = vector.reduction <add>, %[[warp_op]] : vector<2xf32> into f32
+//       CHECK-PROP:   %[[r0:.*]], %{{.*}} = gpu.shuffle  xor %[[a]], %[[c1]], %[[c32]]
+//       CHECK-PROP:   %[[a0:.*]] = arith.addf %[[a]], %[[r0]]
+//       CHECK-PROP:   %[[r1:.*]], %{{.*}} = gpu.shuffle  xor %[[a0]], %[[c2]], %[[c32]]
+//       CHECK-PROP:   %[[a1:.*]] = arith.addf %[[a0]], %[[r1]]
+//       CHECK-PROP:   %[[r2:.*]], %{{.*}} = gpu.shuffle  xor %[[a1]], %[[c4]], %[[c32]]
+//       CHECK-PROP:   %[[a2:.*]] = arith.addf %[[a1]], %[[r2]]
+//       CHECK-PROP:   %[[r3:.*]], %{{.*}} = gpu.shuffle  xor %[[a2]], %[[c8]], %[[c32]]
+//       CHECK-PROP:   %[[a3:.*]] = arith.addf %[[a2]], %[[r3]]
+//       CHECK-PROP:   %[[r4:.*]], %{{.*}} = gpu.shuffle  xor %[[a3]], %[[c16]], %[[c32]]
+//       CHECK-PROP:   %[[a4:.*]] = arith.addf %[[a3]], %[[r4]]
+//       CHECK-PROP:   return %[[a4]] : f32
+func.func @vector_reduction_large(%laneid: index) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.reduction <add>, %0 : vector<64xf32> into f32
+    vector.yield %1 : f32
+  }
+  return %r : f32


More information about the Mlir-commits mailing list