[Mlir-commits] [mlir] [MLIR][Linalg] Fixes for Winograd decomposition and for tiling (PR #123675)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 20 15:19:21 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Dmitriy Smirnov (d-smirnov)

<details>
<summary>Changes</summary>

The PR addresses issues with the filters of 1 x r and of r x 1 and with the tiling.

---

Patch is 35.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123675.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+4-2) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+22-14) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+19-15) 
- (modified) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+107-6) 
- (modified) mlir/test/Dialect/Linalg/transform-tile-winograd.mlir (+34-28) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index e42fd5d2ce13c1..f8df828f74851b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -155,7 +155,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
 }
 
 def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
-    [AllElementTypesMatch<["filter", "output"]>,
+    [AllElementTypesMatch<["filter", "output"]>, DestinationStyleOpInterface,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
@@ -220,12 +220,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
     int64_t getFilterCDim() {
       return 3;
     }
+    MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
   }];
   let hasVerifier = 1;
 }
 
 def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
-    [AllElementTypesMatch<["input", "output"]>,
+    [AllElementTypesMatch<["input", "output"]>, DestinationStyleOpInterface,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
@@ -308,6 +309,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
     int64_t getOutputCDim() {
       return 5;
     }
+    MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..0649fbc8e9549b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3063,8 +3063,11 @@ LogicalResult WinogradInputTransformOp::verify() {
   int m = getM();
   int r = getR();
   int64_t tileSize = m + r - 1;
-  bool leftTransform = inputH != 1;
-  bool rightTransform = inputW != 1;
+
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
+  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
 
   SmallVector<int64_t> expectedOutputShape(6, inputH);
   if (ShapedType::isDynamic(inputH)) {
@@ -3073,7 +3076,7 @@ LogicalResult WinogradInputTransformOp::verify() {
   } else {
     expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
     expectedOutputShape[getOutputTileHDim()] =
-        leftTransform ? (inputH - (r - 1)) / m : 1;
+        leftTransform ? (inputH - (r - 1)) / m : inputH;
   }
   if (ShapedType::isDynamic(inputW)) {
     expectedOutputShape[getOutputAlphaWDim()] = tileSize;
@@ -3081,13 +3084,11 @@ LogicalResult WinogradInputTransformOp::verify() {
   } else {
     expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
     expectedOutputShape[getOutputTileWDim()] =
-        rightTransform ? (inputW - (r - 1)) / m : 1;
+        rightTransform ? (inputW - (r - 1)) / m : inputW;
   }
   expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
   expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
 
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  ArrayRef<int64_t> outputShape = outputType.getShape();
   if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
     return emitOpError("the output shape is not expected");
   }
@@ -3124,15 +3125,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  ShapedType inputType = getInputOperandType();
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputH = inputShape[getInputHDim()];
-  int64_t inputW = inputShape[getInputWDim()];
+  ShapedType outputType = getOutputOperandType();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
+  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
+
   int64_t m = getM();
   int64_t r = getR();
   int64_t alpha = m + r - 1;
-  int64_t alphaH = inputH != 1 ? alpha : 1;
-  int64_t alphaW = inputW != 1 ? alpha : 1;
+  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
+  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
+
   IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
   IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
@@ -3165,6 +3168,11 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   int64_t m = getM();
   int64_t r = getR();
 
+  ShapedType outputType = getOutputOperandType();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t alphaH = outputShape[getOutputAlphaHDim()];
+  int64_t alphaW = outputShape[getOutputAlphaWDim()];
+
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
   auto offsetAffineMap =
@@ -3190,9 +3198,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   sliceOffsets.append(
       {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
   OpFoldResult sizeH =
-      inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+      alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
   OpFoldResult sizeW =
-      inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+      alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
   sliceSizes.append(
       {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
   int64_t inputRank = getInputOperandRank();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 79f77822116fd7..f1059ddf5da2cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -514,12 +514,14 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
     Value CIter = ivs[3];
 
     auto context = builder.getContext();
+
+    auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
     auto affineMap =
         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-    Value heightOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
-    Value widthOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+    Value heightOffset = builder.create<affine::AffineApplyOp>(
+        loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+    Value widthOffset = builder.create<affine::AffineApplyOp>(
+        loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
 
     // Extract (H, W) from (N, H, W, C).
     auto extractInput =
@@ -753,12 +755,13 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
     Value zero = builder.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementType));
 
+    auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
     auto affineMap =
         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-    Value heightOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
-    Value widthOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+    Value heightOffset = builder.create<affine::AffineApplyOp>(
+        loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+    Value widthOffset = builder.create<affine::AffineApplyOp>(
+        loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
 
     Value outInitVal =
         extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
@@ -1075,16 +1078,17 @@ FailureOr<Operation *>
 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
                                       linalg::WinogradInputTransformOp op) {
   Location loc = op.getLoc();
-  Value input = op.getInput();
-  auto inputType = cast<ShapedType>(input.getType());
-  auto inputShape = inputType.getShape();
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
+  Value output = op.getOutput();
+  auto outputType = cast<ShapedType>(output.getType());
+  auto outputShape = outputType.getShape();
+
+  int64_t outputH = outputShape[0];
+  int64_t outputW = outputShape[1];
 
   // For F(m x 1, r x 1), we only need to do left side transform.
-  bool leftTransform = inputH != 1;
+  bool leftTransform = outputH != 1;
   // For F(1 x m, 1 x r), we only need to do right side transform.
-  bool rightTransform = inputW != 1;
+  bool rightTransform = outputW != 1;
   Value transformedInput =
       inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
                      op.getR(), leftTransform, rightTransform);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 776dc5b748c846..a9af874ea40933 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -61,13 +61,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
 // CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
 // CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -195,13 +194,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
 // CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
 // CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
 // CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
 // CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -346,3 +344,106 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
 // CHECK:     scf.yield %[[S7]]
 // CHECK:   return %[[S6]]
+
+// -----
+
+func.func @conv2d_mx1_rx1_2(%arg0: tensor<2x6x2x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<6x1x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  %2 = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x2x2x5xf32> into tensor<6x4x5xf32>
+  %4 = tensor.empty() : tensor<6x4x2xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2] : tensor<6x4x2xf32> into tensor<6x1x1x2x2x2xf32>
+  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
+  return %7 : tensor<2x4x2x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x2x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[CST_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK:       %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[S10:.*]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+// CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK:     %[[S8:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %8, 0] [2, 6, 1, 5] [1, 1, 1, 1]
+// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:     %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK:     %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK:       %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S11:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK:       %[[S12:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S11]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_6]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S12]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[INSERTED_SLICE_7:.*]] = tensor.insert_slice %[[S13]] into %[[ARG8]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE_7]]
+// CHECK:     scf.yield %[[S10]]
+// CHECK:    %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:    scf.yield %[[INSERTED_SLICE]]
+// CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK:   %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:   %[[S4:.*]] = tensor.empty() : tensor<6x4x2xf32>
+// CHECK:   %[[S5:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S4]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_4]], %[[COLLAPSED]] : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2]
+// CHECK:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, 0] [2, 4, 1, 2] [1, 1, 1, 1]
+// CHECK:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK:       %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK:       %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/123675


More information about the Mlir-commits mailing list