[Mlir-commits] [mlir] [MLIR][Linalg] Fix winograd op lowering for types smaller than f32 (PR #158500)
Hsiangkai Wang
llvmlistbot at llvm.org
Thu Oct 2 01:27:23 PDT 2025
Hsiangkai wrote:
Something like
```
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -201,12 +201,16 @@ Value create2DTransformMatrix(OpBuilder &builder, Location loc,
TransformMatrix transform, Type type) {
ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
+ SmallVector<Attribute> constAttrVec;
+ for (float v : constVec)
+ constAttrVec.push_back(builder.getFloatAttr(type, v));
+
return arith::ConstantOp::create(
builder, loc,
DenseFPElementsAttr::get(
RankedTensorType::get(
SmallVector<int64_t>{transform.rows, transform.cols}, type),
- constVec));
+ constAttrVec));
}
/// Extract height x width data from 4D tensors.
@@ -552,7 +556,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value BT =
- create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+ create2DTransformMatrix(builder, loc, BTMatrix, elementType);
// Multiply BT x d.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{BT, matmulRetValue},
@@ -575,7 +579,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
Value B =
- create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+ create2DTransformMatrix(builder, loc, BMatrix, elementType);
// Multiply v = (BT x d) x B.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, B},
```
https://github.com/llvm/llvm-project/pull/158500
More information about the Mlir-commits
mailing list