[Mlir-commits] [mlir] 9b5a3d1 - [mlir][vector] Add helper that builds a scalar reduction according to CombiningKind
Matthias Springer
llvmlistbot at llvm.org
Thu Feb 10 05:35:52 PST 2022
Author: Matthias Springer
Date: 2022-02-10T22:35:43+09:00
New Revision: 9b5a3d14b2c3bed697f3f5c873bef82bead27818
URL: https://github.com/llvm/llvm-project/commit/9b5a3d14b2c3bed697f3f5c873bef82bead27818
DIFF: https://github.com/llvm/llvm-project/commit/9b5a3d14b2c3bed697f3f5c873bef82bead27818.diff
LOG: [mlir][vector] Add helper that builds a scalar reduction according to CombiningKind
Differential Revision: https://reviews.llvm.org/D119433
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index f06ef2bc38ad..f6b84f1e28cd 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
@@ -30,12 +31,14 @@ class VectorType;
class VectorTransferOpInterface;
namespace vector {
-class TransferWriteOp;
-class TransferReadOp;
-
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
+
+/// Return the result value of reducing two scalar/vector values with the
+/// corresponding arith operation.
+Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
+ Value v1, Value v2);
} // namespace vector
/// Return the number of elements of basis, `0` if empty.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 1d37a3345827..db5c667a4935 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -243,47 +243,8 @@ struct TwoDimMultiReductionToElementWise
for (int64_t i = 1; i < srcShape[0]; i++) {
auto operand =
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
- switch (multiReductionOp.kind()) {
- case vector::CombiningKind::ADD:
- if (elementType.isIntOrIndex())
- result = rewriter.create<arith::AddIOp>(loc, operand, result);
- else
- result = rewriter.create<arith::AddFOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MUL:
- if (elementType.isIntOrIndex())
- result = rewriter.create<arith::MulIOp>(loc, operand, result);
- else
- result = rewriter.create<arith::MulFOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MINUI:
- result = rewriter.create<arith::MinUIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MINSI:
- result = rewriter.create<arith::MinSIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MINF:
- result = rewriter.create<arith::MinFOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MAXUI:
- result = rewriter.create<arith::MaxUIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MAXSI:
- result = rewriter.create<arith::MaxSIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::MAXF:
- result = rewriter.create<arith::MaxFOp>(loc, operand, result);
- break;
- case vector::CombiningKind::AND:
- result = rewriter.create<arith::AndIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::OR:
- result = rewriter.create<arith::OrIOp>(loc, operand, result);
- break;
- case vector::CombiningKind::XOR:
- result = rewriter.create<arith::XOrIOp>(loc, operand, result);
- break;
- }
+ result = makeArithReduction(rewriter, loc, multiReductionOp.kind(),
+ operand, result);
}
rewriter.replaceOp(multiReductionOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 155beb8d91fd..226faccaba96 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -10,6 +10,8 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+
#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -18,8 +20,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
-
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@@ -514,40 +515,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
if (!acc)
return Optional<Value>(mul);
- Value combinedResult;
- switch (kind) {
- case CombiningKind::ADD:
- combinedResult = rewriter.create<arith::AddIOp>(loc, mul, acc);
- break;
- case CombiningKind::MUL:
- combinedResult = rewriter.create<arith::MulIOp>(loc, mul, acc);
- break;
- case CombiningKind::MINUI:
- combinedResult = rewriter.create<arith::MinUIOp>(loc, mul, acc);
- break;
- case CombiningKind::MINSI:
- combinedResult = rewriter.create<arith::MinSIOp>(loc, mul, acc);
- break;
- case CombiningKind::MAXUI:
- combinedResult = rewriter.create<arith::MaxUIOp>(loc, mul, acc);
- break;
- case CombiningKind::MAXSI:
- combinedResult = rewriter.create<arith::MaxSIOp>(loc, mul, acc);
- break;
- case CombiningKind::AND:
- combinedResult = rewriter.create<arith::AndIOp>(loc, mul, acc);
- break;
- case CombiningKind::OR:
- combinedResult = rewriter.create<arith::OrIOp>(loc, mul, acc);
- break;
- case CombiningKind::XOR:
- combinedResult = rewriter.create<arith::XOrIOp>(loc, mul, acc);
- break;
- case CombiningKind::MINF: // Only valid for floating point types.
- case CombiningKind::MAXF: // Only valid for floating point types.
+ if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
+ // Only valid for floating point types.
return Optional<Value>();
- }
- return Optional<Value>(combinedResult);
+
+ return makeArithReduction(rewriter, loc, kind, mul, acc);
}
static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
@@ -565,28 +537,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
if (!acc)
return Optional<Value>(mul);
- Value combinedResult;
- switch (kind) {
- case CombiningKind::MUL:
- combinedResult = rewriter.create<arith::MulFOp>(loc, mul, acc);
- break;
- case CombiningKind::MINF:
- combinedResult = rewriter.create<arith::MinFOp>(loc, mul, acc);
- break;
- case CombiningKind::MAXF:
- combinedResult = rewriter.create<arith::MaxFOp>(loc, mul, acc);
- break;
- case CombiningKind::ADD: // Already handled this special case above.
- case CombiningKind::AND: // Only valid for integer types.
- case CombiningKind::MINUI: // Only valid for integer types.
- case CombiningKind::MINSI: // Only valid for integer types.
- case CombiningKind::MAXUI: // Only valid for integer types.
- case CombiningKind::MAXSI: // Only valid for integer types.
- case CombiningKind::OR: // Only valid for integer types.
- case CombiningKind::XOR: // Only valid for integer types.
+ if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
+ kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
+ kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
+ kind == CombiningKind::OR || kind == CombiningKind::XOR)
+ // Already handled or only valid for integer types.
return Optional<Value>();
- }
- return Optional<Value>(combinedResult);
+
+ return makeArithReduction(rewriter, loc, kind, mul, acc);
}
};
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 800e19ceda71..5388b7bd6ca1 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
#include <numeric>
@@ -42,6 +43,56 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
llvm_unreachable("Expected MemRefType or TensorType");
}
+Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
+ CombiningKind kind, Value v1, Value v2) {
+ Type t1 = getElementTypeOrSelf(v1.getType());
+ Type t2 = getElementTypeOrSelf(v2.getType());
+ switch (kind) {
+ case CombiningKind::ADD:
+ if (t1.isIntOrIndex() && t2.isIntOrIndex())
+ return b.createOrFold<arith::AddIOp>(loc, v1, v2);
+ else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+ return b.createOrFold<arith::AddFOp>(loc, v1, v2);
+ llvm_unreachable("invalid value types for ADD reduction");
+ case CombiningKind::AND:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::AndIOp>(loc, v1, v2);
+ case CombiningKind::MAXF:
+ assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ "expected float values");
+ return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
+ case CombiningKind::MINF:
+ assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ "expected float values");
+ return b.createOrFold<arith::MinFOp>(loc, v1, v2);
+ case CombiningKind::MAXSI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
+ case CombiningKind::MINSI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
+ case CombiningKind::MAXUI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
+ case CombiningKind::MINUI:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
+ case CombiningKind::MUL:
+ if (t1.isIntOrIndex() && t2.isIntOrIndex())
+ return b.createOrFold<arith::MulIOp>(loc, v1, v2);
+ else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+ return b.createOrFold<arith::MulFOp>(loc, v1, v2);
+ llvm_unreachable("invalid value types for MUL reduction");
+ case CombiningKind::OR:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::OrIOp>(loc, v1, v2);
+ case CombiningKind::XOR:
+ assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+ return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
+ };
+ llvm_unreachable("unknown CombiningKind");
+}
+
/// Return the number of elements of basis, `0` if empty.
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
if (basis.empty())
More information about the Mlir-commits
mailing list