[Mlir-commits] [mlir] [mlir][linalg] Fix isaConvolutionOpInterface logic (PR #102087)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 5 18:54:16 PDT 2024
https://github.com/yifeizh2 updated https://github.com/llvm/llvm-project/pull/102087
>From bbdce7f27cf0c03e9279ef54c9bd18687820c348 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Mon, 5 Aug 2024 06:29:58 -0700
Subject: [PATCH 1/2] [mlir][linalg] Fix isaConvolutionOpInterface logic
---
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6ee1810c2ff2b..41143e0a5e347 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -762,7 +762,8 @@ enum class MatchConvolutionResult {
NotProjectedPermutations,
NonConvolutionLoop,
OutputDimsNotParallel,
- NonOutputDimNotReduction
+ NonOutputDimNotReduction,
+ NoValidConvolvedDim
};
} // namespace mlir::linalg::detail
@@ -810,6 +811,8 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// - Depth multiplier : unconvolved in input, present in output, present in
// filter.
llvm::SmallDenseSet<int64_t> allLoopDims;
+ bool hasOutputImageDim = false;
+ bool hasFilterLoopDim = false;
for (auto outputExpr : indexingMaps.back().getResults()) {
int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -825,6 +828,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// Output image Loop dimension.
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
+ hasOutputImageDim = true;
allLoopDims.insert(outputDim);
continue;
}
@@ -862,6 +866,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
+ hasFilterLoopDim = true;
allLoopDims.insert(filterDim);
continue;
}
@@ -886,6 +891,9 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
if (allLoopDims.size() != linalgOp.getNumLoops())
return MatchConvolutionResult::NonConvolutionLoop;
+ if (!hasOutputImageDim || !hasFilterLoopDim)
+ return MatchConvolutionResult::NoValidConvolvedDim;
+
if (dimensions) {
FailureOr<ConvolutionDimensions> res =
inferConvolutionDimsImpl(linalgOp, inputExprWalker,
>From effa82a9b4f6975da9f403bb7b8a7039e33fcc33 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Mon, 5 Aug 2024 18:53:44 -0700
Subject: [PATCH 2/2] update logic
---
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 14 ++++----------
1 file changed, 4 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 41143e0a5e347..7e7aa87e93825 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -763,7 +763,7 @@ enum class MatchConvolutionResult {
NonConvolutionLoop,
OutputDimsNotParallel,
NonOutputDimNotReduction,
- NoValidConvolvedDim
+ EmptyConvolvedDims
};
} // namespace mlir::linalg::detail
@@ -811,8 +811,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// - Depth multiplier : unconvolved in input, present in output, present in
// filter.
llvm::SmallDenseSet<int64_t> allLoopDims;
- bool hasOutputImageDim = false;
- bool hasFilterLoopDim = false;
for (auto outputExpr : indexingMaps.back().getResults()) {
int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -828,7 +826,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// Output image Loop dimension.
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
- hasOutputImageDim = true;
allLoopDims.insert(outputDim);
continue;
}
@@ -866,7 +863,6 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
- hasFilterLoopDim = true;
allLoopDims.insert(filterDim);
continue;
}
@@ -891,14 +887,12 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
if (allLoopDims.size() != linalgOp.getNumLoops())
return MatchConvolutionResult::NonConvolutionLoop;
- if (!hasOutputImageDim || !hasFilterLoopDim)
- return MatchConvolutionResult::NoValidConvolvedDim;
-
if (dimensions) {
FailureOr<ConvolutionDimensions> res =
inferConvolutionDimsImpl(linalgOp, inputExprWalker,
- /*allowEmptyConvolvedDims=*/true);
- assert(succeeded(res) && "unexpected failure to infer convolution dims");
+ /*allowEmptyConvolvedDims=*/false);
+ if (failed(res))
+ return MatchConvolutionResult::EmptyConvolvedDims;
*dimensions = *res;
}
More information about the Mlir-commits
mailing list