[Mlir-commits] [mlir] [mlir][spirv] Add pattern matching for arith.index_cast i1 to index for ArithToSPIRV (PR #155729)
Ian Li
llvmlistbot at llvm.org
Fri Aug 29 08:36:10 PDT 2025
https://github.com/ianayl updated https://github.com/llvm/llvm-project/pull/155729
>From f446699bc016509b7a7c6c0a2170b61d0b8709c8 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 27 Aug 2025 16:51:17 -0700
Subject: [PATCH 1/4] [mlir][spirv] Add pattern matching for arith.index_cast
i1 to index
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 37 ++++++++++++++++++-
.../ArithToSPIRV/arith-to-spirv.mlir | 7 ++++
2 files changed, 43 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265293b83f84c..172f322a12fd8 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -607,6 +607,41 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
}
};
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.index_cast to spirv.Select if the type of source is i1 or
+/// vector of i1.
+struct IndexCastI1Pattern final : public OpConversionPattern<arith::IndexCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = adaptor.getOperands().front().getType();
+ if (!srcType.isInteger(1))
+ return failure();
+
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+ // if (!dstType.isIndex()) {
+ // llvm::errs() << "why doesnt this work?\n";
+ // return failure();
+ // }
+
+ auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+ Location loc = op.getLoc();
+ Type spirvI32T = converter->getIndexType();
+ Value zero = spirv::ConstantOp::getZero(spirvI32T, loc, rewriter);
+ Value one = spirv::ConstantOp::getOne(spirvI32T, loc, rewriter);
+ auto newOp = rewriter.replaceOpWithNewOp<spirv::SelectOp>(
+ op, dstType, adaptor.getOperands().front(), one, zero);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
@@ -1328,7 +1363,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
- TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+ TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastI1Pattern,
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 6e2352e706acc..8bb63fff861ce 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,6 +734,13 @@ func.func @index_castui4(%arg0: index) {
return
}
+// CHECK-LABEL: index_casti1_1
+func.func @index_casti1_1(%arg0 : i1) -> index {
+ // CHECK: spirv.Select %{{.+}}, %{{.+}}, %{{.+}} : i1, i32
+ %0 = arith.index_cast %arg0 : i1 to index
+ return %0 : index
+}
+
// CHECK-LABEL: @bit_cast
func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
// CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>
>From 1ba2dcf5c81b98b67ebf95c9052c28119f230e99 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 28 Aug 2025 13:37:27 -0700
Subject: [PATCH 2/4] Remove redundancy, add missing lit checks
---
.../lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 16 +++++-----------
.../Conversion/ArithToSPIRV/arith-to-spirv.mlir | 8 +++++---
2 files changed, 10 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 172f322a12fd8..b9e04e456ff72 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -613,7 +613,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
/// Converts arith.index_cast to spirv.Select if the type of source is i1 or
/// vector of i1.
-struct IndexCastI1Pattern final : public OpConversionPattern<arith::IndexCastOp> {
+struct IndexCastI1IndexPattern final : public OpConversionPattern<arith::IndexCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -626,17 +626,11 @@ struct IndexCastI1Pattern final : public OpConversionPattern<arith::IndexCastOp>
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
- // if (!dstType.isIndex()) {
- // llvm::errs() << "why doesnt this work?\n";
- // return failure();
- // }
- auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
Location loc = op.getLoc();
- Type spirvI32T = converter->getIndexType();
- Value zero = spirv::ConstantOp::getZero(spirvI32T, loc, rewriter);
- Value one = spirv::ConstantOp::getOne(spirvI32T, loc, rewriter);
- auto newOp = rewriter.replaceOpWithNewOp<spirv::SelectOp>(
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(
op, dstType, adaptor.getOperands().front(), one, zero);
return success();
}
@@ -1363,7 +1357,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
- TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastI1Pattern,
+ TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastI1IndexPattern,
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 8bb63fff861ce..938a5ccfed542 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,9 +734,11 @@ func.func @index_castui4(%arg0: index) {
return
}
-// CHECK-LABEL: index_casti1_1
-func.func @index_casti1_1(%arg0 : i1) -> index {
- // CHECK: spirv.Select %{{.+}}, %{{.+}}, %{{.+}} : i1, i32
+// CHECK-LABEL: index_casti1index_1
+func.func @index_casti1index_1(%arg0 : i1) -> index {
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+ // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
+ // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
%0 = arith.index_cast %arg0 : i1 to index
return %0 : index
}
>From ac32e57f1fdfc358e39e9e01a5787f38d0d3c513 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 28 Aug 2025 13:51:03 -0700
Subject: [PATCH 3/4] remove redundant return
---
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 938a5ccfed542..e86b04527383d 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -735,12 +735,12 @@ func.func @index_castui4(%arg0: index) {
}
// CHECK-LABEL: index_casti1index_1
-func.func @index_casti1index_1(%arg0 : i1) -> index {
+func.func @index_casti1index_1(%arg0 : i1) {
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
// CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
%0 = arith.index_cast %arg0 : i1 to index
- return %0 : index
+ return
}
// CHECK-LABEL: @bit_cast
>From 13f3d477dd241f50d478e6536b8d2f57547e5fd4 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 08:35:59 -0700
Subject: [PATCH 4/4] clang-format
---
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index b9e04e456ff72..b55322816fd31 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -613,7 +613,8 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
/// Converts arith.index_cast to spirv.Select if the type of source is i1 or
/// vector of i1.
-struct IndexCastI1IndexPattern final : public OpConversionPattern<arith::IndexCastOp> {
+struct IndexCastI1IndexPattern final
+ : public OpConversionPattern<arith::IndexCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
More information about the Mlir-commits
mailing list