[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)

Hsiangkai Wang llvmlistbot at llvm.org
Wed Aug 14 05:46:59 PDT 2024


================
@@ -2922,6 +3021,129 @@ LogicalResult WinogradInputTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  IntegerAttr zeroAttr = builder.getIndexAttr(0);
+  IntegerAttr oneAttr = builder.getIndexAttr(1);
+  Value output = getOutput();
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<Range> loopBounds(outputRank - 2);
+  for (unsigned dim = 2; dim < outputRank; ++dim) {
+    loopBounds[dim - 2].offset = zeroAttr;
+    loopBounds[dim - 2].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim - 2].stride = oneAttr;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<utils::IteratorType> iteratorTypes(outputRank - 2,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  ShapedType inputType = getInputOperandType();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
----------------
Hsiangkai wrote:

Done.

https://github.com/llvm/llvm-project/pull/96184


More information about the Mlir-commits mailing list