[Mlir-commits] [mlir] [mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm (PR #96181)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 8 09:00:02 PDT 2024
================
@@ -2764,40 +2776,35 @@ LogicalResult WinogradInputTransformOp::verify() {
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputH = inputShape[1];
int64_t inputW = inputShape[2];
- auto outputType = cast<ShapedType>(getOutput().getType());
- ArrayRef<int64_t> outputShape = outputType.getShape();
- int64_t outputH = outputShape[0];
- int64_t outputW = outputShape[1];
- int64_t outputTileH = outputShape[2];
- int64_t outputTileW = outputShape[3];
int m = getM();
int r = getR();
+ int64_t tileSize = m + r - 1;
bool leftTransform = inputH != 1;
bool rightTransform = inputW != 1;
- if (!leftTransform && !rightTransform)
- return failure();
-
- if (leftTransform) {
- int64_t tileH = (inputH - (r - 1)) / m;
- if (inputH != tileH * m + (r - 1))
- return emitOpError("input height cannot be tiled in full tile size");
- if (tileH != outputTileH)
- return emitOpError("number of output height tiles is not correct");
- if (outputH != m + r - 1)
- return emitOpError("expect output height equals to tile size");
+ SmallVector<int64_t> expectedOutputShape(6, inputH);
+ if (ShapedType::isDynamic(inputH)) {
+ expectedOutputShape[0] = tileSize;
+ expectedOutputShape[2] = -1;
+ } else {
+ expectedOutputShape[0] = leftTransform ? tileSize : 1;
+ expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1;
}
-
- if (rightTransform) {
- int64_t tileW = (inputW - (r - 1)) / m;
- if (inputW != tileW * m + (r - 1))
- return emitOpError("input width cannot be tiled in full tile size");
- if (tileW != outputTileW)
- return emitOpError("number of output width tiles is not correct");
- if (outputW != m + r - 1)
- return emitOpError("expect output width equals to tile size");
+ if (ShapedType::isDynamic(inputW)) {
+ expectedOutputShape[1] = tileSize;
+ expectedOutputShape[3] = -1;
----------------
Max191 wrote:
ditto, use `ShapedType::kDynamic`
https://github.com/llvm/llvm-project/pull/96181
More information about the Mlir-commits
mailing list