[Mlir-commits] [mlir] 9f74e6e - [mlir][vector][gpu] Use `makeArithReduction` in lowering patterns. NFC. (#75952)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 19 16:04:31 PST 2023
Author: Jakub Kuderski
Date: 2023-12-19T19:04:27-05:00
New Revision: 9f74e6e6157bc4d63a28385c7c0a50506bb8a737
URL: https://github.com/llvm/llvm-project/commit/9f74e6e6157bc4d63a28385c7c0a50506bb8a737
DIFF: https://github.com/llvm/llvm-project/commit/9f74e6e6157bc4d63a28385c7c0a50506bb8a737.diff
LOG: [mlir][vector][gpu] Use `makeArithReduction` in lowering patterns. NFC. (#75952)
Use the `vector::makeArithReduction` helper as the source-of-truth of
reduction to arith ops lowering.
Added:
Modified:
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index ecee9a7b45e32b..a9f903e696dfb1 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -16,15 +16,44 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
namespace {
+static vector::CombiningKind
+convertReductionKind(gpu::AllReduceOperation mode) {
+ switch (mode) {
+#define MAP_CASE(X) \
+ case gpu::AllReduceOperation::X: \
+ return vector::CombiningKind::X
+
+ MAP_CASE(ADD);
+ MAP_CASE(MUL);
+ MAP_CASE(MINUI);
+ MAP_CASE(MINSI);
+ MAP_CASE(MINF);
+ MAP_CASE(MAXSI);
+ MAP_CASE(MAXUI);
+ MAP_CASE(MAXF);
+ MAP_CASE(AND);
+ MAP_CASE(OR);
+ MAP_CASE(XOR);
+ MAP_CASE(MINIMUMF);
+ MAP_CASE(MAXIMUMF);
+
+#undef MAP_CASE
+ }
+
+ llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
+}
+
struct GpuAllReduceRewriter {
using AccumulatorFactory = std::function<Value(Value, Value)>;
@@ -181,7 +210,7 @@ struct GpuAllReduceRewriter {
/// block is expected to have 2 arguments. The gpu.yield return the
/// accumulated value of the same type.
AccumulatorFactory getFactory(Region &body) {
- return AccumulatorFactory([&](Value lhs, Value rhs) {
+ return [&body, this](Value lhs, Value rhs) -> Value {
Block *block = rewriter.getInsertionBlock();
Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
@@ -209,51 +238,14 @@ struct GpuAllReduceRewriter {
// Return accumulator result.
rewriter.setInsertionPointToStart(split);
return split->addArgument(lhs.getType(), lhs.getLoc());
- });
+ };
}
/// Returns an accumulator factory that creates an op specified by opName.
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
- using Kind = gpu::AllReduceOperation;
- bool isFloatingPoint = isa<FloatType>(valueType);
- switch (opName) {
- case Kind::ADD:
- return isFloatingPoint ? getFactory<arith::AddFOp>()
- : getFactory<arith::AddIOp>();
- case Kind::MUL:
- return isFloatingPoint ? getFactory<arith::MulFOp>()
- : getFactory<arith::MulIOp>();
- case Kind::MINSI:
- return getFactory<arith::MinSIOp>();
- case Kind::MINUI:
- return getFactory<arith::MinUIOp>();
- case Kind::MINF:
- return getFactory<arith::MinNumFOp>();
- case Kind::MAXSI:
- return getFactory<arith::MaxSIOp>();
- case Kind::MAXUI:
- return getFactory<arith::MaxUIOp>();
- case Kind::MAXF:
- return getFactory<arith::MaxNumFOp>();
- case Kind::AND:
- return getFactory<arith::AndIOp>();
- case Kind::OR:
- return getFactory<arith::OrIOp>();
- case Kind::XOR:
- return getFactory<arith::XOrIOp>();
- case Kind::MINIMUMF:
- return getFactory<arith::MinimumFOp>();
- case Kind::MAXIMUMF:
- return getFactory<arith::MaximumFOp>();
- }
- llvm_unreachable("unknown GPU AllReduceOperation");
- }
-
- /// Returns an accumulator factory that creates an op of type T.
- template <typename T>
- AccumulatorFactory getFactory() {
- return [this](Value lhs, Value rhs) {
- return create<T>(lhs.getType(), lhs, rhs);
+ return [opName, this](Value lhs, Value rhs) {
+ return vector::makeArithReduction(rewriter, loc,
+ convertReductionKind(opName), lhs, rhs);
};
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index ef6e6f5264a221..c3ae7e74693cdd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -38,66 +38,6 @@
using namespace mlir;
using namespace mlir::vector;
-/// This function constructs the appropriate integer or float
-/// operation given the vector combining kind and operands. The
-/// supported int operations are : add, mul, min (signed/unsigned),
-/// max(signed/unsigned), and, or, xor. The supported float
-/// operations are : add, mul, min and max.
-static Value genOperator(Location loc, Value x, Value y,
- vector::CombiningKind kind,
- PatternRewriter &rewriter) {
- using vector::CombiningKind;
-
- auto elType = cast<VectorType>(x.getType()).getElementType();
- bool isInt = elType.isIntOrIndex();
-
- Value combinedResult{nullptr};
- switch (kind) {
- case CombiningKind::ADD:
- if (isInt)
- combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
- else
- combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
- break;
- case CombiningKind::MUL:
- if (isInt)
- combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
- else
- combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
- break;
- case CombiningKind::MINUI:
- combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
- break;
- case CombiningKind::MINSI:
- combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
- break;
- case CombiningKind::MAXUI:
- combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
- break;
- case CombiningKind::MAXSI:
- combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
- break;
- case CombiningKind::AND:
- combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
- break;
- case CombiningKind::OR:
- combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
- break;
- case CombiningKind::XOR:
- combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
- break;
- case CombiningKind::MINF:
- case CombiningKind::MINIMUMF:
- combinedResult = rewriter.create<arith::MinimumFOp>(loc, x, y);
- break;
- case CombiningKind::MAXF:
- case CombiningKind::MAXIMUMF:
- combinedResult = rewriter.create<arith::MaximumFOp>(loc, x, y);
- break;
- }
- return combinedResult;
-}
-
/// This function checks to see if the vector combining kind
/// is consistent with the integer or float element type.
static bool isValidKind(bool isInt, vector::CombiningKind kind) {
@@ -224,8 +164,8 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
}
} else {
Value y = inclusive ? input : lastInput;
- output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
- assert(output != nullptr);
+ output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
+ lastOutput, y);
}
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, output, result, offsets, strides);
More information about the Mlir-commits
mailing list