[Mlir-commits] [mlir] [mlir][spirv] Fix bug for `vector.broadcast` op in `convert-vector-to-spirv` pass (PR #99928)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 22 13:19:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Angel Zhang (angelz913)

<details>
<summary>Changes</summary>

This PR addresses [!17976](https://github.com/iree-org/iree/issues/17976) by using converted `resultType` instead of the original result type obtained from `castOp.getResultVectorType`. A new LIT test is also included.

---
Full diff: https://github.com/llvm/llvm-project/pull/99928.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+2-2) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+13) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4390447532a5..527fbe5cf628a 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -144,8 +144,8 @@ struct VectorBroadcastConvert final
 
     SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
                                  adaptor.getSource());
-    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
-        castOp, castOp.getResultVectorType(), source);
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
+                                                             source);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 667aad7645c51..edad208749930 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -150,6 +150,19 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
 
 // -----
 
+// CHECK-LABEL: @broadcast_index
+//  CHECK-SAME: %[[ARG0:.*]]: index
+//       CHECK:   %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : index to i32
+//       CHECK:   %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CAST0]], %[[CAST0]], %[[CAST0]], %[[CAST0]] : (i32, i32, i32, i32) -> vector<4xi32>
+//       CHECK:   %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+//       CHECK:   return %[[CAST1]] : vector<4xindex>
+func.func @broadcast_index(%a: index) -> vector<4xindex> {
+  %0 = vector.broadcast %a : index to vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/99928


More information about the Mlir-commits mailing list