[Mlir-commits] [mlir] [MLIR][Linalg] Fix winograd op lowering for types smaller than f32 (PR #158500)

Isaac Nudelman llvmlistbot at llvm.org
Sun Sep 14 11:05:46 PDT 2025


https://github.com/nuudlman updated https://github.com/llvm/llvm-project/pull/158500

>From d4d2121b41d16a9658b8972dc47b79c15b50ca5f Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Fri, 12 Sep 2025 20:44:34 -0500
Subject: [PATCH 1/5] add a concept of a fix

---
 .../Linalg/Transforms/WinogradConv2D.cpp      | 47 ++++++++++---------
 .../Linalg/winograd-conv2d-rewrite.mlir       | 18 +++++++
 2 files changed, 43 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b80b27fe5fcc5..288c8ada0c8eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -46,46 +46,46 @@ namespace {
 ///   BTMatrices, BMatrices, ATMatrices, or AMatrices map.
 /// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
 ///
-constexpr float G_2x2_3x3[] = {
+constexpr double G_2x2_3x3[] = {
    -1,     0,   0,
  1./2, -1./2, 1./2,
  1./2,  1./2, 1./2,
     0,     0,    1
 };
 
-constexpr float GT_2x2_3x3[] = {
+constexpr double GT_2x2_3x3[] = {
    -1,  1./2, 1./2, 0,
     0, -1./2, 1./2, 0,
     0,  1./2, 1./2, 1
 };
 
-constexpr float BT_2x2_3x3[] = {
+constexpr double BT_2x2_3x3[] = {
    -1,    0,   1,   0,
     0,   -1,   1,   0,
     0,    1,   1,   0,
     0,   -1,   0,   1
 };
 
-constexpr float B_2x2_3x3[] = {
+constexpr double B_2x2_3x3[] = {
    -1,    0,   0,   0,
     0,   -1,   1,  -1,
     1,    1,   1,   0,
     0,    0,   0,   1
 };
 
-constexpr float AT_2x2_3x3[] = {
+constexpr double AT_2x2_3x3[] = {
     1,    1,   1,   0,
     0,   -1,   1,   1
 };
 
-constexpr float A_2x2_3x3[] = {
+constexpr double A_2x2_3x3[] = {
     1,    0,
     1,   -1,
     1,    1,
     0,    1
 };
 
-constexpr float G_4x4_3x3[] = {
+constexpr double G_4x4_3x3[] = {
      1,     0,     0,
  -1./3,  1./3, -1./3,
  -1./3, -1./3, -1./3,
@@ -94,13 +94,13 @@ constexpr float G_4x4_3x3[] = {
      0,     0,     1
 };
 
-constexpr float GT_4x4_3x3[] = {
+constexpr double GT_4x4_3x3[] = {
  1,  -1./3, -1./3, 1./12, 1./12, 0,
  0,   1./3, -1./3, -1./6,  1./6, 0,
  0,  -1./3, -1./3,  1./3,  1./3, 1
 };
 
-constexpr float BT_4x4_3x3[] = {
+constexpr double BT_4x4_3x3[] = {
  1./4,     0, -5./16,      0, 1./16,     0,
     0,  1./4,  -1./4, -1./16, 1./16,     0,
     0, -1./4,  -1./4,  1./16, 1./16,     0,
@@ -109,7 +109,7 @@ constexpr float BT_4x4_3x3[] = {
     0,  1./4,      0, -5./16,     0, 1./16
 };
 
-constexpr float B_4x4_3x3[] = {
+constexpr double B_4x4_3x3[] = {
    1./4,      0,     0,     0,     0,      0,
       0,   1./4, -1./4,  1./4, -1./4,   1./4,
  -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
@@ -118,14 +118,14 @@ constexpr float B_4x4_3x3[] = {
       0,      0,     0,     0,     0,  1./16
 };
 
-constexpr float AT_4x4_3x3[] = {
+constexpr double AT_4x4_3x3[] = {
  1./8,  1./4, 1./4,  1./8, 1./8,    0,
     0, -1./4, 1./4, -1./4, 1./4,    0,
     0,  1./4, 1./4,  1./2, 1./2,    0,
     0, -1./4, 1./4,    -1,    1, 1./2
 };
 
-constexpr float A_4x4_3x3[] = {
+constexpr double A_4x4_3x3[] = {
   1./8,     0,    0,     0,
   1./4, -1./4, 1./4, -1./4,
   1./4,  1./4, 1./4,  1./4,
@@ -134,7 +134,7 @@ constexpr float A_4x4_3x3[] = {
      0,     0,    0,  1./2
 };
 
-constexpr float G_2x2_5x5[] = {
+constexpr double G_2x2_5x5[] = {
      1,     0,      0,      0,      0,
   1./6, -1./6,   1./6,  -1./6,   1./6,
  -1./6, -1./6,  -1./6,  -1./6,  -1./6,
@@ -143,7 +143,7 @@ constexpr float G_2x2_5x5[] = {
      0,     0,      0,      0,      1
 };
 
-constexpr float GT_2x2_5x5[] = {
+constexpr double GT_2x2_5x5[] = {
    1,  1./6, -1./6, -4./15, 1./60, 0,
    0, -1./6, -1./6,  2./15, 1./30, 0,
    0,  1./6, -1./6, -1./15, 1./15, 0,
@@ -151,7 +151,7 @@ constexpr float GT_2x2_5x5[] = {
    0,  1./6, -1./6, -1./60, 4./15, 1
 };
 
-constexpr float BT_2x2_5x5[] = {
+constexpr double BT_2x2_5x5[] = {
  1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
     0,   1./8,  1./16,  -5./16,   1./8,    0,
     0,  -1./8, -5./16,  -1./16,   1./8,    0,
@@ -160,7 +160,7 @@ constexpr float BT_2x2_5x5[] = {
     0,   1./8,  3./16,   -1./4, -3./16, 1./8
 };
 
-constexpr float B_2x2_5x5[] = {
+constexpr double B_2x2_5x5[] = {
    1./8,      0,      0,     0,     0,      0,
   3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
   -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
@@ -169,12 +169,12 @@ constexpr float B_2x2_5x5[] = {
       0,      0,      0,     0,     0,   1./8
 };
 
-constexpr float AT_2x2_5x5[] = {
+constexpr double AT_2x2_5x5[] = {
   1./2,  1, 1,  2, 1,    0,
      0, -1, 1, -1, 2, 1./2
 };
 
-constexpr float A_2x2_5x5[] = {
+constexpr double A_2x2_5x5[] = {
  1./2,    0,
     1,   -1,
     1,    1,
@@ -186,11 +186,12 @@ constexpr float A_2x2_5x5[] = {
 
 /// Structure to keep information of constant transform matrices.
 struct TransformMatrix {
-  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+  TransformMatrix(ArrayRef<double> table, int64_t rows, int64_t cols,
                   int64_t scalarFactor = 1)
-      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+      : table(llvm::map_to_vector(table, [](double val) { return APFloat(val); })), rows(rows), cols(cols), scalarFactor(scalarFactor) {
+  }
 
-  const float *table;
+  SmallVector<APFloat> table;
   int64_t rows;
   int64_t cols;
   int64_t scalarFactor;
@@ -199,7 +200,9 @@ struct TransformMatrix {
 /// Utility function to convert constant array to arith.constant Value.
 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
                               TransformMatrix transform, Type type) {
-  ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+  assert(type.isFloat());
+  assert(transform.table.size() == (transform.rows * transform.cols));
+  ArrayRef<APFloat> constVec(transform.table.data(), transform.rows * transform.cols);
 
   return arith::ConstantOp::create(
       builder, loc,
diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index c7b0bd51308ba..0c7d4e1d23f34 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -127,3 +127,21 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
 // CHECK-NEXT:   %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
 // CHECK-NEXT:   return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
 // CHECK-NEXT: }
+
+// -----
+
+func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<6x6x5x2xf16>
+  %1 = linalg.winograd_filter_transform fmr(F_4_3) ins(%arg1 : tensor<2x3x3x5xf16>) outs(%0 : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> // no-crash
+  %2 = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+  %3 = linalg.winograd_input_transform fmr(F_4_3) ins(%arg0 : tensor<2x6x6x5xf16>) outs(%2 : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> // no-crash
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+  %4 = tensor.empty() : tensor<36x2x2xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%5 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+  %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
+  return %7 : tensor<2x4x4x2xf32>
+}

>From 4ca18ad309a18282460bbc1b472330012fec1cd7 Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Sun, 14 Sep 2025 11:04:22 -0500
Subject: [PATCH 2/5] Always promote winograd lowering to f32

---
 .../Linalg/Transforms/WinogradConv2D.cpp      | 69 +++++++++----------
 1 file changed, 34 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 288c8ada0c8eb..2e884f6f79ef6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -46,46 +46,46 @@ namespace {
 ///   BTMatrices, BMatrices, ATMatrices, or AMatrices map.
 /// 3. Add a enum value F_m_r to WinogradConv2DFmr enum.
 ///
-constexpr double G_2x2_3x3[] = {
+constexpr float G_2x2_3x3[] = {
    -1,     0,   0,
  1./2, -1./2, 1./2,
  1./2,  1./2, 1./2,
     0,     0,    1
 };
 
-constexpr double GT_2x2_3x3[] = {
+constexpr float GT_2x2_3x3[] = {
    -1,  1./2, 1./2, 0,
     0, -1./2, 1./2, 0,
     0,  1./2, 1./2, 1
 };
 
-constexpr double BT_2x2_3x3[] = {
+constexpr float BT_2x2_3x3[] = {
    -1,    0,   1,   0,
     0,   -1,   1,   0,
     0,    1,   1,   0,
     0,   -1,   0,   1
 };
 
-constexpr double B_2x2_3x3[] = {
+constexpr float B_2x2_3x3[] = {
    -1,    0,   0,   0,
     0,   -1,   1,  -1,
     1,    1,   1,   0,
     0,    0,   0,   1
 };
 
-constexpr double AT_2x2_3x3[] = {
+constexpr float AT_2x2_3x3[] = {
     1,    1,   1,   0,
     0,   -1,   1,   1
 };
 
-constexpr double A_2x2_3x3[] = {
+constexpr float A_2x2_3x3[] = {
     1,    0,
     1,   -1,
     1,    1,
     0,    1
 };
 
-constexpr double G_4x4_3x3[] = {
+constexpr float G_4x4_3x3[] = {
      1,     0,     0,
  -1./3,  1./3, -1./3,
  -1./3, -1./3, -1./3,
@@ -94,13 +94,13 @@ constexpr double G_4x4_3x3[] = {
      0,     0,     1
 };
 
-constexpr double GT_4x4_3x3[] = {
+constexpr float GT_4x4_3x3[] = {
  1,  -1./3, -1./3, 1./12, 1./12, 0,
  0,   1./3, -1./3, -1./6,  1./6, 0,
  0,  -1./3, -1./3,  1./3,  1./3, 1
 };
 
-constexpr double BT_4x4_3x3[] = {
+constexpr float BT_4x4_3x3[] = {
  1./4,     0, -5./16,      0, 1./16,     0,
     0,  1./4,  -1./4, -1./16, 1./16,     0,
     0, -1./4,  -1./4,  1./16, 1./16,     0,
@@ -109,7 +109,7 @@ constexpr double BT_4x4_3x3[] = {
     0,  1./4,      0, -5./16,     0, 1./16
 };
 
-constexpr double B_4x4_3x3[] = {
+constexpr float B_4x4_3x3[] = {
    1./4,      0,     0,     0,     0,      0,
       0,   1./4, -1./4,  1./4, -1./4,   1./4,
  -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
@@ -118,14 +118,14 @@ constexpr double B_4x4_3x3[] = {
       0,      0,     0,     0,     0,  1./16
 };
 
-constexpr double AT_4x4_3x3[] = {
+constexpr float AT_4x4_3x3[] = {
  1./8,  1./4, 1./4,  1./8, 1./8,    0,
     0, -1./4, 1./4, -1./4, 1./4,    0,
     0,  1./4, 1./4,  1./2, 1./2,    0,
     0, -1./4, 1./4,    -1,    1, 1./2
 };
 
-constexpr double A_4x4_3x3[] = {
+constexpr float A_4x4_3x3[] = {
   1./8,     0,    0,     0,
   1./4, -1./4, 1./4, -1./4,
   1./4,  1./4, 1./4,  1./4,
@@ -134,7 +134,7 @@ constexpr double A_4x4_3x3[] = {
      0,     0,    0,  1./2
 };
 
-constexpr double G_2x2_5x5[] = {
+constexpr float G_2x2_5x5[] = {
      1,     0,      0,      0,      0,
   1./6, -1./6,   1./6,  -1./6,   1./6,
  -1./6, -1./6,  -1./6,  -1./6,  -1./6,
@@ -143,7 +143,7 @@ constexpr double G_2x2_5x5[] = {
      0,     0,      0,      0,      1
 };
 
-constexpr double GT_2x2_5x5[] = {
+constexpr float GT_2x2_5x5[] = {
    1,  1./6, -1./6, -4./15, 1./60, 0,
    0, -1./6, -1./6,  2./15, 1./30, 0,
    0,  1./6, -1./6, -1./15, 1./15, 0,
@@ -151,7 +151,7 @@ constexpr double GT_2x2_5x5[] = {
    0,  1./6, -1./6, -1./60, 4./15, 1
 };
 
-constexpr double BT_2x2_5x5[] = {
+constexpr float BT_2x2_5x5[] = {
  1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
     0,   1./8,  1./16,  -5./16,   1./8,    0,
     0,  -1./8, -5./16,  -1./16,   1./8,    0,
@@ -160,7 +160,7 @@ constexpr double BT_2x2_5x5[] = {
     0,   1./8,  3./16,   -1./4, -3./16, 1./8
 };
 
-constexpr double B_2x2_5x5[] = {
+constexpr float B_2x2_5x5[] = {
    1./8,      0,      0,     0,     0,      0,
   3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
   -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
@@ -169,12 +169,12 @@ constexpr double B_2x2_5x5[] = {
       0,      0,      0,     0,     0,   1./8
 };
 
-constexpr double AT_2x2_5x5[] = {
+constexpr float AT_2x2_5x5[] = {
   1./2,  1, 1,  2, 1,    0,
      0, -1, 1, -1, 2, 1./2
 };
 
-constexpr double A_2x2_5x5[] = {
+constexpr float A_2x2_5x5[] = {
  1./2,    0,
     1,   -1,
     1,    1,
@@ -186,12 +186,12 @@ constexpr double A_2x2_5x5[] = {
 
 /// Structure to keep information of constant transform matrices.
 struct TransformMatrix {
-  TransformMatrix(ArrayRef<double> table, int64_t rows, int64_t cols,
+  TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
                   int64_t scalarFactor = 1)
-      : table(llvm::map_to_vector(table, [](double val) { return APFloat(val); })), rows(rows), cols(cols), scalarFactor(scalarFactor) {
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {
   }
 
-  SmallVector<APFloat> table;
+  ArrayRef<float> table;
   int64_t rows;
   int64_t cols;
   int64_t scalarFactor;
@@ -199,17 +199,14 @@ struct TransformMatrix {
 
 /// Utility function to convert constant array to arith.constant Value.
 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
-                              TransformMatrix transform, Type type) {
-  assert(type.isFloat());
-  assert(transform.table.size() == (transform.rows * transform.cols));
-  ArrayRef<APFloat> constVec(transform.table.data(), transform.rows * transform.cols);
-
+                              TransformMatrix transform) {
+  assert(transform.table.size() == static_cast<size_t>(transform.rows * transform.cols));
+  ArrayRef<float> constVec(transform.table.data(), transform.rows * transform.cols);
+  SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
   return arith::ConstantOp::create(
       builder, loc,
       DenseFPElementsAttr::get(
-          RankedTensorType::get(
-              SmallVector<int64_t>{transform.rows, transform.cols}, type),
-          constVec));
+          RankedTensorType::get(shape, builder.getF32Type()), constVec));
 }
 
 /// Extract height x width data from 4D tensors.
@@ -407,7 +404,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
-      Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
+      Value G = create2DTransformMatrix(builder, loc, GMatrix);
       // Multiply G x g.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{G, extractFilter},
@@ -430,7 +427,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
-      Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
+      Value GT = create2DTransformMatrix(builder, loc, GTMatrix);
       // Multiply u = (G x g) x GT.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, GT},
@@ -500,6 +497,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
   std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   auto inputType = cast<ShapedType>(input.getType());
   Type elementType = inputType.getElementType();
+  // assert(elementType.isF32() && "NYI: support non-f32");
   auto inputShape = inputType.getShape(); // N, H, W, C
   int64_t inputN = inputShape[0];
   int64_t inputC = inputShape[3];
@@ -555,7 +553,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
       Value BT =
-          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BTMatrix);
       // Multiply BT x d.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{BT, matmulRetValue},
@@ -578,7 +576,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       Value B =
-          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BMatrix);
       // Multiply v = (BT x d) x B.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, B},
@@ -723,6 +721,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
   std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   auto valueType = cast<ShapedType>(value.getType());
   Type elementType = valueType.getElementType();
+  // assert(elementType.isF32() && "NYI: support non-f32");
   auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
   int64_t valueH = valueShape[0];
   int64_t valueW = valueShape[1];
@@ -786,7 +785,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
         init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       }
 
-      Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
+      Value AT = create2DTransformMatrix(builder, loc, ATMatrix);
       // Multiply AT x m.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{AT, matmulRetValue},
@@ -805,7 +804,7 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
         init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       }
 
-      Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
+      Value A = create2DTransformMatrix(builder, loc, AMatrix);
       // Multiply y = (AT x m) x A.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, A},

>From f94946b4a99d56b30aea8d7b0bf663d881c744a8 Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Sun, 14 Sep 2025 12:25:12 -0500
Subject: [PATCH 3/5] Fix test

---
 .../Linalg/winograd-conv2d-rewrite.mlir       | 98 +++++++++++++++++++
 1 file changed, 98 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
index 0c7d4e1d23f34..4bcb9b0c2c465 100644
--- a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir
@@ -145,3 +145,101 @@ func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3
   %7 = linalg.winograd_output_transform fmr(F_4_3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
   return %7 : tensor<2x4x4x2xf32>
 }
+
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL:   func.func @conv2d_type_promotion(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x6x6x5xf16>,
+// CHECK-SAME:      %[[ARG1:.*]]: tensor<2x3x3x5xf16>,
+// CHECK-SAME:      %[[ARG2:.*]]: tensor<1xf32>,
+// CHECK-SAME:      %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> {
+// CHECK-DAG:           %[[VAL_0:.*]] = arith.constant 1.024000e+03 : f32
+// CHECK-DAG:           %[[VAL_1:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32>
+// CHECK-DAG:           %[[VAL_2:.*]] = arith.constant dense<{{\[\[}}1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32>
+// CHECK-DAG:           %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:           %[[VAL_4:.*]] = arith.constant dense<{{\[\[}}2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32>
+// CHECK-DAG:           %[[VAL_5:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32>
+// CHECK-DAG:           %[[VAL_6:.*]] = arith.constant dense<{{\[\[}}1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32>
+// CHECK-DAG:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK-DAG:           %[[VAL_9:.*]] = arith.constant 5 : index
+// CHECK-DAG:           %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK-DAG:           %[[VAL_11:.*]] = arith.constant 0 : index
+// CHECK-DAG:           %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_13:.*]] = tensor.empty() : tensor<6x6x5x2xf16>
+// CHECK-NEXT:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT:             %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (tensor<6x6x5x2xf16>) {
+// CHECK-NEXT:               %[[VAL_20:.*]] = tensor.extract_slice %[[ARG1]]{{\[}}%[[VAL_15]], %[[VAL_11]], %[[VAL_11]], %[[VAL_18]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf16> to tensor<3x3xf16>
+// CHECK-NEXT:               %[[VAL_21:.*]] = tensor.empty() : tensor<6x3xf16>
+// CHECK-NEXT:               %[[VAL_22:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_21]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT:               %[[VAL_23:.*]] = linalg.matmul ins(%[[VAL_6]], %[[VAL_20]] : tensor<6x3xf32>, tensor<3x3xf16>) outs(%[[VAL_22]] : tensor<6x3xf16>) -> tensor<6x3xf16>
+// CHECK-NEXT:               %[[VAL_24:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT:               %[[VAL_25:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_24]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:               %[[VAL_26:.*]] = linalg.matmul ins(%[[VAL_23]], %[[VAL_5]] : tensor<6x3xf16>, tensor<3x6xf32>) outs(%[[VAL_25]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:               %[[VAL_27:.*]] = tensor.insert_slice %[[VAL_26]] into %[[VAL_19]]{{\[}}%[[VAL_11]], %[[VAL_11]], %[[VAL_18]], %[[VAL_15]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x5x2xf16>
+// CHECK-NEXT:               scf.yield %[[VAL_27]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[VAL_17]] : tensor<6x6x5x2xf16>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           %[[VAL_28:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:           %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:             %[[VAL_32:.*]] = scf.for %[[VAL_33:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_34:.*]] = %[[VAL_31]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:               %[[VAL_35:.*]] = scf.for %[[VAL_36:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_37:.*]] = %[[VAL_34]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:                 %[[VAL_38:.*]] = scf.for %[[VAL_39:.*]] = %[[VAL_11]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_40:.*]] = %[[VAL_37]]) -> (tensor<6x6x1x1x2x5xf16>) {
+// CHECK-NEXT:                   %[[VAL_41:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_30]])
+// CHECK-NEXT:                   %[[VAL_42:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_33]])
+// CHECK-NEXT:                   %[[VAL_43:.*]] = tensor.extract_slice %[[ARG0]]{{\[}}%[[VAL_36]], %[[VAL_41]], %[[VAL_42]], %[[VAL_39]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf16> to tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_44:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_45:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_44]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_46:.*]] = linalg.matmul ins(%[[VAL_4]], %[[VAL_43]] : tensor<6x6xf32>, tensor<6x6xf16>) outs(%[[VAL_45]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_47:.*]] = tensor.empty() : tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_48:.*]] = linalg.fill ins(%[[VAL_7]] : f16) outs(%[[VAL_47]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_49:.*]] = linalg.matmul ins(%[[VAL_46]], %[[VAL_3]] : tensor<6x6xf16>, tensor<6x6xf32>) outs(%[[VAL_48]] : tensor<6x6xf16>) -> tensor<6x6xf16>
+// CHECK-NEXT:                   %[[VAL_50:.*]] = tensor.insert_slice %[[VAL_49]] into %[[VAL_40]][0, 0, %[[VAL_30]], %[[VAL_33]], %[[VAL_36]], %[[VAL_39]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf16> into tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:                   scf.yield %[[VAL_50]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:                 }
+// CHECK-NEXT:                 scf.yield %[[VAL_38]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:               }
+// CHECK-NEXT:               scf.yield %[[VAL_35]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[VAL_32]] : tensor<6x6x1x1x2x5xf16>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           %[[VAL_51:.*]] = tensor.collapse_shape %[[VAL_14]] {{\[\[}}0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16>
+// CHECK-NEXT:           %[[VAL_52:.*]] = tensor.collapse_shape %[[VAL_29]] {{\[\[}}0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16>
+// CHECK-NEXT:           %[[VAL_53:.*]] = tensor.empty() : tensor<36x2x2xf32>
+// CHECK-NEXT:           %[[VAL_54:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_53]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:           %[[VAL_55:.*]] = linalg.batch_matmul ins(%[[VAL_52]], %[[VAL_51]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[VAL_54]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32>
+// CHECK-NEXT:           %[[VAL_56:.*]] = tensor.expand_shape %[[VAL_55]] {{\[\[}}0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32>
+// CHECK-NEXT:           %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_59:.*]] = %[[ARG3]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:             %[[VAL_60:.*]] = scf.for %[[VAL_61:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_8]] iter_args(%[[VAL_62:.*]] = %[[VAL_59]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:               %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_65:.*]] = %[[VAL_62]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:                 %[[VAL_66:.*]] = scf.for %[[VAL_67:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_8]] iter_args(%[[VAL_68:.*]] = %[[VAL_65]]) -> (tensor<2x4x4x2xf32>) {
+// CHECK-NEXT:                   %[[VAL_69:.*]] = tensor.extract_slice %[[VAL_56]][0, 0, %[[VAL_58]], %[[VAL_61]], %[[VAL_64]], %[[VAL_67]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x2x2xf32> to tensor<6x6xf32>
+// CHECK-NEXT:                   %[[VAL_70:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_58]])
+// CHECK-NEXT:                   %[[VAL_71:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_61]])
+// CHECK-NEXT:                   %[[VAL_72:.*]] = tensor.extract_slice %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x4x4x2xf32> to tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_73:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK-NEXT:                   %[[VAL_74:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_73]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:                   %[[VAL_75:.*]] = linalg.matmul ins(%[[VAL_2]], %[[VAL_69]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[VAL_74]] : tensor<4x6xf32>) -> tensor<4x6xf32>
+// CHECK-NEXT:                   %[[VAL_76:.*]] = tensor.empty() : tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_77:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_76]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_78:.*]] = linalg.matmul ins(%[[VAL_75]], %[[VAL_1]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[VAL_77]] : tensor<4x4xf32>) -> tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_79:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_78]] : f32, tensor<4x4xf32>) outs(%[[VAL_72]] : tensor<4x4xf32>) {
+// CHECK-NEXT:                   ^bb0(%[[VAL_80:.*]]: f32, %[[VAL_81:.*]]: f32, %[[VAL_82:.*]]: f32):
+// CHECK-NEXT:                     %[[VAL_83:.*]] = arith.mulf %[[VAL_80]], %[[VAL_81]] : f32
+// CHECK-NEXT:                     %[[VAL_84:.*]] = arith.addf %[[VAL_83]], %[[VAL_82]] : f32
+// CHECK-NEXT:                     linalg.yield %[[VAL_84]] : f32
+// CHECK-NEXT:                   } -> tensor<4x4xf32>
+// CHECK-NEXT:                   %[[VAL_85:.*]] = tensor.insert_slice %[[VAL_79]] into %[[VAL_68]]{{\[}}%[[VAL_64]], %[[VAL_70]], %[[VAL_71]], %[[VAL_67]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x4x4x2xf32>
+// CHECK-NEXT:                   scf.yield %[[VAL_85]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:                 }
+// CHECK-NEXT:                 scf.yield %[[VAL_66]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:               }
+// CHECK-NEXT:               scf.yield %[[VAL_63]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:             }
+// CHECK-NEXT:             scf.yield %[[VAL_60]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:           }
+// CHECK-NEXT:           return %[[VAL_57]] : tensor<2x4x4x2xf32>
+// CHECK-NEXT:         }

>From eec3881e3220661e2c28e0b05d1901b1b85b4ded Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Sun, 14 Sep 2025 12:56:16 -0500
Subject: [PATCH 4/5] Remove debug asserts

---
 mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 2e884f6f79ef6..b875b24c8fda0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -497,7 +497,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
   std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   auto inputType = cast<ShapedType>(input.getType());
   Type elementType = inputType.getElementType();
-  // assert(elementType.isF32() && "NYI: support non-f32");
   auto inputShape = inputType.getShape(); // N, H, W, C
   int64_t inputN = inputShape[0];
   int64_t inputC = inputShape[3];
@@ -721,7 +720,6 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
   std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   auto valueType = cast<ShapedType>(value.getType());
   Type elementType = valueType.getElementType();
-  // assert(elementType.isF32() && "NYI: support non-f32");
   auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F
   int64_t valueH = valueShape[0];
   int64_t valueW = valueShape[1];

>From 7025a8e0e1ac02e97970aeef3740d7f4561e359d Mon Sep 17 00:00:00 2001
From: Isaac Nudelman <isaac.nudelman at utexas.edu>
Date: Sun, 14 Sep 2025 13:05:34 -0500
Subject: [PATCH 5/5] Fix formatting

---
 .../Dialect/Linalg/Transforms/WinogradConv2D.cpp  | 15 +++++++--------
 1 file changed, 7 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index b875b24c8fda0..860f97a29a260 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -188,8 +188,7 @@ constexpr float A_2x2_5x5[] = {
 struct TransformMatrix {
   TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
                   int64_t scalarFactor = 1)
-      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {
-  }
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
 
   ArrayRef<float> table;
   int64_t rows;
@@ -200,8 +199,10 @@ struct TransformMatrix {
 /// Utility function to convert constant array to arith.constant Value.
 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
                               TransformMatrix transform) {
-  assert(transform.table.size() == static_cast<size_t>(transform.rows * transform.cols));
-  ArrayRef<float> constVec(transform.table.data(), transform.rows * transform.cols);
+  assert(transform.table.size() ==
+         static_cast<size_t>(transform.rows * transform.cols));
+  ArrayRef<float> constVec(transform.table.data(),
+                           transform.rows * transform.cols);
   SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
   return arith::ConstantOp::create(
       builder, loc,
@@ -551,8 +552,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
 
-      Value BT =
-          create2DTransformMatrix(builder, loc, BTMatrix);
+      Value BT = create2DTransformMatrix(builder, loc, BTMatrix);
       // Multiply BT x d.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{BT, matmulRetValue},
@@ -574,8 +574,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
                        .getResult();
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
-      Value B =
-          create2DTransformMatrix(builder, loc, BMatrix);
+      Value B = create2DTransformMatrix(builder, loc, BMatrix);
       // Multiply v = (BT x d) x B.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, B},



More information about the Mlir-commits mailing list