[Mlir-commits] [mlir] [mlir][spirv] Add pattern matching for arith.index_cast index to i1 for ArithToSPIRV (PR #156031)

Ian Li llvmlistbot at llvm.org
Tue Sep 2 10:33:29 PDT 2025


https://github.com/ianayl updated https://github.com/llvm/llvm-project/pull/156031

>From 6e1a6ab301f923ea52d3fcb61ccbdbf34f5f3935 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 28 Aug 2025 20:49:45 -0700
Subject: [PATCH 01/10] Add conversion from arith.index_cast index->i1 to SPIRV

---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 32 ++++++++++++++++++-
 .../ArithToSPIRV/arith-to-spirv.mlir          | 11 +++++++
 2 files changed, 42 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265293b83f84c..de43b5e7fb176 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -607,6 +607,36 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.index_cast to spirv.Select if the type of source is index.
+struct IndexCastIndexI1Pattern final : public OpConversionPattern<arith::IndexCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = adaptor.getOperands().front().getType();
+    // Indexes have already been converted to its respective spirv type:
+    Type indexType = getTypeConverter<SPIRVTypeConverter>()->getIndexType();
+    if (srcType != indexType || !op.getType().isInteger(1))
+      return failure();
+
+    Type dstType = rewriter.getI1Type();
+    Location loc = op.getLoc();
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
+    auto isZero = spirv::IEqualOp::create(
+        rewriter, loc, dstType, zeroIdx, adaptor.getOperands().front());
+    // spriv.IEqual outputs i32, spirv.Select is used to truncate to i1:
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isZero, zero, one);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ExtSIOp
 //===----------------------------------------------------------------------===//
@@ -1328,7 +1358,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
     TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
-    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastIndexI1Pattern,
     TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 6e2352e706acc..3109edf5d87d6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,6 +734,17 @@ func.func @index_castui4(%arg0: index) {
   return
 }
 
+// CHECK-LABEL: index_castindexi1
+func.func @index_castindexi1(%arg0 : index) {
+  // CHECK: %[[FALSE:.+]] = spirv.Constant false
+  // CHECK: %[[TRUE:.+]] = spirv.Constant true
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+  // CHECK: %[[IS_ZERO:.+]] = spirv.IEqual %[[ZERO]], %{{.+}} : i32
+  // CHECK: spirv.Select %[[IS_ZERO]], %[[FALSE]], %[[TRUE]] : i1, i1
+  %0 = arith.index_cast %arg0 : index to i1
+  return
+}
+
 // CHECK-LABEL: @bit_cast
 func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
   // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>

>From c03e924e4823255b5e0d4c378035f9d70f7d5788 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 08:16:01 -0700
Subject: [PATCH 02/10] clang-format

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index de43b5e7fb176..41ed211ba3731 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -612,7 +612,8 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
 //===----------------------------------------------------------------------===//
 
 /// Converts arith.index_cast to spirv.Select if the type of source is index.
-struct IndexCastIndexI1Pattern final : public OpConversionPattern<arith::IndexCastOp> {
+struct IndexCastIndexI1Pattern final
+    : public OpConversionPattern<arith::IndexCastOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -629,10 +630,11 @@ struct IndexCastIndexI1Pattern final : public OpConversionPattern<arith::IndexCa
     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
     Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
-    auto isZero = spirv::IEqualOp::create(
-        rewriter, loc, dstType, zeroIdx, adaptor.getOperands().front());
+    auto isZero = spirv::IEqualOp::create(rewriter, loc, dstType, zeroIdx,
+                                          adaptor.getOperands().front());
     // spriv.IEqual outputs i32, spirv.Select is used to truncate to i1:
-    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isZero, zero, one);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isZero, zero,
+                                                 one);
     return success();
   }
 };

>From fa8db0aade8b2f71abd912c8b529bcb616fd57ff Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 09:17:36 -0700
Subject: [PATCH 03/10] remove redundant truncate

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp     | 9 +--------
 mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 5 +----
 2 files changed, 2 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 41ed211ba3731..a5469f506fda8 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -625,16 +625,9 @@ struct IndexCastIndexI1Pattern final
     if (srcType != indexType || !op.getType().isInteger(1))
       return failure();
 
-    Type dstType = rewriter.getI1Type();
     Location loc = op.getLoc();
-    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
-    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
     Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
-    auto isZero = spirv::IEqualOp::create(rewriter, loc, dstType, zeroIdx,
-                                          adaptor.getOperands().front());
-    // spriv.IEqual outputs i32, spirv.Select is used to truncate to i1:
-    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isZero, zero,
-                                                 one);
+    rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, rewriter.getI1Type(), zeroIdx, adaptor.getOperands().front());
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 3109edf5d87d6..f3f5a5fadc0b6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -736,11 +736,8 @@ func.func @index_castui4(%arg0: index) {
 
 // CHECK-LABEL: index_castindexi1
 func.func @index_castindexi1(%arg0 : index) {
-  // CHECK: %[[FALSE:.+]] = spirv.Constant false
-  // CHECK: %[[TRUE:.+]] = spirv.Constant true
   // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
-  // CHECK: %[[IS_ZERO:.+]] = spirv.IEqual %[[ZERO]], %{{.+}} : i32
-  // CHECK: spirv.Select %[[IS_ZERO]], %[[FALSE]], %[[TRUE]] : i1, i1
+  // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
   %0 = arith.index_cast %arg0 : index to i1
   return
 }

>From ca2e36a3992f2a0bbee4e4efd83b20c5a7438b18 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 09:20:08 -0700
Subject: [PATCH 04/10] clang-format

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a5469f506fda8..9ed7602cd9789 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -627,7 +627,8 @@ struct IndexCastIndexI1Pattern final
 
     Location loc = op.getLoc();
     Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
-    rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, rewriter.getI1Type(), zeroIdx, adaptor.getOperands().front());
+    rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(
+        op, rewriter.getI1Type(), zeroIdx, adaptor.getOperands().front());
     return success();
   }
 };

>From fa93f78d0eebf39de5d1dff0e88d7eff0d7450d0 Mon Sep 17 00:00:00 2001
From: Ian Li <ianayl.work at gmail.com>
Date: Fri, 29 Aug 2025 12:21:17 -0400
Subject: [PATCH 05/10] Fix comment

Co-authored-by: Md Abdullah Shahneous Bari <98356296+mshahneo at users.noreply.github.com>
---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9ed7602cd9789..c53a3c8b10098 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-/// Converts arith.index_cast to spirv.Select if the type of source is index.
+// Converts arith.index_cast to spirv.Select if the target type is i1.
 struct IndexCastIndexI1Pattern final
     : public OpConversionPattern<arith::IndexCastOp> {
   using OpConversionPattern::OpConversionPattern;

>From 1ae814c13627bee363b72aa6b021f9eb9fbb4ef2 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 13:00:31 -0700
Subject: [PATCH 06/10] Amend comment

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index c53a3c8b10098..0137f37f97364 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-// Converts arith.index_cast to spirv.Select if the target type is i1.
+// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
 struct IndexCastIndexI1Pattern final
     : public OpConversionPattern<arith::IndexCastOp> {
   using OpConversionPattern::OpConversionPattern;

>From b5583a0c4764524f91377be4970c2550a3aeff08 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Tue, 2 Sep 2025 10:09:45 -0700
Subject: [PATCH 07/10] Apply reviewer comments

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 0137f37f97364..c51fbab4d3b1a 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -619,16 +619,14 @@ struct IndexCastIndexI1Pattern final
   LogicalResult
   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type srcType = adaptor.getOperands().front().getType();
-    // Indexes have already been converted to its respective spirv type:
-    Type indexType = getTypeConverter<SPIRVTypeConverter>()->getIndexType();
-    if (srcType != indexType || !op.getType().isInteger(1))
+    Type srcType = adaptor.getIn().getType();
+    if (!op.getType().isInteger(1))
       return failure();
 
     Location loc = op.getLoc();
     Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
     rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(
-        op, rewriter.getI1Type(), zeroIdx, adaptor.getOperands().front());
+        op, op.getType(), zeroIdx, adaptor.getIn());
     return success();
   }
 };

>From 2815dbdeca95783ce6df8622987b8895b01729ef Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Tue, 2 Sep 2025 10:16:07 -0700
Subject: [PATCH 08/10] clang-format

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index c51fbab4d3b1a..37dda67e27a19 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -625,8 +625,8 @@ struct IndexCastIndexI1Pattern final
 
     Location loc = op.getLoc();
     Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
-    rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(
-        op, op.getType(), zeroIdx, adaptor.getIn());
+    rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, op.getType(), zeroIdx,
+                                                    adaptor.getIn());
     return success();
   }
 };

>From cf86eb619c18e7b277307d66ca582a2d6e2adb23 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Tue, 2 Sep 2025 10:32:37 -0700
Subject: [PATCH 09/10] Fix compat with vector<?xi1>

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp    |  5 ++---
 .../test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 12 ++++++++++--
 2 files changed, 12 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 37dda67e27a19..3c3b3e455658f 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -619,12 +619,11 @@ struct IndexCastIndexI1Pattern final
   LogicalResult
   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type srcType = adaptor.getIn().getType();
-    if (!op.getType().isInteger(1))
+    if (!isBoolScalarOrVector(op.getType()))
       return failure();
 
     Location loc = op.getLoc();
-    Value zeroIdx = spirv::ConstantOp::getZero(srcType, loc, rewriter);
+    Value zeroIdx = spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
     rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, op.getType(), zeroIdx,
                                                     adaptor.getIn());
     return success();
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index f3f5a5fadc0b6..8caaf06236b4f 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,14 +734,22 @@ func.func @index_castui4(%arg0: index) {
   return
 }
 
-// CHECK-LABEL: index_castindexi1
-func.func @index_castindexi1(%arg0 : index) {
+// CHECK-LABEL: index_castindexi1_1
+func.func @index_castindexi1_1(%arg0 : index) {
   // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
   // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : i32
   %0 = arith.index_cast %arg0 : index to i1
   return
 }
 
+// CHECK-LABEL: index_castindexi1_2
+func.func @index_castindexi1_2(%arg0 : vector<3xindex>) {
+  // CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
+  // CHECK: spirv.INotEqual %[[ZERO]], %{{.+}} : vector<3xi32>
+  %0 = arith.index_cast %arg0 : vector<3xindex> to vector<3xi1>
+  return
+}
+
 // CHECK-LABEL: @bit_cast
 func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
   // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>

>From 68e6fb46485d002bc950889555dfa0e6e696443a Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Tue, 2 Sep 2025 10:33:08 -0700
Subject: [PATCH 10/10] clang-format

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 3c3b3e455658f..e74bc4cdab91a 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -623,7 +623,8 @@ struct IndexCastIndexI1Pattern final
       return failure();
 
     Location loc = op.getLoc();
-    Value zeroIdx = spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
+    Value zeroIdx =
+        spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
     rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, op.getType(), zeroIdx,
                                                     adaptor.getIn());
     return success();



More information about the Mlir-commits mailing list