[Mlir-commits] [mlir] db1d881 - [mlir][spirv] Fix bug for `vector.broadcast` op in `convert-vector-to-spirv` pass (#99928)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 22 15:57:14 PDT 2024


Author: Angel Zhang
Date: 2024-07-22T18:57:11-04:00
New Revision: db1d88137212fec6c884dcb0f76a8dfab4fcab98

URL: https://github.com/llvm/llvm-project/commit/db1d88137212fec6c884dcb0f76a8dfab4fcab98
DIFF: https://github.com/llvm/llvm-project/commit/db1d88137212fec6c884dcb0f76a8dfab4fcab98.diff

LOG: [mlir][spirv] Fix bug for `vector.broadcast` op in `convert-vector-to-spirv` pass (#99928)

This PR addresses
[!17976](https://github.com/iree-org/iree/issues/17976) by using
converted `resultType` instead of the original result type obtained from
`castOp.getResultVectorType`. A new LIT test is also included.

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4390447532a5..527fbe5cf628a 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -144,8 +144,8 @@ struct VectorBroadcastConvert final
 
     SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
                                  adaptor.getSource());
-    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
-        castOp, castOp.getResultVectorType(), source);
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
+                                                             source);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 667aad7645c51..edad208749930 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -150,6 +150,19 @@ func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) {
 
 // -----
 
+// CHECK-LABEL: @broadcast_index
+//  CHECK-SAME: %[[ARG0:.*]]: index
+//       CHECK:   %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : index to i32
+//       CHECK:   %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CAST0]], %[[CAST0]], %[[CAST0]], %[[CAST0]] : (i32, i32, i32, i32) -> vector<4xi32>
+//       CHECK:   %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
+//       CHECK:   return %[[CAST1]] : vector<4xindex>
+func.func @broadcast_index(%a: index) -> vector<4xindex> {
+  %0 = vector.broadcast %a : index to vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
 // CHECK-LABEL: @extract
 //  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
 //       CHECK:   spirv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>


        


More information about the Mlir-commits mailing list