[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 19 08:02:21 PDT 2024
================
@@ -2813,6 +2813,100 @@ LogicalResult WinogradInputTransformOp::verify() {
return success();
}
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+ Location loc = getLoc();
+ IndexType indexType = builder.getIndexType();
+ IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+ IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
+ Value output = getOutput();
+ SmallVector<Range> loopBounds(6);
+ for (unsigned dim = 0; dim < 6; ++dim) {
+ loopBounds[dim].offset = zeroAttr;
+ loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+ loopBounds[dim].stride = oneAttr;
+ }
+ return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+ SmallVector<utils::IteratorType> iteratorTypes(6,
+ utils::IteratorType::parallel);
+ return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+
+ resultOffsets.append(
+ {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+ resultSizes.append(
+ {sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
+
+ return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+ IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+ Value input = getInput();
+ auto inputType = cast<ShapedType>(input.getType());
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t inputH = inputShape[1];
+ int64_t inputW = inputShape[2];
+ int64_t m = getM();
+ int64_t r = getR();
+ int64_t alpha = m + r - 1;
+ int64_t alphaH = inputH != 1 ? alpha : 1;
+ int64_t alphaW = inputW != 1 ? alpha : 1;
+ IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+ IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+ MLIRContext *context = builder.getContext();
+ auto affineMap =
+ AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+ Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
+ Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+ loc, affineMap,
+ getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+ sliceOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+ sliceSizes.append({sizes[4], alphaHAttr, alphaWAttr, sizes[5]});
+ SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+ tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+ loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
+ resultSizes)))
+ return failure();
+
+ SmallVector<OpFoldResult> outputStrides(6, oneAttr);
----------------
Max191 wrote:
Replace `6` with output rank
https://github.com/llvm/llvm-project/pull/96184
More information about the Mlir-commits
mailing list