[Mlir-commits] [mlir] [mlir][Vector] Allow elementwise/broadcast swap to handle mixed types (PR #151274)

Krzysztof Drewniak llvmlistbot at llvm.org
Wed Jul 30 09:22:21 PDT 2025


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/151274

>From 15d432f01a0edc6d157bc91bbdbb3938e07e619f Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Tue, 29 Jul 2025 22:20:25 -0700
Subject: [PATCH 1/3] [mlir][Vector] Allow elementwise/broadcast swap to handle
 mixed types

This patch extends the operation that rewrites elementwise operations
whose inputs are all broadcast from the same shape to handle
mixed-types, such as when the result and input types don't match, or
when the inputs have multiple types.

PR #150867 failed to check for the possibility of type mismatches when
rewriting splat constants. In order to fix that issue, we add support
for mixed-type operations more generally.
---
 .../Vector/Transforms/VectorTransforms.cpp    | 70 ++++++++++-------
 mlir/test/Dialect/Vector/vector-sink.mlir     | 75 +++++++++++++++++--
 2 files changed, 113 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c51c7b7270fae..5ade4d6c22a39 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -965,6 +965,28 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
   std::function<bool(BitCastOp)> controlFn;
 };
 
+static bool haveSameShapeAndScaling(Type t, Type u) {
+  auto tVec = dyn_cast<VectorType>(t);
+  auto uVec = dyn_cast<VectorType>(u);
+  if (!tVec) {
+    return !uVec;
+  }
+  if (!uVec) {
+    return false;
+  }
+  return tVec.getShape() == uVec.getShape() &&
+         tVec.getScalableDims() == uVec.getScalableDims();
+}
+
+/// If `type` is shaped, clone it with `newElementType`. Otherwise,
+/// return `newElementType`.
+static Type cloneOrReplace(Type type, Type newElementType) {
+  if (auto shapedType = dyn_cast<ShapedType>(type)) {
+    return shapedType.clone(newElementType);
+  }
+  return newElementType;
+}
+
 /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
 ///
 /// Example:
@@ -988,16 +1010,14 @@ struct ReorderElementwiseOpsOnBroadcast final
                                 PatternRewriter &rewriter) const override {
     if (op->getNumResults() != 1)
       return failure();
-    if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
+    auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+    if (!resultType)
       return failure();
     if (!OpTrait::hasElementwiseMappableTraits(op))
       return rewriter.notifyMatchFailure(
           op, "Op doesn't have ElementwiseMappableTraits");
     if (op->getNumOperands() == 0)
       return failure();
-    if (op->getResults()[0].getType() != op->getOperand(0).getType())
-      return rewriter.notifyMatchFailure(op,
-                                         "result and operand type mismatch");
     if (isa<vector::FMAOp>(op)) {
       return rewriter.notifyMatchFailure(
           op,
@@ -1005,6 +1025,7 @@ struct ReorderElementwiseOpsOnBroadcast final
           "might be a scalar");
     }
 
+    Type resultElemType = resultType.getElementType();
     // Get the type of the first non-constant operand
     Operation *firstBroadcastOrSplat = nullptr;
     for (Value operand : op->getOperands()) {
@@ -1020,24 +1041,23 @@ struct ReorderElementwiseOpsOnBroadcast final
     }
     if (!firstBroadcastOrSplat)
       return failure();
-    Type firstBroadcastOrSplatType =
-        firstBroadcastOrSplat->getOperand(0).getType();
+    Type unbroadcastResultType = cloneOrReplace(
+        firstBroadcastOrSplat->getOperand(0).getType(), resultElemType);
 
-    // Make sure that all operands are broadcast from identical types:
+    // Make sure that all operands are broadcast from identically-shaped types:
     //  * scalar (`vector.broadcast` + `vector.splat`), or
     //  * vector (`vector.broadcast`).
     // Otherwise the re-ordering wouldn't be safe.
-    if (!llvm::all_of(
-            op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
-              if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
-                return (bcastOp.getOperand().getType() ==
-                        firstBroadcastOrSplatType);
-              if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
-                return (splatOp.getOperand().getType() ==
-                        firstBroadcastOrSplatType);
-              SplatElementsAttr splatConst;
-              return matchPattern(val, m_Constant(&splatConst));
-            })) {
+    if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) {
+          if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
+            return haveSameShapeAndScaling(bcastOp.getOperand().getType(),
+                                           unbroadcastResultType);
+          if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
+            return haveSameShapeAndScaling(splatOp.getOperand().getType(),
+                                           unbroadcastResultType);
+          SplatElementsAttr splatConst;
+          return matchPattern(val, m_Constant(&splatConst));
+        })) {
       return failure();
     }
 
@@ -1048,15 +1068,16 @@ struct ReorderElementwiseOpsOnBroadcast final
       SplatElementsAttr splatConst;
       if (matchPattern(operand, m_Constant(&splatConst))) {
         Attribute newConst;
-        if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
-          newConst = splatConst.resizeSplat(shapedTy);
+        Type elementType = getElementTypeOrSelf(operand.getType());
+        Type newType = cloneOrReplace(unbroadcastResultType, elementType);
+        if (auto shapedTy = dyn_cast<ShapedType>(unbroadcastResultType)) {
+          newConst = splatConst.resizeSplat(cast<ShapedType>(newType));
         } else {
           newConst = splatConst.getSplatValue<Attribute>();
         }
         Operation *newConstOp =
             operand.getDefiningOp()->getDialect()->materializeConstant(
-                rewriter, newConst, firstBroadcastOrSplatType,
-                operand.getLoc());
+                rewriter, newConst, newType, operand.getLoc());
         srcValues.push_back(newConstOp->getResult(0));
       } else {
         srcValues.push_back(operand.getDefiningOp()->getOperand(0));
@@ -1066,12 +1087,11 @@ struct ReorderElementwiseOpsOnBroadcast final
     // Create the "elementwise" Op
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
-                        firstBroadcastOrSplatType, op->getAttrs());
+                        unbroadcastResultType, op->getAttrs());
 
     // Replace the original Op with the elementwise Op
-    auto vectorType = op->getResultTypes()[0];
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        op, vectorType, elementwiseOp->getResults());
+        op, resultType, elementwiseOp->getResults());
 
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index f8638ab843ecb..d161197e4bfe4 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> {
 
 // -----
 
-// The source and the result for arith.cmp have different types - not supported
-
-// CHECK-LABEL: func.func @negative_source_and_result_mismatch
-//       CHECK:   %[[BROADCAST:.+]] = vector.broadcast
-//       CHECK:   %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
-//       CHECK:   return %[[RETURN]]
-func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
+// The source and the result for arith.cmp have different types
+
+// CHECK-LABEL: func.func @source_and_result_mismatch(
+//  CHECK-SAME: %[[ARG0:.+]]: f32)
+//       CHECK:   %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]]
+//       CHECK:   %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1>
+//       CHECK:   return %[[BROADCAST]]
+func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> {
   %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
   %1 = arith.cmpf uno, %0, %0 : vector<1xf32>
   return %1 : vector<1xi1>
@@ -321,6 +322,66 @@ func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xi
   return %2 : vector<1x4xindex>
 }
 
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_mixed_type(
+// CHECK-SAME:     %[[ARG_0:.*]]: f16) -> vector<1x4xf32> {
+// CHECK:           %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32>
+// CHECK:           return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> {
+  %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16>
+  %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32>
+  return %1 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vector_mixed_type(
+// CHECK-SAME:     %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> {
+// CHECK:           %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK:           return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> {
+  %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16>
+  %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32>
+  return %1 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_and_splat_const_mixed_type(
+// CHECK-SAME:     %[[ARG_0:.*]]: f32) -> vector<1x4xf32> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant 3 : i32
+// CHECK:           %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32>
+// CHECK:           return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32>
+  %cst = arith.constant dense<3> : vector<1x4xi32>
+  %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32>
+  return %2 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vector_and_splat_const_mixed_type(
+// CHECK-SAME:     %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32>
+// CHECK:           %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32>
+// CHECK:           return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+  %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+  %cst = arith.constant dense<3> : vector<3x4xi32>
+  %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32>
+  return %2 : vector<3x4xf32>
+}
+
 //===----------------------------------------------------------------------===//
 // [Pattern: ReorderElementwiseOpsOnTranspose]
 //===----------------------------------------------------------------------===//

>From d1e49839a3d3ac657c1fe340ce790736cba2d031 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Wed, 30 Jul 2025 09:07:40 -0700
Subject: [PATCH 2/3] Review comments

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5ade4d6c22a39..7500bf7d1d9eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1070,8 +1070,8 @@ struct ReorderElementwiseOpsOnBroadcast final
         Attribute newConst;
         Type elementType = getElementTypeOrSelf(operand.getType());
         Type newType = cloneOrReplace(unbroadcastResultType, elementType);
-        if (auto shapedTy = dyn_cast<ShapedType>(unbroadcastResultType)) {
-          newConst = splatConst.resizeSplat(cast<ShapedType>(newType));
+        if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
+          newConst = splatConst.resizeSplat(newTypeShaped);
         } else {
           newConst = splatConst.getSplatValue<Attribute>();
         }

>From 99c5d2316eddf7ae083274e6e3a5e8705d212ac4 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Wed, 30 Jul 2025 09:21:40 -0700
Subject: [PATCH 3/3] Reorder tests

---
 mlir/test/Dialect/Vector/vector-sink.mlir | 94 +++++++++++------------
 1 file changed, 47 insertions(+), 47 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index d161197e4bfe4..ef881ba05a416 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -211,53 +211,6 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
   return %1 : vector<1xf32>
 }
 
-//===----------------------------------------------------------------------===//
-// [Pattern: ReorderCastOpsOnBroadcast]
-//
-// Reorder casting ops and vector ops. The casting ops have almost identical
-// pattern, so only arith.extsi op is tested.
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
-  // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
-  %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
-  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
-  return %r : vector<2x4xi32>
-}
-
-// -----
-
-func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
-  // CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
-  %b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8>
-  %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
-  return %r : vector<2x[4]xi32>
-}
-
-// -----
-
-func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
-  // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
-  %b = vector.broadcast %a : i8 to vector<2x4xi8>
-  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
-  return %r : vector<2x4xi32>
-}
-
-// -----
-
-func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
-  // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
-  %b = vector.broadcast %a : i8 to vector<2x[4]xi8>
-  %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
-  return %r : vector<2x[4]xi32>
-}
-
 // -----
 
 // CHECK-LABEL:   func.func @broadcast_scalar_and_splat_const(
@@ -382,6 +335,53 @@ func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) ->
   return %2 : vector<3x4xf32>
 }
 
+//===----------------------------------------------------------------------===//
+// [Pattern: ReorderCastOpsOnBroadcast]
+//
+// Reorder casting ops and vector ops. The casting ops have almost identical
+// pattern, so only arith.extsi op is tested.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
+  // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
+  %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}
+
+// -----
+
+func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
+  // CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
+  %b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8>
+  %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
+  return %r : vector<2x[4]xi32>
+}
+
+// -----
+
+func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
+  // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
+  %b = vector.broadcast %a : i8 to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}
+
+// -----
+
+func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
+  // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
+  %b = vector.broadcast %a : i8 to vector<2x[4]xi8>
+  %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
+  return %r : vector<2x[4]xi32>
+}
+
 //===----------------------------------------------------------------------===//
 // [Pattern: ReorderElementwiseOpsOnTranspose]
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list