[Mlir-commits] [mlir] [mlir][spirv] Add pattern matching for arith.index_cast index to i1 for ArithToSPIRV (PR #156031)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 29 07:53:57 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Ian Li (ianayl)

<details>
<summary>Changes</summary>

Currently, `arith.index_cast` gets converted to `OpSConvert`: https://github.com/llvm/llvm-project/blob/9bf5bf3baf3c7aec82cdd235c6a2fd57b4dd55ab/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp#L1331 [OpSConvert requires its operands to be of integer type](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSConvert), which poses an issue for `i1` since SPIRV distinguishes between booleans and integers. As a result, the following example doesn't get converted, leaving behind illegal ops:
```
%0 = arith.index_cast %arg0 : index to i1
```
This PR adds additional logic to convert `arith.index_casts` to SPIRV dialect when casting from `index` to `i1`. Converting `index_cast`s from `i1` to `index` is submitted as https://github.com/llvm/llvm-project/pull/155729. 

---
Full diff: https://github.com/llvm/llvm-project/pull/156031.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+31-1) 
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+11) 


``````````diff
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265293b83f84c..de43b5e7fb176 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -607,6 +607,36 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.index_cast to spirv.Select if the type of source is index.
+struct IndexCastIndexI1Pattern final : public OpConversionPattern<arith::IndexCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = adaptor.getOperands().front().getType();
+    // Indexes have already been converted to its respective spirv type:
+    Type indexType = getTypeConverter<SPIRVTypeConverter>()->getIndexType();
+    if (srcType != indexType || !op.getType().isInteger(1))
+      return failure();
+
+    Type dstType = rewriter.getI1Type();
+    Location loc = op.getLoc();
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
+    auto isZero = spirv::IEqualOp::create(
+        rewriter, loc, dstType, zeroIdx, adaptor.getOperands().front());
+    // spriv.IEqual outputs i32, spirv.Select is used to truncate to i1:
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isZero, zero, one);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ExtSIOp
 //===----------------------------------------------------------------------===//
@@ -1328,7 +1358,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
     TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
-    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
     TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 6e2352e706acc..3109edf5d87d6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,6 +734,17 @@ func.func @index_castui4(%arg0: index) {
   return
 }
 
+// CHECK-LABEL: index_castindexi1
+func.func @index_castindexi1(%arg0 : index) {
+  // CHECK: %[[FALSE:.+]] = spirv.Constant false
+  // CHECK: %[[TRUE:.+]] = spirv.Constant true
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+  // CHECK: %[[IS_ZERO:.+]] = spirv.IEqual %[[ZERO]], %{{.+}} : i32
+  // CHECK: spirv.Select %[[IS_ZERO]], %[[FALSE]], %[[TRUE]] : i1, i1
+  %0 = arith.index_cast %arg0 : index to i1
+  return
+}
+
 // CHECK-LABEL: @bit_cast
 func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
   // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/156031


More information about the Mlir-commits mailing list