[Mlir-commits] [mlir] [mlir][linalg] Fix isaConvolutionOpInterface logic (PR #102087)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 5 18:28:19 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: None (yifeizh2)
<details>
<summary>Changes</summary>
Enhance convolution op judgement logic
---
Full diff: https://github.com/llvm/llvm-project/pull/102087.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+9-1)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6ee1810c2ff2b..41143e0a5e347 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -762,7 +762,8 @@ enum class MatchConvolutionResult {
NotProjectedPermutations,
NonConvolutionLoop,
OutputDimsNotParallel,
- NonOutputDimNotReduction
+ NonOutputDimNotReduction,
+ NoValidConvolvedDim
};
} // namespace mlir::linalg::detail
@@ -810,6 +811,8 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// - Depth multiplier : unconvolved in input, present in output, present in
// filter.
llvm::SmallDenseSet<int64_t> allLoopDims;
+ bool hasOutputImageDim = false;
+ bool hasFilterLoopDim = false;
for (auto outputExpr : indexingMaps.back().getResults()) {
int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -825,6 +828,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// Output image Loop dimension.
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
+ hasOutputImageDim = true;
allLoopDims.insert(outputDim);
continue;
}
@@ -862,6 +866,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
+ hasFilterLoopDim = true;
allLoopDims.insert(filterDim);
continue;
}
@@ -886,6 +891,9 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
if (allLoopDims.size() != linalgOp.getNumLoops())
return MatchConvolutionResult::NonConvolutionLoop;
+ if (!hasOutputImageDim || !hasFilterLoopDim)
+ return MatchConvolutionResult::NoValidConvolvedDim;
+
if (dimensions) {
FailureOr<ConvolutionDimensions> res =
inferConvolutionDimsImpl(linalgOp, inputExprWalker,
``````````
</details>
https://github.com/llvm/llvm-project/pull/102087
More information about the Mlir-commits
mailing list