[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)

Hsiangkai Wang llvmlistbot at llvm.org
Wed Aug 14 05:47:11 PDT 2024


================
@@ -2922,6 +3021,129 @@ LogicalResult WinogradInputTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  IntegerAttr zeroAttr = builder.getIndexAttr(0);
+  IntegerAttr oneAttr = builder.getIndexAttr(1);
+  Value output = getOutput();
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<Range> loopBounds(outputRank - 2);
+  for (unsigned dim = 2; dim < outputRank; ++dim) {
+    loopBounds[dim - 2].offset = zeroAttr;
+    loopBounds[dim - 2].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim - 2].stride = oneAttr;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+  int64_t outputRank = getOutputOperandRank();
+  SmallVector<utils::IteratorType> iteratorTypes(outputRank - 2,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  ShapedType inputType = getInputOperandType();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  resultOffsets.append(
+      {zeroAttr, zeroAttr, offsets[0], offsets[1], offsets[2], offsets[3]});
+  resultSizes.append(
+      {alphaHAttr, alphaWAttr, sizes[0], sizes[1], sizes[2], sizes[3]});
+
+  return success();
+}
+
+/// Implement tiling for winograd_input_transform
+/// The input of winograd_input_transform is (N, H, W, C).
+/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
+/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
+/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
+/// the values for the sizes of tileH, tileW, N, C for one tile.
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+                                                 ArrayRef<OpFoldResult> offsets,
+                                                 ArrayRef<OpFoldResult> sizes) {
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  ShapedType inputType = getInputOperandType();
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+
+  Location loc = getLoc();
+  MLIRContext *context = builder.getContext();
+  auto offsetAffineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+  Value mappedOffsetH = builder.create<affine::AffineApplyOp>(
+      loc, offsetAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[0]));
+  Value mappedOffsetW = builder.create<affine::AffineApplyOp>(
+      loc, offsetAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[1]));
+  auto sizeAffineMap = AffineMap::get(
+      1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
+  Value mappedSizeH = builder.create<affine::AffineApplyOp>(
+      loc, sizeAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, sizes[0]));
+  Value mappedSizeW = builder.create<affine::AffineApplyOp>(
+      loc, sizeAffineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, sizes[1]));
----------------
Hsiangkai wrote:

Done.

https://github.com/llvm/llvm-project/pull/96184


More information about the Mlir-commits mailing list