[Mlir-commits] [mlir] [mlir][Linalg] Promote lhs/rhs when vectorizing conv1D as outerproduct (PR #179883)
Abhishek Varma
llvmlistbot at llvm.org
Thu Feb 5 00:30:43 PST 2026
https://github.com/Abhishek-Varma created https://github.com/llvm/llvm-project/pull/179883
-- vector.outerproduct requires lhs/rhs to have same element type as the
result.
-- This commit adds a fix to promote lhs/rhs to have result's element
type when vectorizing conv1D slice to vector.outerproduct.
-- This is along the similar lines of what happens when we are
vectorizing conv1D slice to vector.contract - the corresponding
CHECK line was incorrect and this commit fixes that too.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
>From a9a794fc8103d5546d9abf2e446293b9c447d4c5 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 5 Feb 2026 08:19:21 +0000
Subject: [PATCH] [mlir][Linalg] Promote lhs/rhs when vectorizing conv1D as
outerproduct
-- vector.outerproduct requires lhs/rhs to have same element type as the
result.
-- This commit adds a fix to promote lhs/rhs to have result's element
type when vectorizing conv1D slice to vector.outerproduct.
-- This is along the similar lines of what happens when we are
vectorizing conv1D slice to vector.contract - the corresponding
CHECK line was incorrect and this commit fixes that too.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
.../Linalg/Transforms/Vectorization.cpp | 10 +++-
.../convolution-with-patterns.mlir | 59 ++++++++++++++++++-
2 files changed, 65 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 24579f6aa0217..7ac759f635f87 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3848,8 +3848,12 @@ struct Conv1DGenerator
const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
- const Type dstType =
- cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
+ // Handle both shaped as well as scalar types.
+ Type dstType;
+ if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
+ dstType = shapedType.cloneWith(std::nullopt, dstElementType);
+ else
+ dstType = dstElementType;
if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
return arith::SIToFPOp::create(rewriter, loc, dstType, val);
@@ -3888,6 +3892,8 @@ struct Conv1DGenerator
// convolution.
Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
+ lhs = promote(rewriter, loc, lhs, res.getType());
+ rhs = promote(rewriter, loc, rhs, res.getType());
return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
rhs, res, vector::CombiningKind::ADD);
}
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
index f8781ff5452d9..97b27befd44e2 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
@@ -678,6 +678,59 @@ module attributes {transform.with_named_sequence} {
// -----
+// Test for mixed precision hanlding of 1D non-channeled convolution.
+func.func @conv1d_mixed_precision_bf16_f32(%input: tensor<5xbf16>, %filter: tensor<2xbf16>, %output: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = linalg.conv_1d ins(%input, %filter : tensor<5xbf16>, tensor<2xbf16>)
+ outs(%output : tensor<4xf32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// CHECK: func @conv1d_mixed_precision_bf16_f32
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<5xbf16>, %[[FILTER:.+]]: tensor<2xbf16>, %[[OUTPUT:.+]]: tensor<4xf32>)
+
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[BF0:.+]] = arith.constant 0.000000e+00 : bf16
+
+/// Read the whole data in one shot.
+// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]]], %[[BF0]]
+// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]]], %[[BF0]]
+// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]]], %[[F0]]
+
+// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [0], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16>
+// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<5xbf16> to vector<4xbf16>
+
+// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : bf16 from vector<2xbf16>
+// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : bf16 from vector<2xbf16>
+
+/// Extend input and filter to f32 and then perform outerproduct.
+/// kw == 0
+// CHECK: %[[V_INPUT_0_F32:.+]] = arith.extf %[[V_INPUT_0]] : vector<4xbf16> to vector<4xf32>
+// CHECK: %[[V_FILTER_0_F32:.+]] = arith.extf %[[V_FILTER_0]] : bf16 to f32
+// CHECK: %[[RES_0:.+]] = vector.outerproduct %[[V_INPUT_0_F32]], %[[V_FILTER_0_F32]], %[[V_OUTPUT_R]] {kind = #vector.kind<add>}
+// CHECK-SAME: : vector<4xf32>, f32
+/// kw == 1
+// CHECK: %[[V_INPUT_1_F32:.+]] = arith.extf %[[V_INPUT_1]] : vector<4xbf16> to vector<4xf32>
+// CHECK: %[[V_FILTER_1_F32:.+]] = arith.extf %[[V_FILTER_1]] : bf16 to f32
+// CHECK: %[[RES_1:.+]] = vector.outerproduct %[[V_INPUT_1_F32]], %[[V_FILTER_1_F32]], %[[RES_0]] {kind = #vector.kind<add>}
+// CHECK-SAME: : vector<4xf32>, f32
+
+// Write the result back in one shot.
+// CHECK: vector.transfer_write %[[RES_1]], %[[OUTPUT]][%[[C0]]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_1d", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
@@ -801,8 +854,10 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<3x2xf16> from vector<1x3x2xf16>
-// CHECK: %[[CONT:.*]] = vector.contract
-// {{.*}} %[[V_INPUT_R]], %[[V_FILTER_1]], %[[V_OUTPUT_R]] : vector<1x2x3xf16>, vector<3x2xf16> into vector<1x2x2xf32>
+// CHECK: %[[V_INPUT_F32:.+]] = arith.extf %[[V_INPUT_R]] : vector<1x2x3xf16> to vector<1x2x3xf32>
+// CHECK: %[[V_FILTER_F32:.+]] = arith.extf %[[V_FILTER_1]] : vector<3x2xf16> to vector<3x2xf32>
+// CHECK: %[[CONT:.+]] = vector.contract
+// CHECK-SAME: %[[V_INPUT_F32]], %[[V_FILTER_F32]], %[[V_OUTPUT_R]] : vector<1x2x3xf32>, vector<3x2xf32> into vector<1x2x2xf32>
// CHECK: vector.transfer_write %[[CONT]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
module attributes {transform.with_named_sequence} {
More information about the Mlir-commits
mailing list