[Mlir-commits] [mlir] [MLIR][Linalg] Fixes for Winograd decomposition and for tiling (PR #123675)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 20 15:19:21 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Dmitriy Smirnov (d-smirnov)
<details>
<summary>Changes</summary>
The PR addresses issues with the filters of 1 x r and of r x 1 and with the tiling.
---
Patch is 35.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123675.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+4-2)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+22-14)
- (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+19-15)
- (modified) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+107-6)
- (modified) mlir/test/Dialect/Linalg/transform-tile-winograd.mlir (+34-28)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index e42fd5d2ce13c1..f8df828f74851b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -155,7 +155,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
}
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
- [AllElementTypesMatch<["filter", "output"]>,
+ [AllElementTypesMatch<["filter", "output"]>, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
@@ -220,12 +220,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
int64_t getFilterCDim() {
return 3;
}
+ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
- [AllElementTypesMatch<["input", "output"]>,
+ [AllElementTypesMatch<["input", "output"]>, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
@@ -308,6 +309,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
int64_t getOutputCDim() {
return 5;
}
+ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..0649fbc8e9549b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3063,8 +3063,11 @@ LogicalResult WinogradInputTransformOp::verify() {
int m = getM();
int r = getR();
int64_t tileSize = m + r - 1;
- bool leftTransform = inputH != 1;
- bool rightTransform = inputW != 1;
+
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
+ bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
SmallVector<int64_t> expectedOutputShape(6, inputH);
if (ShapedType::isDynamic(inputH)) {
@@ -3073,7 +3076,7 @@ LogicalResult WinogradInputTransformOp::verify() {
} else {
expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
expectedOutputShape[getOutputTileHDim()] =
- leftTransform ? (inputH - (r - 1)) / m : 1;
+ leftTransform ? (inputH - (r - 1)) / m : inputH;
}
if (ShapedType::isDynamic(inputW)) {
expectedOutputShape[getOutputAlphaWDim()] = tileSize;
@@ -3081,13 +3084,11 @@ LogicalResult WinogradInputTransformOp::verify() {
} else {
expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
expectedOutputShape[getOutputTileWDim()] =
- rightTransform ? (inputW - (r - 1)) / m : 1;
+ rightTransform ? (inputW - (r - 1)) / m : inputW;
}
expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
- auto outputType = cast<ShapedType>(getOutput().getType());
- ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
@@ -3124,15 +3125,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
- ShapedType inputType = getInputOperandType();
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputH = inputShape[getInputHDim()];
- int64_t inputW = inputShape[getInputWDim()];
+ ShapedType outputType = getOutputOperandType();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
+ int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
+
int64_t m = getM();
int64_t r = getR();
int64_t alpha = m + r - 1;
- int64_t alphaH = inputH != 1 ? alpha : 1;
- int64_t alphaW = inputW != 1 ? alpha : 1;
+ int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
+ int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
+
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
@@ -3165,6 +3168,11 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
int64_t m = getM();
int64_t r = getR();
+ ShapedType outputType = getOutputOperandType();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ int64_t alphaH = outputShape[getOutputAlphaHDim()];
+ int64_t alphaW = outputShape[getOutputAlphaWDim()];
+
Location loc = getLoc();
MLIRContext *context = builder.getContext();
auto offsetAffineMap =
@@ -3190,9 +3198,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
sliceOffsets.append(
{offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
OpFoldResult sizeH =
- inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+ alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
OpFoldResult sizeW =
- inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+ alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
sliceSizes.append(
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
int64_t inputRank = getInputOperandRank();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 79f77822116fd7..f1059ddf5da2cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -514,12 +514,14 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
Value CIter = ivs[3];
auto context = builder.getContext();
+
+ auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
- Value heightOffset =
- builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
- Value widthOffset =
- builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+ Value heightOffset = builder.create<affine::AffineApplyOp>(
+ loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+ Value widthOffset = builder.create<affine::AffineApplyOp>(
+ loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
// Extract (H, W) from (N, H, W, C).
auto extractInput =
@@ -753,12 +755,13 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
Value zero = builder.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
+ auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
- Value heightOffset =
- builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
- Value widthOffset =
- builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+ Value heightOffset = builder.create<affine::AffineApplyOp>(
+ loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+ Value widthOffset = builder.create<affine::AffineApplyOp>(
+ loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
Value outInitVal =
extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
@@ -1075,16 +1078,17 @@ FailureOr<Operation *>
decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
linalg::WinogradInputTransformOp op) {
Location loc = op.getLoc();
- Value input = op.getInput();
- auto inputType = cast<ShapedType>(input.getType());
- auto inputShape = inputType.getShape();
- int64_t inputH = inputShape[1];
- int64_t inputW = inputShape[2];
+ Value output = op.getOutput();
+ auto outputType = cast<ShapedType>(output.getType());
+ auto outputShape = outputType.getShape();
+
+ int64_t outputH = outputShape[0];
+ int64_t outputW = outputShape[1];
// For F(m x 1, r x 1), we only need to do left side transform.
- bool leftTransform = inputH != 1;
+ bool leftTransform = outputH != 1;
// For F(1 x m, 1 x r), we only need to do right side transform.
- bool rightTransform = inputW != 1;
+ bool rightTransform = outputW != 1;
Value transformedInput =
inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
op.getR(), leftTransform, rightTransform);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 776dc5b748c846..a9af874ea40933 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -61,13 +61,12 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S9]]
// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
// CHECK: %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -195,13 +194,12 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.yield %[[S9]] : tensor<6x6x5x2xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]])
// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
// CHECK: %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -346,3 +344,106 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S7]]
// CHECK: return %[[S6]]
+
+// -----
+
+func.func @conv2d_mx1_rx1_2(%arg0: tensor<2x6x2x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<6x1x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %2 = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x2x2x5xf32> into tensor<6x4x5xf32>
+ %4 = tensor.empty() : tensor<6x4x2xf32>
+ %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+ %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+ %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2] : tensor<6x4x2xf32> into tensor<6x1x1x2x2x2xf32>
+ %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
+ return %7 : tensor<2x4x2x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x2x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK: %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[CST_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK: %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[S10:.*]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+// CHECK: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK: %[[S8:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %8, 0] [2, 6, 1, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK: %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK: %[[S12:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S11]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_6]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S12]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[INSERTED_SLICE_7:.*]] = tensor.insert_slice %[[S13]] into %[[ARG8]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_7]]
+// CHECK: scf.yield %[[S10]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S4:.*]] = tensor.empty() : tensor<6x4x2xf32>
+// CHECK: %[[S5:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S4]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_4]], %[[COLLAPSED]] : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2]
+// CHECK: %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, 0] [2, 4, 1, 2] [1, 1, 1, 1]
+// CHECK: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/123675
More information about the Mlir-commits
mailing list