[Mlir-commits] [mlir] [mlir][linalg] Fix isaConvolutionOpInterface logic (PR #102087)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 5 22:48:59 PDT 2024


https://github.com/yifeizh2 updated https://github.com/llvm/llvm-project/pull/102087

>From bbdce7f27cf0c03e9279ef54c9bd18687820c348 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Mon, 5 Aug 2024 06:29:58 -0700
Subject: [PATCH 1/4] [mlir][linalg] Fix isaConvolutionOpInterface logic

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6ee1810c2ff2b..41143e0a5e347 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -762,7 +762,8 @@ enum class MatchConvolutionResult {
   NotProjectedPermutations,
   NonConvolutionLoop,
   OutputDimsNotParallel,
-  NonOutputDimNotReduction
+  NonOutputDimNotReduction,
+  NoValidConvolvedDim
 };
 } // namespace mlir::linalg::detail
 
@@ -810,6 +811,8 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
   // - Depth multiplier : unconvolved in input, present in output, present in
   //   filter.
   llvm::SmallDenseSet<int64_t> allLoopDims;
+  bool hasOutputImageDim = false;
+  bool hasFilterLoopDim = false;
   for (auto outputExpr : indexingMaps.back().getResults()) {
     int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -825,6 +828,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
       // Output image Loop dimension.
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
+      hasOutputImageDim = true;
       allLoopDims.insert(outputDim);
       continue;
     }
@@ -862,6 +866,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
         return MatchConvolutionResult::NonOutputDimNotReduction;
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
+      hasFilterLoopDim = true;
       allLoopDims.insert(filterDim);
       continue;
     }
@@ -886,6 +891,9 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
   if (allLoopDims.size() != linalgOp.getNumLoops())
     return MatchConvolutionResult::NonConvolutionLoop;
 
+  if (!hasOutputImageDim || !hasFilterLoopDim)
+    return MatchConvolutionResult::NoValidConvolvedDim;
+
   if (dimensions) {
     FailureOr<ConvolutionDimensions> res =
         inferConvolutionDimsImpl(linalgOp, inputExprWalker,

>From effa82a9b4f6975da9f403bb7b8a7039e33fcc33 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Mon, 5 Aug 2024 18:53:44 -0700
Subject: [PATCH 2/4] update logic

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 14 ++++----------
 1 file changed, 4 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 41143e0a5e347..7e7aa87e93825 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -763,7 +763,7 @@ enum class MatchConvolutionResult {
   NonConvolutionLoop,
   OutputDimsNotParallel,
   NonOutputDimNotReduction,
-  NoValidConvolvedDim
+  EmptyConvolvedDims
 };
 } // namespace mlir::linalg::detail
 
@@ -811,8 +811,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
   // - Depth multiplier : unconvolved in input, present in output, present in
   //   filter.
   llvm::SmallDenseSet<int64_t> allLoopDims;
-  bool hasOutputImageDim = false;
-  bool hasFilterLoopDim = false;
   for (auto outputExpr : indexingMaps.back().getResults()) {
     int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
     if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -828,7 +826,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
       // Output image Loop dimension.
       if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
         return MatchConvolutionResult::OutputDimsNotParallel;
-      hasOutputImageDim = true;
       allLoopDims.insert(outputDim);
       continue;
     }
@@ -866,7 +863,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
         return MatchConvolutionResult::NonOutputDimNotReduction;
       if (allLoopDims.count(filterDim))
         return MatchConvolutionResult::NonConvolutionLoop;
-      hasFilterLoopDim = true;
       allLoopDims.insert(filterDim);
       continue;
     }
@@ -891,14 +887,12 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
   if (allLoopDims.size() != linalgOp.getNumLoops())
     return MatchConvolutionResult::NonConvolutionLoop;
 
-  if (!hasOutputImageDim || !hasFilterLoopDim)
-    return MatchConvolutionResult::NoValidConvolvedDim;
-
   if (dimensions) {
     FailureOr<ConvolutionDimensions> res =
         inferConvolutionDimsImpl(linalgOp, inputExprWalker,
-                                 /*allowEmptyConvolvedDims=*/true);
-    assert(succeeded(res) && "unexpected failure to infer convolution dims");
+                                 /*allowEmptyConvolvedDims=*/false);
+    if (failed(res))
+      return MatchConvolutionResult::EmptyConvolvedDims;
     *dimensions = *res;
   }
 

>From 5170861cfd5ec7428d67a48b172edd84f43b97d1 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Mon, 5 Aug 2024 21:42:59 -0700
Subject: [PATCH 3/4] expose allowEmptyConvolvedDims

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h  |  6 ++++--
 .../lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 18 +++++++++++-------
 2 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 08afdf373f014..3cc90559b7d81 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -111,7 +111,8 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
 
 /// Checks whether `linalgOp` conforms to ConvolutionOpInterface.
 // TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
-bool isaConvolutionOpInterface(LinalgOp linalgOp);
+bool isaConvolutionOpInterface(LinalgOp linalgOp,
+                               bool allowEmptyConvolvedDims = false);
 
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
@@ -177,7 +178,8 @@ enum class MatchConvolutionResult;
 /// present.
 MatchConvolutionResult
 isConvolutionInterfaceImpl(Operation *op,
-                           ConvolutionDimensions *dimensions = nullptr);
+                           ConvolutionDimensions *dimensions = nullptr,
+                           bool allowEmptyConvolvedDims = false);
 
 /// Returns the error message corresponding to the convolution checking return
 /// code.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 7e7aa87e93825..641ff7bd7e93f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -769,7 +769,8 @@ enum class MatchConvolutionResult {
 
 mlir::linalg::detail::MatchConvolutionResult
 mlir::linalg::detail::isConvolutionInterfaceImpl(
-    Operation *op, ConvolutionDimensions *dimensions) {
+    Operation *op, ConvolutionDimensions *dimensions,
+    bool allowEmptyConvolvedDims) {
   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
   if (!linalgOp)
     return MatchConvolutionResult::NotLinalgOp;
@@ -888,11 +889,12 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
     return MatchConvolutionResult::NonConvolutionLoop;
 
   if (dimensions) {
-    FailureOr<ConvolutionDimensions> res =
-        inferConvolutionDimsImpl(linalgOp, inputExprWalker,
-                                 /*allowEmptyConvolvedDims=*/false);
-    if (failed(res))
+    FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
+        linalgOp, inputExprWalker, allowEmptyConvolvedDims);
+    if (!allowEmptyConvolvedDims && failed(res))
       return MatchConvolutionResult::EmptyConvolvedDims;
+    if (allowEmptyConvolvedDims)
+      assert(succeeded(res) && "unexpected failure to infer convolution dims");
     *dimensions = *res;
   }
 
@@ -922,8 +924,10 @@ mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
   llvm_unreachable("unhandled MatchConvolutionResult case");
 }
 
-bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp) {
-  return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) ==
+bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp,
+                                             bool allowEmptyConvolvedDims) {
+  return linalg::detail::isConvolutionInterfaceImpl(
+             linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
          linalg::detail::MatchConvolutionResult::Success;
 }
 

>From 5328b742608eb434781430208de73453e0db3b0f Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Mon, 5 Aug 2024 22:48:31 -0700
Subject: [PATCH 4/4] fix logic

---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 641ff7bd7e93f..a38b20eed3a00 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -888,13 +888,13 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
   if (allLoopDims.size() != linalgOp.getNumLoops())
     return MatchConvolutionResult::NonConvolutionLoop;
 
+  if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
+    return MatchConvolutionResult::EmptyConvolvedDims;
+
   if (dimensions) {
     FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
         linalgOp, inputExprWalker, allowEmptyConvolvedDims);
-    if (!allowEmptyConvolvedDims && failed(res))
-      return MatchConvolutionResult::EmptyConvolvedDims;
-    if (allowEmptyConvolvedDims)
-      assert(succeeded(res) && "unexpected failure to infer convolution dims");
+    assert(succeeded(res) && "unexpected failure to infer convolution dims");
     *dimensions = *res;
   }
 



More information about the Mlir-commits mailing list