[Mlir-commits] [mlir] d206153 - [mlir][vector] Modify constraint and interface for warp reduce on f16 and i8

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 9 11:54:18 PST 2022


Author: stanley-nod
Date: 2022-11-09T11:52:17-08:00
New Revision: d2061530dc093daca93fbb268611e1a146e722de

URL: https://github.com/llvm/llvm-project/commit/d2061530dc093daca93fbb268611e1a146e722de
DIFF: https://github.com/llvm/llvm-project/commit/d2061530dc093daca93fbb268611e1a146e722de.diff

LOG: [mlir][vector] Modify constraint and interface for warp reduce on f16 and i8

Quantization method is crucial and ubiqutous in accelerating machine
learning workloads. Most of these methods uses f16 and i8 types.

This patch relaxes the type contraints on warp reduce distribution to
allow these types. Furthermore, this patch also changed the interface
and moved the initial reduction of data to a single thread into the
distributedReductionFn, this gives flexibility for developers to control
how they are obtaining the initial lane value, which might differ based
on the input types. (i.e to shuffle 32-width type, we need to reduce f16
to 2xf16 types rather than a single element).

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D137691

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 6dfdf766a2f62..a2916a57350ba 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1135,12 +1135,13 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
       return rewriter.notifyMatchFailure(
           warpOp, "Reduction vector dimension must match was size.");
-    // Only f32 and i32 element types are supported.
+    // Only f32, i32, f16, i8 element types are supported.
     if (!reductionOp.getType().isF32() &&
-        !reductionOp.getType().isSignlessInteger(32))
+        !reductionOp.getType().isSignlessInteger(32) &&
+        !reductionOp.getType().isF16() && !reductionOp.getType().isInteger(8))
       return rewriter.notifyMatchFailure(
-          warpOp,
-          "Reduction distribution currently only supports 32bits types.");
+          warpOp, "Reduction distribution currently only supports 32bits, f16, "
+                  "and i8 types.");
 
     int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
@@ -1157,13 +1158,11 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
 
+    // Obtain data to reduce for a single lane.
     Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
-    // First reduce on a single thread.
-    Value perLaneReduction = rewriter.create<vector::ReductionOp>(
-        reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
-    // Then distribute across threads.
+    // Distribute and reduce across threads.
     Value fullReduce =
-        distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
+        distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
                                reductionOp.getKind(), newWarpOp.getWarpSize());
     if (reductionOp.getAcc()) {
       fullReduce = vector::makeArithReduction(

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b66b2fe9ef7f8..de29fc2a66423 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -686,7 +686,8 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
 
 static Value warpReduction(Location loc, OpBuilder &builder, Value input,
                            CombiningKind kind, uint32_t size) {
-  Value laneVal = input;
+  // First reduce on a single thread to get per lane reduction value.
+  Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
   // Parallel reduction using butterfly shuffles.
   for (uint64_t i = 1; i < size; i <<= 1) {
     Value shuffled = builder


        


More information about the Mlir-commits mailing list