[Mlir-commits] [mlir] 6ed921f - Reland "[mlir][vector] Use vector.broadcast in place of vector.splat" (#150138)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 23 06:01:03 PDT 2025


Author: James Newling
Date: 2025-07-23T06:00:59-07:00
New Revision: 6ed921f9675b7f1bb840f115d81cede4d182cdc2

URL: https://github.com/llvm/llvm-project/commit/6ed921f9675b7f1bb840f115d81cede4d182cdc2
DIFF: https://github.com/llvm/llvm-project/commit/6ed921f9675b7f1bb840f115d81cede4d182cdc2.diff

LOG: Reland "[mlir][vector] Use vector.broadcast in place of vector.splat" (#150138)

This reverts commit 228c45f13dc92546661b6825b7b32c3808b0d2eb (PR
#148937) . Now that #148027 is landed, I think it is safe to "reland"
the original PR: #148028

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
    mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
    mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index f14264e2f55f3..55b757c136127 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,7 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
                                vector::OuterProductOp, vector::ScanOp>(
       [&](Operation *op) { return converter.isLegal(op); });
   target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
-                    arith::ConstantOp, vector::SplatOp>();
+                    arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
 }
 
 void EmulateUnsupportedFloatsPass::runOnOperation() {

diff  --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 5d253c1199dc0..f5f0bfa4128aa 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
 
   Type elementType = getElementTypeOrSelf(memref.getType());
   auto vt = VectorType::get(vectorShape, elementType);
-  Value res = vector::SplatOp::create(b, loc, vt, loads[0]);
+  Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
   foreachIndividualVectorElement(
       res,
       /*applyFn=*/

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index de67098d397f4..0d44415595cb8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -438,7 +438,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
           Value inc = arith::ConstantIndexOp::create(rewriter, loc,
                                                      i * blockedChunkSize);
           Value incVec =
-              vector::SplatOp::create(rewriter, loc, indiceType, inc);
+              vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
           Value offsetIndice =
               arith::AddIOp::create(rewriter, loc, indice, incVec);
 

diff  --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
index 07e03f3b8473d..bbe27fe1b99d9 100644
--- a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -20,14 +20,14 @@ func.func @matmul_16x8x4xf32_global(
 // CHECK:           %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32>
-// CHECK:           %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.broadcast %[[VAL_6]] : f32 to vector<2x1xf32>
 // CHECK:           %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32>
 // CHECK:           %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32>
 //
 // CHECK:           %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32>
-// CHECK:           %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32>
+// CHECK:           %[[VAL_16:.*]] = vector.broadcast %[[VAL_15]] : f32 to vector<1x1xf32>
 // CHECK:           %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32>
 //
 // CHECK:           %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
@@ -42,7 +42,7 @@ func.func @matmul_16x8x4xf32_global(
 // CHECK:           %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32>
-// CHECK:           %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32>
+// CHECK:           %[[VAL_30:.*]] = vector.broadcast %[[VAL_20]] : f32 to vector<2x2xf32>
 // CHECK:           %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32>
 // CHECK:           %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32>
 // CHECK:           %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>


        


More information about the Mlir-commits mailing list