[Mlir-commits] [mlir] cc14b58 - [MLIR][Linalg] Fix winograd op lowering for types smaller than f32 (#158500)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 8 14:09:54 PDT 2025
Author: Isaac Nudelman
Date: 2025-10-08T16:09:50-05:00
New Revision: cc14b589659b6c6c9fb65de7a274287f2490d345
URL: https://github.com/llvm/llvm-project/commit/cc14b589659b6c6c9fb65de7a274287f2490d345
DIFF: https://github.com/llvm/llvm-project/commit/cc14b589659b6c6c9fb65de7a274287f2490d345.diff
LOG: [MLIR][Linalg] Fix winograd op lowering for types smaller than f32 (#158500)
The winograd transform constant array is always emitted as f32, but
previously the creation would pass through the original type. If this
type was smaller (like f16), you would get an assertion failure during
attribute creation.
This fixes this by ensuring that the types match and adding a test for
this case.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index ef172c131be3b..37bdd8b12af4d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -186,11 +186,11 @@ constexpr float A_2x2_5x5[] = {
/// Structure to keep information of constant transform matrices.
struct TransformMatrix {
- TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
int64_t scalarFactor = 1)
: table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
- const float *table;
+ ArrayRef<float> table;
int64_t rows;
int64_t cols;
int64_t scalarFactor;
@@ -199,14 +199,20 @@ struct TransformMatrix {
/// Utility function to convert constant array to arith.constant Value.
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
TransformMatrix transform, Type type) {
- ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
-
+ assert(transform.table.size() ==
+ static_cast<size_t>(transform.rows * transform.cols));
+ assert(type.isFloat() && "Only floats are supported by Winograd");
+ ArrayRef<float> constVec(transform.table.data(),
+ transform.rows * transform.cols);
+ auto constAttrVec =
+ llvm::map_to_vector<>(constVec, [&](const float v) -> Attribute {
+ return builder.getFloatAttr(type, v);
+ });
+ SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
return arith::ConstantOp::create(
builder, loc,
- DenseFPElementsAttr::get(
- RankedTensorType::get(
- SmallVector<int64_t>{transform.rows, transform.cols}, type),
- constVec));
+ DenseFPElementsAttr::get(RankedTensorType::get(shape, type),
+ constAttrVec));
}
/// Extract height x width data from 4D tensors.
@@ -551,8 +557,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value BT =
- create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+ Value BT = create2DTransformMatrix(builder, loc, BTMatrix, elementType);
// Multiply BT x d.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{BT, matmulRetValue},
@@ -574,8 +579,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
.getResult();
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value B =
- create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+ Value B = create2DTransformMatrix(builder, loc, BMatrix, elementType);
// Multiply v = (BT x d) x B.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, B},
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index c7b0bd51308ba..8465e553166f1 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -127,3 +127,119 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
// CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<6x6x5x2xf16>
+ %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf16>) outs(%0 : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> // no-crash
+ %2 = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+ %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf16>) outs(%2 : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> // no-crash
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+ %4 = tensor.empty() : tensor<36x2x2xf32>
+ %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+ %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%5 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+ %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+ %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+ return %7 : tensor<2x4x4x2xf32>
+}
+
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_type_promotion(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x6x6x5xf16>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<2x3x3x5xf16>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<1xf32>,
+// CHECK-SAME: %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf16>
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf16>
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, -3.332520e-01, -3.332520e-01, 8.331300e-02, 8.331300e-02, 0.000000e+00], [0.000000e+00, 3.332520e-01, -3.332520e-01, -1.666260e-01, 1.666260e-01, 0.000000e+00], [0.000000e+00, -3.332520e-01, -3.332520e-01, 3.332520e-01, 3.332520e-01, 1.000000e+00]]> : tensor<3x6xf16>
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, 0.000000e+00, 0.000000e+00], [-3.332520e-01, 3.332520e-01, -3.332520e-01], [-3.332520e-01, -3.332520e-01, -3.332520e-01], [8.331300e-02, -1.666260e-01, 3.332520e-01], [8.331300e-02, 1.666260e-01, 3.332520e-01], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf16>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT: %[[VAL_20:.*]] = tensor.extract_slice %[[ARG1]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_11]], %[[VAL_18]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf16> to tensor<3x3xf16>
+// CHECK-NEXT: %[[VAL_21:.*]] = tensor.empty() : tensor<6x3xf16>
+// CHECK-NEXT: %[[VAL_22:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_21]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT: %[[VAL_23:.*]] = linalg.matmul ins(%[[VAL_6]], %[[VAL_20]] : tensor<6x3xf16>, tensor<3x3xf16>) outs(%[[VAL_22]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT: %[[VAL_24:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_25:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_24]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_26:.*]] = linalg.matmul ins(%[[VAL_23]], %[[VAL_5]] : tensor<6x3xf16>, tensor<3x6xf16>) outs(%[[VAL_25]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_27:.*]] = tensor.insert_slice %[[VAL_26]] into %[[VAL_19]]{{\[}}%[[VAL_11]], %[[VAL_11]], %[[VAL_18]], %[[VAL_15]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x5x2xf16>
+// CHECK-NEXT: scf.yield %[[VAL_27]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_17]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[VAL_28:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_32:.*]] = scf.for %[[VAL_33:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_34:.*]] = %[[VAL_31]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_35:.*]] = scf.for %[[VAL_36:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_37:.*]] = %[[VAL_34]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_38:.*]] = scf.for %[[VAL_39:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_40:.*]] = %[[VAL_37]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT: %[[VAL_41:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_30]])
+// CHECK-NEXT: %[[VAL_42:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_33]])
+// CHECK-NEXT: %[[VAL_43:.*]] = tensor.extract_slice %[[ARG0]]{{\[}}%[[VAL_36]], %[[VAL_41]], %[[VAL_42]], %[[VAL_39]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf16> to tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_44:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_45:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_44]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_46:.*]] = linalg.matmul ins(%[[VAL_4]], %[[VAL_43]] : tensor<6x6xf16>, tensor<6x6xf16>) outs(%[[VAL_45]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_47:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_48:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_47]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_49:.*]] = linalg.matmul ins(%[[VAL_46]], %[[VAL_3]] : tensor<6x6xf16>, tensor<6x6xf16>) outs(%[[VAL_48]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT: %[[VAL_50:.*]] = tensor.insert_slice %[[VAL_49]] into %[[VAL_40]][0, 0, %[[VAL_30]], %[[VAL_33]], %[[VAL_36]], %[[VAL_39]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: scf.yield %[[VAL_50]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_38]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_35]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_32]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[VAL_51:.*]] = tensor.collapse_shape %[[VAL_14]] {{\[\[}}0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT: %[[VAL_52:.*]] = tensor.collapse_shape %[[VAL_29]] {{\[\[}}0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT: %[[VAL_53:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT: %[[VAL_54:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_53]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[VAL_55:.*]] = linalg.batch_matmul ins(%[[VAL_52]], %[[VAL_51]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[VAL_54]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT: %[[VAL_56:.*]] = tensor.expand_shape %[[VAL_55]] {{\[\[}}0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_59:.*]] = %[[ARG3]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_59]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_65:.*]] = %[[VAL_62]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_66:.*]] = scf.for %[[VAL_67:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_68:.*]] = %[[VAL_65]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT: %[[VAL_69:.*]] = tensor.extract_slice %[[VAL_56]][0, 0, %[[VAL_58]], %[[VAL_61]], %[[VAL_64]], %[[VAL_67]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6xf32>
+// CHECK-NEXT: %[[VAL_70:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_58]])
+// CHECK-NEXT: %[[VAL_71:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_61]])
+// CHECK-NEXT: %[[VAL_72:.*]] = tensor.extract_slice %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x4x4x2xf32> to tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_73:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT: %[[VAL_74:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_73]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT: %[[VAL_75:.*]] = linalg.matmul ins(%[[VAL_2]], %[[VAL_69]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[VAL_74]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT: %[[VAL_76:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_77:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_76]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_78:.*]] = linalg.matmul ins(%[[VAL_75]], %[[VAL_1]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[VAL_77]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_79:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_78]] : f32, tensor<4x4xf32>) outs(%[[VAL_72]] : tensor<4x4xf32>) {
+// CHECK-NEXT: ^bb0(%[[VAL_80:.*]]: f32, %[[VAL_81:.*]]: f32, %[[VAL_82:.*]]: f32):
+// CHECK-NEXT: %[[VAL_83:.*]] = arith.mulf %[[VAL_80]], %[[VAL_81]] : f32
+// CHECK-NEXT: %[[VAL_84:.*]] = arith.addf %[[VAL_83]], %[[VAL_82]] : f32
+// CHECK-NEXT: linalg.yield %[[VAL_84]] : f32
+// CHECK-NEXT: } -> tensor<4x4xf32>
+// CHECK-NEXT: %[[VAL_85:.*]] = tensor.insert_slice %[[VAL_79]] into %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT: scf.yield %[[VAL_85]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_66]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_63]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: scf.yield %[[VAL_60]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[VAL_57]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT: }
\ No newline at end of file
More information about the Mlir-commits
mailing list