[Mlir-commits] [mlir] [mlir][linalg] Implement Winograd Conv2D. (PR #94470)
Hsiangkai Wang
llvmlistbot at llvm.org
Thu Jun 20 02:56:03 PDT 2024
================
@@ -2737,6 +2737,336 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
return SmallVector<Value>{result};
}
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+SmallVector<Range>
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+ SmallVector<Range> loopBounds(4);
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ auto heightAttr = builder.getI64IntegerAttr(getOutputHeight());
+ auto widthAttr = builder.getI64IntegerAttr(getOutputWidth());
+ Value output = getOutput();
+ for (auto dim = 0; dim < 4; ++dim) {
+ loopBounds[dim].offset = zero;
+ loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim].stride = one;
+ }
+ // Iterate on output domain
+ loopBounds[0].size = heightAttr;
+ loopBounds[1].size = widthAttr;
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradFilterTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(4,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+ Location loc) {
+ if (auto val = opFoldResult.dyn_cast<Value>()) {
+ return val;
+ } else if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+ auto intAttr = cast<IntegerAttr>(attr);
+ return builder.create<arith::ConstantOp>(loc, intAttr);
+ }
+ // This should never happen if OpFoldResult is correctly formed
+ return nullptr;
+}
+
+LogicalResult WinogradFilterTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Value filter = getFilter();
+ auto filterType = cast<ShapedType>(filter.getType());
+ auto filterShape = filterType.getShape();
+ int64_t filterH = filterShape[1];
+ int64_t filterW = filterShape[2];
+ int64_t m = getM();
+ int64_t r = getR();
+ int64_t alpha = m + r - 1;
+ int64_t alphaH = filterH != 1 ? alpha : 1;
+ int64_t alphaW = filterW != 1 ? alpha : 1;
+ auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ auto context = builder.getContext();
+ auto affineMap = AffineMap::get(
+ 1, 0, {builder.getAffineDimExpr(0).floorDiv(m) * alpha}, context);
+
+ Location loc = getLoc();
+ Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+ Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+ resultOffsets.push_back(mappedOffset1);
+ resultOffsets.push_back(mappedOffset2);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(alphaHAttr);
+ resultSizes.push_back(alphaWAttr);
+ resultSizes.push_back(sizes[2]);
+ resultSizes.push_back(sizes[3]);
+ return success();
+}
+
+FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ Location loc = getLoc();
+ SmallVector<OpFoldResult> strides(4, oneAttr);
+ SmallVector<Value> tiledOperands;
+ tiledOperands.emplace_back(getFilter());
+
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+ sliceSizes)))
+ return failure();
+
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+ SmallVector<Type, 4> resultTypes;
+ resultTypes.push_back(tiledOperands[1].getType());
+ Operation *tiledOp =
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+ SmallVector<Range> loopBounds(4);
+ Location loc = getLoc();
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+ auto heightAttr = builder.getI64IntegerAttr(getOutputHeight());
+ auto widthAttr = builder.getI64IntegerAttr(getOutputWidth());
+ Value output = getOutput();
+ for (auto dim = 0; dim < 4; ++dim) {
+ loopBounds[dim].offset = zero;
+ loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim].stride = one;
+ }
+ loopBounds[0].size = heightAttr;
+ loopBounds[1].size = widthAttr;
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(4,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Value input = getInput();
+ auto inputType = cast<ShapedType>(input.getType());
+ auto inputShape = inputType.getShape();
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t m = getM();
+ int64_t r = getR();
+ int64_t alpha = m + r - 1;
+ int64_t alphaH = inputH != 1 ? alpha : 1;
+ int64_t alphaW = inputW != 1 ? alpha : 1;
+ auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ auto context = builder.getContext();
+ auto affineMap = AffineMap::get(
+ 1, 0, {builder.getAffineDimExpr(0).floorDiv(m) * alpha}, context);
+
+ Location loc = getLoc();
+ Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc));
+ Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc));
+
+ resultOffsets.push_back(mappedOffset1);
+ resultOffsets.push_back(mappedOffset2);
+ resultOffsets.push_back(zeroAttr);
+ resultOffsets.push_back(zeroAttr);
+ resultSizes.push_back(alphaHAttr);
+ resultSizes.push_back(alphaWAttr);
+ resultSizes.push_back(sizes[2]);
+ resultSizes.push_back(sizes[3]);
+ return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ Value input = getInput();
+ auto inputType = cast<ShapedType>(input.getType());
+ auto inputShape = inputType.getShape();
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t m = getM();
+ int64_t r = getR();
+ int64_t alpha = m + r - 1;
+ int64_t alphaH = inputH != 1 ? alpha : 1;
+ int64_t alphaW = inputW != 1 ? alpha : 1;
+ auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ Location loc = getLoc();
+ SmallVector<OpFoldResult> strides(4, oneAttr);
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ sliceOffsets.push_back(zeroAttr);
+ sliceOffsets.push_back(offsets[0]);
+ sliceOffsets.push_back(offsets[1]);
+ sliceOffsets.push_back(zeroAttr);
+ sliceSizes.push_back(sizes[2]);
+ sliceSizes.push_back(alphaHAttr);
+ sliceSizes.push_back(alphaWAttr);
+ sliceSizes.push_back(sizes[3]);
----------------
Hsiangkai wrote:
I updated my patch to use the number of tiles as additional dimensions.
https://github.com/llvm/llvm-project/pull/94470
More information about the Mlir-commits
mailing list