[llvm-branch-commits] [mlir] [mlir][linalg] Implement TilingInterface for winograd operators (PR #96184)

Oleksandr Alex Zinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jul 11 05:00:52 PDT 2024


================
@@ -2810,9 +2819,117 @@ LogicalResult WinogradInputTransformOp::verify() {
   if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
     return emitOpError("the output shape is not expected");
   }
+
   return success();
 }
 
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  auto indexType = builder.getIndexType();
+  auto zeroAttr = builder.getIntegerAttr(indexType, 0);
+  auto oneAttr = builder.getIntegerAttr(indexType, 1);
+  Value output = getOutput();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zeroAttr;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = oneAttr;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(offsets[2]);
+  resultOffsets.push_back(offsets[3]);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(sizes[0]);
+  resultSizes.push_back(sizes[1]);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(sizes[5]);
+
+  return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+                                                 ArrayRef<OpFoldResult> offsets,
+                                                 ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value input = getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
+
+  Location loc = getLoc();
+  SmallVector<Value> tiledOperands;
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+
+  auto context = builder.getContext();
+  auto affineMap =
+      AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
+  Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
+  Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
+      loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
+
+  sliceOffsets.push_back(zeroAttr);
+  sliceOffsets.push_back(mappedOffset1);
+  sliceOffsets.push_back(mappedOffset2);
+  sliceOffsets.push_back(zeroAttr);
+  sliceSizes.push_back(sizes[4]);
+  sliceSizes.push_back(alphaHAttr);
+  sliceSizes.push_back(alphaWAttr);
+  sliceSizes.push_back(sizes[5]);
+  SmallVector<OpFoldResult> inputStrides(4, oneAttr);
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
+
+  sliceOffsets.clear();
+  sliceSizes.clear();
----------------
ftynse wrote:

I'd rather declare new vectors for this.

https://github.com/llvm/llvm-project/pull/96184


More information about the llvm-branch-commits mailing list