[Mlir-commits] [mlir] [mlir][spirv] Fix bug for `vector.broadcast` op in `convert-vector-to-spirv` pass (PR #99928)

Angel Zhang llvmlistbot at llvm.org
Mon Jul 22 13:19:25 PDT 2024


https://github.com/angelz913 created https://github.com/llvm/llvm-project/pull/99928

This PR addresses [!17976](https://github.com/iree-org/iree/issues/17976) by using converted `resultType` instead of the original result type obtained from `castOp.getResultVectorType`. A new LIT test is also included.

>From c9cf1072b00bcb61e36d5a3ffb65bdae2a0041a0 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 22 Jul 2024 20:12:06 +0000
Subject: [PATCH] [mlir][spirv] Fix bug for vector.broadcast op in
 convert-vector-to-spirv pass

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp |  4 ++--
 .../Conversion/VectorToSPIRV/vector-to-spirv.mlir   | 13 +++++++++++++
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4390447532a5..527fbe5cf628a 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -144,8 +144,8 @@ struct VectorBroadcastConvert final
 
     SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
                                  adaptor.getSource());
-    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
-        castOp, castOp.getResultVectorType(), source);
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
+                                                             source);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 667aad7645c51..edad208749930 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -150,6 +150,19 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
 
 // -----
 
+// CHECK-LABEL: @broadcast_index
+//  CHECK-SAME: %[[ARG0:.*]]: index
+//       CHECK:   %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : index to i32
+//       CHECK:   %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CAST0]], %[[CAST0]], %[[CAST0]], %[[CAST0]] : (i32, i32, i32, i32) -> vector<4xi32>
+//       CHECK:   %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+//       CHECK:   return %[[CAST1]] : vector<4xindex>
+func.func @broadcast_index(%a: index) -> vector<4xindex> {
+  %0 = vector.broadcast %a : index to vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>



More information about the Mlir-commits mailing list