[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 6 08:06:31 PDT 2024
================
@@ -2922,6 +3021,129 @@ LogicalResult WinogradInputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ IntegerAttr zeroAttr = builder.getIndexAttr(0);
+ IntegerAttr oneAttr = builder.getIndexAttr(1);
+ Value output = getOutput();
+ int64_t outputRank = getOutputOperandRank();
+ SmallVector<Range> loopBounds(outputRank - 2);
+ for (unsigned dim = 2; dim < outputRank; ++dim) {
+ loopBounds[dim - 2].offset = zeroAttr;
+ loopBounds[dim - 2].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim - 2].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+ int64_t outputRank = getOutputOperandRank();
+ SmallVector<utils::IteratorType> iteratorTypes(outputRank - 2,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ ShapedType inputType = getInputOperandType();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
----------------
Max191 wrote:
If this can use some `extraClassDeclaration` instead of `1` and `2`, then that would make it easier to see what is going on. It is one of the issues with the IREE implementation that we can hopefully fix here before landing. This would also make it easier to extend to other layouts like NCHW, since it would just require an update to the `extraClassDeclaration`, instead of having to put conditionals around every magic number.
https://github.com/llvm/llvm-project/pull/96184
More information about the Mlir-commits
mailing list