[Mlir-commits] [mlir] [MLIR][Linalg] Fix winograd op lowering for types smaller than f32 (PR #158500)
    Hsiangkai Wang 
    llvmlistbot at llvm.org
       
    Thu Oct  2 01:27:23 PDT 2025
    
    
  
Hsiangkai wrote:
Something like
```
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -201,12 +201,16 @@ Value create2DTransformMatrix(OpBuilder &builder, Location loc,
                               TransformMatrix transform, Type type) {
   ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+  SmallVector<Attribute> constAttrVec;
+  for (float v : constVec)
+    constAttrVec.push_back(builder.getFloatAttr(type, v));
+
   return arith::ConstantOp::create(
       builder, loc,
       DenseFPElementsAttr::get(
           RankedTensorType::get(
               SmallVector<int64_t>{transform.rows, transform.cols}, type),
-          constVec));
+          constAttrVec));
 }
 /// Extract height x width data from 4D tensors.
@@ -552,7 +556,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       Value BT =
-          create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BTMatrix, elementType);
       // Multiply BT x d.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{BT, matmulRetValue},
@@ -575,7 +579,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
       auto init =
           linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
       Value B =
-          create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+          create2DTransformMatrix(builder, loc, BMatrix, elementType);
       // Multiply v = (BT x d) x B.
       auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
                                                ValueRange{matmulRetValue, B},
```
https://github.com/llvm/llvm-project/pull/158500
    
    
More information about the Mlir-commits
mailing list