[Mlir-commits] [mlir] 436d17a - [mlir] Expose a function to get vector::CombiningKind from Operation*.

Alexander Belyaev llvmlistbot at llvm.org
Thu Jan 13 23:28:39 PST 2022


Author: Alexander Belyaev
Date: 2022-01-14T08:28:18+01:00
New Revision: 436d17a8e9e9e77d50914be30541d1557032f030

URL: https://github.com/llvm/llvm-project/commit/436d17a8e9e9e77d50914be30541d1557032f030
DIFF: https://github.com/llvm/llvm-project/commit/436d17a8e9e9e77d50914be30541d1557032f030.diff

LOG: [mlir] Expose a function to get vector::CombiningKind from Operation*.

Differential Revision: https://reviews.llvm.org/D117283

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 739d1b6d69a98..814474405715d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -920,6 +920,9 @@ struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
   LinalgTransformationFilter filter;
 };
 
+/// Return vector::CombiningKind for the given op.
+llvm::Optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
+
 //===----------------------------------------------------------------------===//
 // Transformation and lowering options exposed as auxiliary structs.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f78e179911882..86eaed9a136cb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -109,25 +109,25 @@ struct VectorizationResult {
   Operation *newOp;
 };
 
-static llvm::Optional<vector::CombiningKind>
-getKindForOp(Operation *reductionOp) {
-  if (!reductionOp)
+llvm::Optional<vector::CombiningKind>
+mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
+  using ::mlir::vector::CombiningKind;
+
+  if (!combinerOp)
     return llvm::None;
-  return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
-             reductionOp)
+  return llvm::TypeSwitch<Operation *, llvm::Optional<CombiningKind>>(
+             combinerOp)
       .Case<arith::AddIOp, arith::AddFOp>(
-          [&](auto op) { return vector::CombiningKind::ADD; })
-      .Case<arith::AndIOp>([&](auto op) { return vector::CombiningKind::AND; })
-      .Case<arith::MaxSIOp>(
-          [&](auto op) { return vector::CombiningKind::MAXSI; })
-      .Case<arith::MaxFOp>([&](auto op) { return vector::CombiningKind::MAXF; })
-      .Case<arith::MinSIOp>(
-          [&](auto op) { return vector::CombiningKind::MINSI; })
-      .Case<arith::MinFOp>([&](auto op) { return vector::CombiningKind::MINF; })
+          [&](auto op) { return CombiningKind::ADD; })
+      .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
+      .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
+      .Case<arith::MaxFOp>([&](auto op) { return CombiningKind::MAXF; })
+      .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
+      .Case<arith::MinFOp>([&](auto op) { return CombiningKind::MINF; })
       .Case<arith::MulIOp, arith::MulFOp>(
-          [&](auto op) { return vector::CombiningKind::MUL; })
-      .Case<arith::OrIOp>([&](auto op) { return vector::CombiningKind::OR; })
-      .Case<arith::XOrIOp>([&](auto op) { return vector::CombiningKind::XOR; })
+          [&](auto op) { return CombiningKind::MUL; })
+      .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
+      .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
       .Default([&](auto op) { return llvm::None; });
 }
 
@@ -174,7 +174,7 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
 static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
                                  Value valueToReduce,
                                  const SmallVector<bool> &reductionMask) {
-  auto maybeKind = getKindForOp(reduceOp);
+  auto maybeKind = getCombinerOpKind(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
   return b.create<vector::MultiDimReductionOp>(
       reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
@@ -589,7 +589,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
   }
   for (OpOperand *opOperand : op.getOutputOperands()) {
     Operation *reduceOp = matchLinalgReduction(opOperand);
-    if (!reduceOp || !getKindForOp(reduceOp)) {
+    if (!reduceOp || !getCombinerOpKind(reduceOp)) {
       LDBG("reduction precondition failed: reduction detection failed");
       return failure();
     }
@@ -1458,10 +1458,10 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     if (!reduceOp)
       return;
     llvm::Optional<vector::CombiningKind> maybeKind;
-    maybeKind = getKindForOp(reduceOp);
+    maybeKind = getCombinerOpKind(reduceOp);
     if (!maybeKind || *maybeKind != vector::CombiningKind::ADD)
       return;
-    maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front()));
+    maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front()));
     if (!maybeKind || *maybeKind != vector::CombiningKind::MUL)
       return;
 


        


More information about the Mlir-commits mailing list