[Mlir-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 19 08:02:23 PDT 2024


================
@@ -2855,6 +2949,92 @@ LogicalResult WinogradOutputTransformOp::verify() {
   return success();
 }
 
+SmallVector<Range>
+WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  IndexType indexType = builder.getIndexType();
+  IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
+  IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
+  Value value = getValue();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zeroAttr;
+    loopBounds[dim].size = getDimValue(builder, loc, value, dim);
+    loopBounds[dim].stride = oneAttr;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradOutputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradOutputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  Value output = getOutput();
+  auto outputType = cast<ShapedType>(output.getType());
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputH = outputShape[1];
+  int64_t outputW = outputShape[2];
+  int64_t m = getM();
+  IntegerAttr heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
+  IntegerAttr widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
+
+  Location loc = getLoc();
+  MLIRContext *context = builder.getContext();
+  auto affineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap,
+      getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
+
+  resultOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
+  resultSizes.append({sizes[4], heightM, widthM, sizes[5]});
+  return success();
+}
+
+FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
+  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  sliceOffsets.append(
+      {zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
+  sliceSizes.append({sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
+  SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
----------------
Max191 wrote:

Use input rank instead of `6`

https://github.com/llvm/llvm-project/pull/96184


More information about the Mlir-commits mailing list