[Mlir-commits] [mlir] [mlir][linalg] Fix isaConvolutionOpInterface logic (PR #102087)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 5 18:27:52 PDT 2024


https://github.com/yifeizh2 created https://github.com/llvm/llvm-project/pull/102087

Enhance convolution op judgement logic

>From bbdce7f27cf0c03e9279ef54c9bd18687820c348 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Mon, 5 Aug 2024 06:29:58 -0700
Subject: [PATCH] [mlir][linalg] Fix isaConvolutionOpInterface logic

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6ee1810c2ff2b..41143e0a5e347 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -762,7 +762,8 @@ enum class MatchConvolutionResult {
   NotProjectedPermutations,
   NonConvolutionLoop,
   OutputDimsNotParallel,
-  NonOutputDimNotReduction
+  NonOutputDimNotReduction,
+  NoValidConvolvedDim
 };
 } // namespace mlir::linalg::detail
 
@@ -810,6 +811,8 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
   // - Depth multiplier : unconvolved in input, present in output, present in
   //   filter.
   llvm::SmallDenseSet<int64_t> allLoopDims;
+  bool hasOutputImageDim = false;
+  bool hasFilterLoopDim = false;
   for (auto outputExpr : indexingMaps.back().getResults()) {
     int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -825,6 +828,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
       // Output image Loop dimension.
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
+      hasOutputImageDim = true;
       allLoopDims.insert(outputDim);
       continue;
     }
@@ -862,6 +866,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
         return MatchConvolutionResult::NonOutputDimNotReduction;
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
+      hasFilterLoopDim = true;
       allLoopDims.insert(filterDim);
       continue;
     }
@@ -886,6 +891,9 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
   if (allLoopDims.size() != linalgOp.getNumLoops())
     return MatchConvolutionResult::NonConvolutionLoop;
 
+  if (!hasOutputImageDim || !hasFilterLoopDim)
+    return MatchConvolutionResult::NoValidConvolvedDim;
+
   if (dimensions) {
     FailureOr<ConvolutionDimensions> res =
         inferConvolutionDimsImpl(linalgOp, inputExprWalker,



More information about the Mlir-commits mailing list