[Mlir-commits] [mlir] [mlir][spirv] Add pattern matching for arith.index_cast index to i1 for ArithToSPIRV (PR #156031)
Ian Li
llvmlistbot at llvm.org
Fri Aug 29 09:21:26 PDT 2025
https://github.com/ianayl updated https://github.com/llvm/llvm-project/pull/156031
>From 6e1a6ab301f923ea52d3fcb61ccbdbf34f5f3935 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 28 Aug 2025 20:49:45 -0700
Subject: [PATCH 1/5] Add conversion from arith.index_cast index->i1 to SPIRV
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 32 ++++++++++++++++++-
.../ArithToSPIRV/arith-to-spirv.mlir | 11 +++++++
2 files changed, 42 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265293b83f84c..de43b5e7fb176 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -607,6 +607,36 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
}
};
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.index_cast to spirv.Select if the type of source is index.
+struct IndexCastIndexI1Pattern final : public OpConversionPattern<arith::IndexCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type srcType = adaptor.getOperands().front().getType();
+ // Indexes have already been converted to its respective spirv type:
+ Type indexType = getTypeConverter<SPIRVTypeConverter>()->getIndexType();
+ if (srcType != indexType || !op.getType().isInteger(1))
+ return failure();
+
+ Type dstType = rewriter.getI1Type();
+ Location loc = op.getLoc();
+ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+ Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+ Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
+ auto isZero = spirv::IEqualOp::create(
+ rewriter, loc, dstType, zeroIdx, adaptor.getOperands().front());
+ // spriv.IEqual outputs i32, spirv.Select is used to truncate to i1:
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isZero, zero, one);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//
@@ -1328,7 +1358,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
- TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+ TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 6e2352e706acc..3109edf5d87d6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,6 +734,17 @@ func.func @index_castui4(%arg0: index) {
return
}
+// CHECK-LABEL: index_castindexi1
+func.func @index_castindexi1(%arg0 : index) {
+ // CHECK: %[[FALSE:.+]] = spirv.Constant false
+ // CHECK: %[[TRUE:.+]] = spirv.Constant true
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+ // CHECK: %[[IS_ZERO:.+]] = spirv.IEqual %[[ZERO]], %{{.+}} : i32
+ // CHECK: spirv.Select %[[IS_ZERO]], %[[FALSE]], %[[TRUE]] : i1, i1
+ %0 = arith.index_cast %arg0 : index to i1
+ return
+}
+
// CHECK-LABEL: @bit_cast
func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
// CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>
>From c03e924e4823255b5e0d4c378035f9d70f7d5788 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 08:16:01 -0700
Subject: [PATCH 2/5] clang-format
---
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index de43b5e7fb176..41ed211ba3731 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -612,7 +612,8 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
//===----------------------------------------------------------------------===//
/// Converts arith.index_cast to spirv.Select if the type of source is index.
-struct IndexCastIndexI1Pattern final : public OpConversionPattern<arith::IndexCastOp> {
+struct IndexCastIndexI1Pattern final
+ : public OpConversionPattern<arith::IndexCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -629,10 +630,11 @@ struct IndexCastIndexI1Pattern final : public OpConversionPattern<arith::IndexCa
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
- auto isZero = spirv::IEqualOp::create(
- rewriter, loc, dstType, zeroIdx, adaptor.getOperands().front());
+ auto isZero = spirv::IEqualOp::create(rewriter, loc, dstType, zeroIdx,
+ adaptor.getOperands().front());
// spriv.IEqual outputs i32, spirv.Select is used to truncate to i1:
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isZero, zero, one);
+ rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isZero, zero,
+ one);
return success();
}
};
>From fa8db0aade8b2f71abd912c8b529bcb616fd57ff Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 09:17:36 -0700
Subject: [PATCH 3/5] remove redundant truncate
---
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 9 +--------
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 5 +----
2 files changed, 2 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 41ed211ba3731..a5469f506fda8 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -625,16 +625,9 @@ struct IndexCastIndexI1Pattern final
if (srcType != indexType || !op.getType().isInteger(1))
return failure();
- Type dstType = rewriter.getI1Type();
Location loc = op.getLoc();
- Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
- Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
- auto isZero = spirv::IEqualOp::create(rewriter, loc, dstType, zeroIdx,
- adaptor.getOperands().front());
- // spriv.IEqual outputs i32, spirv.Select is used to truncate to i1:
- rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isZero, zero,
- one);
+ rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, rewriter.getI1Type(), zeroIdx, adaptor.getOperands().front());
return success();
}
};
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 3109edf5d87d6..f3f5a5fadc0b6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -736,11 +736,8 @@ func.func @index_castui4(%arg0: index) {
// CHECK-LABEL: index_castindexi1
func.func @index_castindexi1(%arg0 : index) {
- // CHECK: %[[FALSE:.+]] = spirv.Constant false
- // CHECK: %[[TRUE:.+]] = spirv.Constant true
// CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[IS_ZERO:.+]] = spirv.IEqual %[[ZERO]], %{{.+}} : i32
- // CHECK: spirv.Select %[[IS_ZERO]], %[[FALSE]], %[[TRUE]] : i1, i1
+ // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
%0 = arith.index_cast %arg0 : index to i1
return
}
>From ca2e36a3992f2a0bbee4e4efd83b20c5a7438b18 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 09:20:08 -0700
Subject: [PATCH 4/5] clang-format
---
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a5469f506fda8..9ed7602cd9789 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -627,7 +627,8 @@ struct IndexCastIndexI1Pattern final
Location loc = op.getLoc();
Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
- rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, rewriter.getI1Type(), zeroIdx, adaptor.getOperands().front());
+ rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(
+ op, rewriter.getI1Type(), zeroIdx, adaptor.getOperands().front());
return success();
}
};
>From fa93f78d0eebf39de5d1dff0e88d7eff0d7450d0 Mon Sep 17 00:00:00 2001
From: Ian Li <ianayl.work at gmail.com>
Date: Fri, 29 Aug 2025 12:21:17 -0400
Subject: [PATCH 5/5] Fix comment
Co-authored-by: Md Abdullah Shahneous Bari <98356296+mshahneo at users.noreply.github.com>
---
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9ed7602cd9789..c53a3c8b10098 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
// IndexCastOp
//===----------------------------------------------------------------------===//
-/// Converts arith.index_cast to spirv.Select if the type of source is index.
+// Converts arith.index_cast to spirv.Select if the target type is i1.
struct IndexCastIndexI1Pattern final
: public OpConversionPattern<arith::IndexCastOp> {
using OpConversionPattern::OpConversionPattern;
More information about the Mlir-commits
mailing list