[Mlir-commits] [mlir] [mlir][Linalg] Promote lhs/rhs when vectorizing conv1D as outerproduct (PR #179883)
Abhishek Varma
llvmlistbot at llvm.org
Thu Feb 5 22:14:14 PST 2026
================
@@ -3848,8 +3848,12 @@ struct Conv1DGenerator
const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
- const Type dstType =
- cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
+ // Handle both shaped as well as scalar types.
+ Type dstType;
+ if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
+ dstType = shapedType.cloneWith(std::nullopt, dstElementType);
+ else
+ dstType = dstElementType;
----------------
Abhishek-Varma wrote:
For filter we use [extractConvFilterSlices](https://github.com/llvm/llvm-project/blob/22c5c2583dc9fcd9df7c648c810b4a3233cfc42e/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L147) - this extracts a scalar type for a non-chanelled conv op.
We didn't require this earlier because none of the non-chanelled lit tests were of mixed precision. Example: [@conv1d_8_tensor](https://github.com/llvm/llvm-project/blob/43358cb3d67cbb94566f094ca9c8ec1c4d76975a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir#L619) which is already checked in demonstrates the scalar case : https://github.com/llvm/llvm-project/blob/43358cb3d67cbb94566f094ca9c8ec1c4d76975a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir#L645
but since the conv example doesn't have mismatch in the element types of input/filter and output, the need never arose.
This PR is therefore trying to add the case of mixed precision with the proposed fix (and [the lit test](https://github.com/Abhishek-Varma/llvm-project/blob/a9a794fc8103d5546d9abf2e446293b9c447d4c5/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir#L705-L711) to demonstrate the same).
https://github.com/llvm/llvm-project/pull/179883
More information about the Mlir-commits
mailing list