[Mlir-commits] [mlir] c5a141b - [mlir][spirv] Add pattern matching for arith.index_cast index to i1 for ArithToSPIRV (#156031)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 3 09:48:57 PDT 2025
Author: Ian Li
Date: 2025-09-03T12:48:53-04:00
New Revision: c5a141bb8b5dfc4be06bb165d88e724e66bf5c4c
URL: https://github.com/llvm/llvm-project/commit/c5a141bb8b5dfc4be06bb165d88e724e66bf5c4c
DIFF: https://github.com/llvm/llvm-project/commit/c5a141bb8b5dfc4be06bb165d88e724e66bf5c4c.diff
LOG: [mlir][spirv] Add pattern matching for arith.index_cast index to i1 for ArithToSPIRV (#156031)
Currently, `arith.index_cast` gets converted to `OpSConvert`:
https://github.com/llvm/llvm-project/blob/9bf5bf3baf3c7aec82cdd235c6a2fd57b4dd55ab/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp#L1331
[OpSConvert requires its operands to be of integer
type](https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSConvert),
which poses an issue for `i1` since SPIRV distinguishes between booleans
and integers. As a result, the following example doesn't get converted,
leaving behind illegal ops:
```
%0 = arith.index_cast %arg0 : index to i1
```
This PR adds additional logic to convert `arith.index_casts` to SPIRV
dialect when casting from `index` to `i1`. Converting `index_cast`s from
`i1` to `index` is submitted as
https://github.com/llvm/llvm-project/pull/155729.
---------
Co-authored-by: Md Abdullah Shahneous Bari <98356296+mshahneo at users.noreply.github.com>
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Added:
Modified:
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265293b83f84c..026fd0f0c4774 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -607,6 +607,34 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
}
};
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
+struct IndexCastIndexI1Pattern final
+ : public OpConversionPattern<arith::IndexCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isBoolScalarOrVector(op.getType()))
+ return failure();
+
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ Location loc = op.getLoc();
+ Value zeroIdx =
+ spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
+ adaptor.getIn());
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
@@ -1328,7 +1356,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
- TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+ TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 6e2352e706acc..0d2bda9c74807 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,6 +734,31 @@ func.func @index_castui4(%arg0: index) {
return
}
+// CHECK-LABEL: index_castindexi1_1
+func.func @index_castindexi1_1(%arg0: index) {
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+ // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
+ %0 = arith.index_cast %arg0 : index to i1
+ return
+}
+
+// CHECK-LABEL: index_castindexi1_2
+func.func @index_castindexi1_2(%arg0: vector<1xindex>) -> vector<1xi1> {
+ // Single-element vectors do not exist in SPIRV.
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+ // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
+ %0 = arith.index_cast %arg0 : vector<1xindex> to vector<1xi1>
+ return %0 : vector<1xi1>
+}
+
+// CHECK-LABEL: index_castindexi1_3
+func.func @index_castindexi1_3(%arg0: vector<3xindex>) {
+ // CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
+ // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : vector<3xi32>
+ %0 = arith.index_cast %arg0 : vector<3xindex> to vector<3xi1>
+ return
+}
+
// CHECK-LABEL: @bit_cast
func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
// CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>
More information about the Mlir-commits
mailing list