[Mlir-commits] [mlir] [MLIR][Linalg] Fix winograd op lowering for types smaller than f32 (PR #158500)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 14 11:02:41 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Isaac Nudelman (nuudlman)

<details>
<summary>Changes</summary>

The winograd transform constant array is always emitted as f32, but previously the creation would pass through the original type. If this type was smaller (like f16), you would get an assertion failure during attribute creation.

This fixes this by always promoting the type of the winograd constants to f32 and adding a test for this case.

---
Full diff: https://github.com/llvm/llvm-project/pull/158500.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+15-15) 
- (modified) mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir (+116) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b80b27fe5fcc5..b875b24c8fda0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -186,11 +186,12 @@ constexpr float A_2x2_5x5[] = {
 
 /// Structure to keep information of constant transform matrices.
 struct TransformMatrix {
-  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+  TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
                   int64_t scalarFactor = 1)
-      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {
+  }
 
-  const float *table;
+  ArrayRef<float> table;
   int64_t rows;
   int64_t cols;
   int64_t scalarFactor;
@@ -198,15 +199,14 @@ struct TransformMatrix {
 
 /// Utility function to convert constant array to arith.constant Value.
 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
-                              TransformMatrix transform, Type type) {
-  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
-
+                              TransformMatrix transform) {
+  assert(transform.table.size() == static_cast<size_t>(transform.rows * transform.cols));
+  ArrayRef<float> constVec(transform.table.data(), transform.rows * transform.cols);
+  SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
   return arith::ConstantOp::create(
       builder, loc,
       DenseFPElementsAttr::get(
-          RankedTensorType::get(
-              SmallVector<int64_t>{transform.rows, transform.cols}, type),
-          constVec));
+          RankedTensorType::get(shape, builder.getF32Type()), constVec));
 }
 
 /// Extract height x width data from 4D tensors.
@@ -404,7 +404,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
-      Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+      Value G = create2DTransformMatrix(builder, loc, GMatrix);
       // Multiply G x g.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{G, extractFilter},
@@ -427,7 +427,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
-      Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+      Value GT = create2DTransformMatrix(builder, loc, GTMatrix);
       // Multiply u = (G x g) x GT.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, GT},
@@ -552,7 +552,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
       Value BT =
-          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BTMatrix);
       // Multiply BT x d.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{BT, matmulRetValue},
@@ -575,7 +575,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       Value B =
-          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BMatrix);
       // Multiply v = (BT x d) x B.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, B},
@@ -783,7 +783,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
         init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       }
 
-      Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
+      Value AT = create2DTransformMatrix(builder, loc, ATMatrix);
       // Multiply AT x m.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{AT, matmulRetValue},
@@ -802,7 +802,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
         init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       }
 
-      Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
+      Value A = create2DTransformMatrix(builder, loc, AMatrix);
       // Multiply y = (AT x m) x A.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, A},
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index c7b0bd51308ba..4bcb9b0c2c465 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -127,3 +127,119 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
 // CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
 // CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<6x6x5x2xf16>
+  %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf16>) outs(%0 : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> // no-crash
+  %2 = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+  %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf16>) outs(%2 : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> // no-crash
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+  %4 = tensor.empty() : tensor<36x2x2xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%5 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+  %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %7 : tensor<2x4x4x2xf32>
+}
+
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL:   func.func @conv2d_type_promotion(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x6x6x5xf16>,
+// CHECK-SAME:      %[[ARG1:.*]]: tensor<2x3x3x5xf16>,
+// CHECK-SAME:      %[[ARG2:.*]]: tensor<1xf32>,
+// CHECK-SAME:      %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:           %[[VAL_0:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:           %[[VAL_1:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:           %[[VAL_2:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:           %[[VAL_4:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:           %[[VAL_5:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:           %[[VAL_6:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK-DAG:           %[[VAL_9:.*]] = arith.constant 5 : index
+// CHECK-DAG:           %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK-DAG:           %[[VAL_11:.*]] = arith.constant 0 : index
+// CHECK-DAG:           %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_13:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT:               %[[VAL_20:.*]] = tensor.extract_slice %[[ARG1]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_11]], %[[VAL_18]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf16> to tensor<3x3xf16>
+// CHECK-NEXT:               %[[VAL_21:.*]] = tensor.empty() : tensor<6x3xf16>
+// CHECK-NEXT:               %[[VAL_22:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_21]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT:               %[[VAL_23:.*]] = linalg.matmul ins(%[[VAL_6]], %[[VAL_20]] : tensor<6x3xf32>, tensor<3x3xf16>) outs(%[[VAL_22]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT:               %[[VAL_24:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT:               %[[VAL_25:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_24]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:               %[[VAL_26:.*]] = linalg.matmul ins(%[[VAL_23]], %[[VAL_5]] : tensor<6x3xf16>, tensor<3x6xf32>) outs(%[[VAL_25]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:               %[[VAL_27:.*]] = tensor.insert_slice %[[VAL_26]] into %[[VAL_19]]{{\[}}%[[VAL_11]], %[[VAL_11]], %[[VAL_18]], %[[VAL_15]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x5x2xf16>
+// CHECK-NEXT:               scf.yield %[[VAL_27]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[VAL_17]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           %[[VAL_28:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:           %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:             %[[VAL_32:.*]] = scf.for %[[VAL_33:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_34:.*]] = %[[VAL_31]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:               %[[VAL_35:.*]] = scf.for %[[VAL_36:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_37:.*]] = %[[VAL_34]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:                 %[[VAL_38:.*]] = scf.for %[[VAL_39:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_40:.*]] = %[[VAL_37]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:                   %[[VAL_41:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_30]])
+// CHECK-NEXT:                   %[[VAL_42:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_33]])
+// CHECK-NEXT:                   %[[VAL_43:.*]] = tensor.extract_slice %[[ARG0]]{{\[}}%[[VAL_36]], %[[VAL_41]], %[[VAL_42]], %[[VAL_39]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf16> to tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_44:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_45:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_44]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_46:.*]] = linalg.matmul ins(%[[VAL_4]], %[[VAL_43]] : tensor<6x6xf32>, tensor<6x6xf16>) outs(%[[VAL_45]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_47:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_48:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_47]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_49:.*]] = linalg.matmul ins(%[[VAL_46]], %[[VAL_3]] : tensor<6x6xf16>, tensor<6x6xf32>) outs(%[[VAL_48]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_50:.*]] = tensor.insert_slice %[[VAL_49]] into %[[VAL_40]][0, 0, %[[VAL_30]], %[[VAL_33]], %[[VAL_36]], %[[VAL_39]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:                   scf.yield %[[VAL_50]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:                 }
+// CHECK-NEXT:                 scf.yield %[[VAL_38]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:               }
+// CHECK-NEXT:               scf.yield %[[VAL_35]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[VAL_32]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           %[[VAL_51:.*]] = tensor.collapse_shape %[[VAL_14]] {{\[\[}}0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT:           %[[VAL_52:.*]] = tensor.collapse_shape %[[VAL_29]] {{\[\[}}0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT:           %[[VAL_53:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:           %[[VAL_54:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_53]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:           %[[VAL_55:.*]] = linalg.batch_matmul ins(%[[VAL_52]], %[[VAL_51]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[VAL_54]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:           %[[VAL_56:.*]] = tensor.expand_shape %[[VAL_55]] {{\[\[}}0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:           %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_59:.*]] = %[[ARG3]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:             %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_59]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:               %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_65:.*]] = %[[VAL_62]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:                 %[[VAL_66:.*]] = scf.for %[[VAL_67:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_68:.*]] = %[[VAL_65]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:                   %[[VAL_69:.*]] = tensor.extract_slice %[[VAL_56]][0, 0, %[[VAL_58]], %[[VAL_61]], %[[VAL_64]], %[[VAL_67]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6xf32>
+// CHECK-NEXT:                   %[[VAL_70:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_58]])
+// CHECK-NEXT:                   %[[VAL_71:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_61]])
+// CHECK-NEXT:                   %[[VAL_72:.*]] = tensor.extract_slice %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x4x4x2xf32> to tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_73:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:                   %[[VAL_74:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_73]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:                   %[[VAL_75:.*]] = linalg.matmul ins(%[[VAL_2]], %[[VAL_69]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[VAL_74]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:                   %[[VAL_76:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_77:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_76]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_78:.*]] = linalg.matmul ins(%[[VAL_75]], %[[VAL_1]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[VAL_77]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_79:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_78]] : f32, tensor<4x4xf32>) outs(%[[VAL_72]] : tensor<4x4xf32>) {
+// CHECK-NEXT:                   ^bb0(%[[VAL_80:.*]]: f32, %[[VAL_81:.*]]: f32, %[[VAL_82:.*]]: f32):
+// CHECK-NEXT:                     %[[VAL_83:.*]] = arith.mulf %[[VAL_80]], %[[VAL_81]] : f32
+// CHECK-NEXT:                     %[[VAL_84:.*]] = arith.addf %[[VAL_83]], %[[VAL_82]] : f32
+// CHECK-NEXT:                     linalg.yield %[[VAL_84]] : f32
+// CHECK-NEXT:                   } -> tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_85:.*]] = tensor.insert_slice %[[VAL_79]] into %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:                   scf.yield %[[VAL_85]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:                 }
+// CHECK-NEXT:                 scf.yield %[[VAL_66]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:               }
+// CHECK-NEXT:               scf.yield %[[VAL_63]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[VAL_60]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           return %[[VAL_57]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:         }

``````````

</details>


https://github.com/llvm/llvm-project/pull/158500


More information about the Mlir-commits mailing list