[Mlir-commits] [mlir] 6aa9d92 - [mlir][spirv] Add pattern matching for arith.index_cast i1 to index for ArithToSPIRV (#155729)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 3 11:04:39 PDT 2025


Author: Ian Li
Date: 2025-09-03T14:04:35-04:00
New Revision: 6aa9d928a86019ab8997fa9fb7c5533a67ed1a8d

URL: https://github.com/llvm/llvm-project/commit/6aa9d928a86019ab8997fa9fb7c5533a67ed1a8d
DIFF: https://github.com/llvm/llvm-project/commit/6aa9d928a86019ab8997fa9fb7c5533a67ed1a8d.diff

LOG: [mlir][spirv] Add pattern matching for arith.index_cast i1 to index for ArithToSPIRV (#155729)

Currently, `arith.index_cast` gets converted to `OpSConvert`:
https://github.com/llvm/llvm-project/blob/9bf5bf3baf3c7aec82cdd235c6a2fd57b4dd55ab/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp#L1331
[OpSConvert requires its operands to be of integer
type](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSConvert),
which poses an issue for `i1` since SPIRV distinguishes between booleans
and integers. As a result, the following example doesn't get converted,
leaving behind illegal ops:
```
%0 = arith.index_cast %arg0 : i1 to index
```
This PR adds additional logic to convert `arith.index_casts` to SPIRV
dialect when casting from `i1` to `index`. Converting `index_cast`s from
`index` to `i1` is a part of
https://github.com/llvm/llvm-project/pull/156031.

Added: 
    

Modified: 
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 026fd0f0c4774..b99a8a3fe17b1 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
+/// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
 struct IndexCastIndexI1Pattern final
     : public OpConversionPattern<arith::IndexCastOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -635,6 +635,30 @@ struct IndexCastIndexI1Pattern final
   }
 };
 
+/// Converts arith.index_cast to spirv.Select if the source type is i1.
+struct IndexCastI1IndexPattern final
+    : public OpConversionPattern<arith::IndexCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!isBoolScalarOrVector(adaptor.getIn().getType()))
+      return failure();
+
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    Location loc = op.getLoc();
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
+                                                 one, zero);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ExtSIOp
 //===----------------------------------------------------------------------===//
@@ -1356,7 +1380,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
     TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
-    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
+    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+    IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
     TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 0d2bda9c74807..3cb5294598994 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -759,6 +759,34 @@ func.func @index_castindexi1_3(%arg0: vector<3xindex>) {
   return
 }
 
+// CHECK-LABEL: index_casti1index_1
+func.func @index_casti1index_1(%arg0 : i1) {
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
+  // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
+  %0 = arith.index_cast %arg0 : i1 to index
+  return
+}
+
+// CHECK-LABEL: index_casti1index_2
+func.func @index_casti1index_2(%arg0 : vector<1xi1>) -> vector<1xindex> {
+  // Single-element vectors do not exist in SPIRV.
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
+  // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
+  %0 = arith.index_cast %arg0 : vector<1xi1> to vector<1xindex>
+  return %0 : vector<1xindex>
+}
+
+// CHECK-LABEL: index_casti1index_3
+func.func @index_casti1index_3(%arg0 : vector<3xi1>) {
+  // CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
+  // CHECK: %[[ONE:.+]] = spirv.Constant dense<1> : vector<3xi32>
+  // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : vector<3xi1>, vector<3xi32>
+  %0 = arith.index_cast %arg0 : vector<3xi1> to vector<3xindex>
+  return
+}
+
 // CHECK-LABEL: @bit_cast
 func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
   // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>


        


More information about the Mlir-commits mailing list