[Mlir-commits] [mlir] 9f74e6e - [mlir][vector][gpu] Use `makeArithReduction` in lowering patterns. NFC. (#75952)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 19 16:04:31 PST 2023


Author: Jakub Kuderski
Date: 2023-12-19T19:04:27-05:00
New Revision: 9f74e6e6157bc4d63a28385c7c0a50506bb8a737

URL: https://github.com/llvm/llvm-project/commit/9f74e6e6157bc4d63a28385c7c0a50506bb8a737
DIFF: https://github.com/llvm/llvm-project/commit/9f74e6e6157bc4d63a28385c7c0a50506bb8a737.diff

LOG: [mlir][vector][gpu] Use `makeArithReduction` in lowering patterns. NFC. (#75952)

Use the `vector::makeArithReduction` helper as the source-of-truth of
reduction to arith ops lowering.

Added: 
    

Modified: 
    mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index ecee9a7b45e32b..a9f903e696dfb1 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -16,15 +16,44 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace mlir;
 
 namespace {
 
+static vector::CombiningKind
+convertReductionKind(gpu::AllReduceOperation mode) {
+  switch (mode) {
+#define MAP_CASE(X)                                                            \
+  case gpu::AllReduceOperation::X:                                             \
+    return vector::CombiningKind::X
+
+    MAP_CASE(ADD);
+    MAP_CASE(MUL);
+    MAP_CASE(MINUI);
+    MAP_CASE(MINSI);
+    MAP_CASE(MINF);
+    MAP_CASE(MAXSI);
+    MAP_CASE(MAXUI);
+    MAP_CASE(MAXF);
+    MAP_CASE(AND);
+    MAP_CASE(OR);
+    MAP_CASE(XOR);
+    MAP_CASE(MINIMUMF);
+    MAP_CASE(MAXIMUMF);
+
+#undef MAP_CASE
+  }
+
+  llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
+}
+
 struct GpuAllReduceRewriter {
   using AccumulatorFactory = std::function<Value(Value, Value)>;
 
@@ -181,7 +210,7 @@ struct GpuAllReduceRewriter {
   /// block is expected to have 2 arguments. The gpu.yield return the
   /// accumulated value of the same type.
   AccumulatorFactory getFactory(Region &body) {
-    return AccumulatorFactory([&](Value lhs, Value rhs) {
+    return [&body, this](Value lhs, Value rhs) -> Value {
       Block *block = rewriter.getInsertionBlock();
       Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
 
@@ -209,51 +238,14 @@ struct GpuAllReduceRewriter {
       // Return accumulator result.
       rewriter.setInsertionPointToStart(split);
       return split->addArgument(lhs.getType(), lhs.getLoc());
-    });
+    };
   }
 
   /// Returns an accumulator factory that creates an op specified by opName.
   AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
-    using Kind = gpu::AllReduceOperation;
-    bool isFloatingPoint = isa<FloatType>(valueType);
-    switch (opName) {
-    case Kind::ADD:
-      return isFloatingPoint ? getFactory<arith::AddFOp>()
-                             : getFactory<arith::AddIOp>();
-    case Kind::MUL:
-      return isFloatingPoint ? getFactory<arith::MulFOp>()
-                             : getFactory<arith::MulIOp>();
-    case Kind::MINSI:
-      return getFactory<arith::MinSIOp>();
-    case Kind::MINUI:
-      return getFactory<arith::MinUIOp>();
-    case Kind::MINF:
-      return getFactory<arith::MinNumFOp>();
-    case Kind::MAXSI:
-      return getFactory<arith::MaxSIOp>();
-    case Kind::MAXUI:
-      return getFactory<arith::MaxUIOp>();
-    case Kind::MAXF:
-      return getFactory<arith::MaxNumFOp>();
-    case Kind::AND:
-      return getFactory<arith::AndIOp>();
-    case Kind::OR:
-      return getFactory<arith::OrIOp>();
-    case Kind::XOR:
-      return getFactory<arith::XOrIOp>();
-    case Kind::MINIMUMF:
-      return getFactory<arith::MinimumFOp>();
-    case Kind::MAXIMUMF:
-      return getFactory<arith::MaximumFOp>();
-    }
-    llvm_unreachable("unknown GPU AllReduceOperation");
-  }
-
-  /// Returns an accumulator factory that creates an op of type T.
-  template <typename T>
-  AccumulatorFactory getFactory() {
-    return [this](Value lhs, Value rhs) {
-      return create<T>(lhs.getType(), lhs, rhs);
+    return [opName, this](Value lhs, Value rhs) {
+      return vector::makeArithReduction(rewriter, loc,
+                                        convertReductionKind(opName), lhs, rhs);
     };
   }
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index ef6e6f5264a221..c3ae7e74693cdd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -38,66 +38,6 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-/// This function constructs the appropriate integer or float
-/// operation given the vector combining kind and operands. The
-/// supported int operations are : add, mul, min (signed/unsigned),
-/// max(signed/unsigned), and, or, xor. The supported float
-/// operations are : add, mul, min and max.
-static Value genOperator(Location loc, Value x, Value y,
-                         vector::CombiningKind kind,
-                         PatternRewriter &rewriter) {
-  using vector::CombiningKind;
-
-  auto elType = cast<VectorType>(x.getType()).getElementType();
-  bool isInt = elType.isIntOrIndex();
-
-  Value combinedResult{nullptr};
-  switch (kind) {
-  case CombiningKind::ADD:
-    if (isInt)
-      combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
-    else
-      combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
-    break;
-  case CombiningKind::MUL:
-    if (isInt)
-      combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
-    else
-      combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
-    break;
-  case CombiningKind::MINUI:
-    combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
-    break;
-  case CombiningKind::MINSI:
-    combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
-    break;
-  case CombiningKind::MAXUI:
-    combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
-    break;
-  case CombiningKind::MAXSI:
-    combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
-    break;
-  case CombiningKind::AND:
-    combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
-    break;
-  case CombiningKind::OR:
-    combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
-    break;
-  case CombiningKind::XOR:
-    combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
-    break;
-  case CombiningKind::MINF:
-  case CombiningKind::MINIMUMF:
-    combinedResult = rewriter.create<arith::MinimumFOp>(loc, x, y);
-    break;
-  case CombiningKind::MAXF:
-  case CombiningKind::MAXIMUMF:
-    combinedResult = rewriter.create<arith::MaximumFOp>(loc, x, y);
-    break;
-  }
-  return combinedResult;
-}
-
 /// This function checks to see if the vector combining kind
 /// is consistent with the integer or float element type.
 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
@@ -224,8 +164,8 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
         }
       } else {
         Value y = inclusive ? input : lastInput;
-        output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
-        assert(output != nullptr);
+        output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
+                                            lastOutput, y);
       }
       result = rewriter.create<vector::InsertStridedSliceOp>(
           loc, output, result, offsets, strides);


        


More information about the Mlir-commits mailing list