[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
Hsiangkai Wang
llvmlistbot at llvm.org
Wed Jul 17 22:56:50 PDT 2024
================
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() {
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
+
return success();
}
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ auto indexType = builder.getIndexType();
+ auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+ auto oneAttr = builder.getIntegerAttr(indexType, 1);
+ Value output = getOutput();
+ SmallVector<Range> loopBounds(6);
+ for (unsigned dim = 0; dim < 6; ++dim) {
+ loopBounds[dim].offset = zeroAttr;
+ loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(6,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(offsets[2]);
+ resultOffsets.push_back(offsets[3]);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(sizes[0]);
+ resultSizes.push_back(sizes[1]);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(oneAttr);
+ resultSizes.push_back(sizes[4]);
+ resultSizes.push_back(sizes[5]);
+
+ return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Value input = getInput();
+ auto inputType = cast<ShapedType>(input.getType());
+ auto inputShape = inputType.getShape();
----------------
Hsiangkai wrote:
Done.
https://github.com/llvm/llvm-project/pull/96184
More information about the Mlir-commits
mailing list