[Mlir-commits] [mlir] 436d17a - [mlir] Expose a function to get vector::CombiningKind from Operation*.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Jan 13 23:28:39 PST 2022
Author: Alexander Belyaev
Date: 2022-01-14T08:28:18+01:00
New Revision: 436d17a8e9e9e77d50914be30541d1557032f030
URL: https://github.com/llvm/llvm-project/commit/436d17a8e9e9e77d50914be30541d1557032f030
DIFF: https://github.com/llvm/llvm-project/commit/436d17a8e9e9e77d50914be30541d1557032f030.diff
LOG: [mlir] Expose a function to get vector::CombiningKind from Operation*.
Differential Revision: https://reviews.llvm.org/D117283
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 739d1b6d69a98..814474405715d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -920,6 +920,9 @@ struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
LinalgTransformationFilter filter;
};
+/// Return vector::CombiningKind for the given op.
+llvm::Optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
+
//===----------------------------------------------------------------------===//
// Transformation and lowering options exposed as auxiliary structs.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f78e179911882..86eaed9a136cb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -109,25 +109,25 @@ struct VectorizationResult {
Operation *newOp;
};
-static llvm::Optional<vector::CombiningKind>
-getKindForOp(Operation *reductionOp) {
- if (!reductionOp)
+llvm::Optional<vector::CombiningKind>
+mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
+ using ::mlir::vector::CombiningKind;
+
+ if (!combinerOp)
return llvm::None;
- return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
- reductionOp)
+ return llvm::TypeSwitch<Operation *, llvm::Optional<CombiningKind>>(
+ combinerOp)
.Case<arith::AddIOp, arith::AddFOp>(
- [&](auto op) { return vector::CombiningKind::ADD; })
- .Case<arith::AndIOp>([&](auto op) { return vector::CombiningKind::AND; })
- .Case<arith::MaxSIOp>(
- [&](auto op) { return vector::CombiningKind::MAXSI; })
- .Case<arith::MaxFOp>([&](auto op) { return vector::CombiningKind::MAXF; })
- .Case<arith::MinSIOp>(
- [&](auto op) { return vector::CombiningKind::MINSI; })
- .Case<arith::MinFOp>([&](auto op) { return vector::CombiningKind::MINF; })
+ [&](auto op) { return CombiningKind::ADD; })
+ .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
+ .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
+ .Case<arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
+ .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
+ .Case<arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
.Case<arith::MulIOp, arith::MulFOp>(
- [&](auto op) { return vector::CombiningKind::MUL; })
- .Case<arith::OrIOp>([&](auto op) { return vector::CombiningKind::OR; })
- .Case<arith::XOrIOp>([&](auto op) { return vector::CombiningKind::XOR; })
+ [&](auto op) { return CombiningKind::MUL; })
+ .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
+ .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
.Default([&](auto op) { return llvm::None; });
}
@@ -174,7 +174,7 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
Value valueToReduce,
const SmallVector<bool> &reductionMask) {
- auto maybeKind = getKindForOp(reduceOp);
+ auto maybeKind = getCombinerOpKind(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
return b.create<vector::MultiDimReductionOp>(
reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
@@ -589,7 +589,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
}
for (OpOperand *opOperand : op.getOutputOperands()) {
Operation *reduceOp = matchLinalgReduction(opOperand);
- if (!reduceOp || !getKindForOp(reduceOp)) {
+ if (!reduceOp || !getCombinerOpKind(reduceOp)) {
LDBG("reduction precondition failed: reduction detection failed");
return failure();
}
@@ -1458,10 +1458,10 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
if (!reduceOp)
return;
llvm::Optional<vector::CombiningKind> maybeKind;
- maybeKind = getKindForOp(reduceOp);
+ maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
return;
- maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front()));
+ maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front()));
if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
return;
More information about the Mlir-commits
mailing list