[Mlir-commits] [mlir] f41abcd - [mlir][vector] Relax restriction on reduction distribution

Thomas Raoux llvmlistbot at llvm.org
Fri Jan 6 08:20:31 PST 2023


Author: Thomas Raoux
Date: 2023-01-06T16:20:17Z
New Revision: f41abcda5ee0cf9d6a99bae5db08c60cbbafa760

URL: https://github.com/llvm/llvm-project/commit/f41abcda5ee0cf9d6a99bae5db08c60cbbafa760
DIFF: https://github.com/llvm/llvm-project/commit/f41abcda5ee0cf9d6a99bae5db08c60cbbafa760.diff

LOG: [mlir][vector] Relax restriction on reduction distribution

Relax unnecessary restriction when distribution a vector.reduce op.
All the float and integer types can be supported by user's lambda.

Differential Revision: https://reviews.llvm.org/D141094

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 16b6000758102..08841e38eecb7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1179,13 +1179,10 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
       return rewriter.notifyMatchFailure(
           warpOp, "Reduction vector dimension must match was size.");
-    // Only f32, i32, f16, i8 element types are supported.
-    if (!reductionOp.getType().isF32() &&
-        !reductionOp.getType().isSignlessInteger(32) &&
-        !reductionOp.getType().isF16() && !reductionOp.getType().isInteger(8))
+    if (!reductionOp.getType().isIntOrFloat())
       return rewriter.notifyMatchFailure(
-          warpOp, "Reduction distribution currently only supports 32bits, f16, "
-                  "and i8 types.");
+          warpOp, "Reduction distribution currently only supports floats and "
+                  "integer types.");
 
     int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
     // Return vector that will be reduced from the WarpExecuteOnLane0Op.


        


More information about the Mlir-commits mailing list