[Mlir-commits] [mlir] [mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm (PR #96181)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 8 09:00:02 PDT 2024


================
@@ -2764,40 +2776,35 @@ LogicalResult WinogradInputTransformOp::verify() {
   ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputH = inputShape[1];
   int64_t inputW = inputShape[2];
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  ArrayRef<int64_t> outputShape = outputType.getShape();
-  int64_t outputH = outputShape[0];
-  int64_t outputW = outputShape[1];
-  int64_t outputTileH = outputShape[2];
-  int64_t outputTileW = outputShape[3];
   int m = getM();
   int r = getR();
+  int64_t tileSize = m + r - 1;
   bool leftTransform = inputH != 1;
   bool rightTransform = inputW != 1;
 
-  if (!leftTransform && !rightTransform)
-    return failure();
-
-  if (leftTransform) {
-    int64_t tileH = (inputH - (r - 1)) / m;
-    if (inputH != tileH * m + (r - 1))
-      return emitOpError("input height cannot be tiled in full tile size");
-    if (tileH != outputTileH)
-      return emitOpError("number of output height tiles is not correct");
-    if (outputH != m + r - 1)
-      return emitOpError("expect output height equals to tile size");
+  SmallVector<int64_t> expectedOutputShape(6, inputH);
+  if (ShapedType::isDynamic(inputH)) {
+    expectedOutputShape[0] = tileSize;
+    expectedOutputShape[2] = -1;
+  } else {
+    expectedOutputShape[0] = leftTransform ? tileSize : 1;
+    expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1;
   }
-
-  if (rightTransform) {
-    int64_t tileW = (inputW - (r - 1)) / m;
-    if (inputW != tileW * m + (r - 1))
-      return emitOpError("input width cannot be tiled in full tile size");
-    if (tileW != outputTileW)
-      return emitOpError("number of output width tiles is not correct");
-    if (outputW != m + r - 1)
-      return emitOpError("expect output width equals to tile size");
+  if (ShapedType::isDynamic(inputW)) {
+    expectedOutputShape[1] = tileSize;
+    expectedOutputShape[3] = -1;
----------------
Max191 wrote:

ditto, use `ShapedType::kDynamic`

https://github.com/llvm/llvm-project/pull/96181


More information about the Mlir-commits mailing list