[Mlir-commits] [mlir] [mlir][spirv] Add pattern matching for arith.index_cast i1 to index for ArithToSPIRV (PR #155729)

Ian Li llvmlistbot at llvm.org
Tue Sep 2 10:50:11 PDT 2025


https://github.com/ianayl updated https://github.com/llvm/llvm-project/pull/155729

>From f446699bc016509b7a7c6c0a2170b61d0b8709c8 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 27 Aug 2025 16:51:17 -0700
Subject: [PATCH 1/6] [mlir][spirv] Add pattern matching for arith.index_cast
 i1 to index

---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 37 ++++++++++++++++++-
 .../ArithToSPIRV/arith-to-spirv.mlir          |  7 ++++
 2 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265293b83f84c..172f322a12fd8 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -607,6 +607,41 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// IndexCastOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.index_cast to spirv.Select if the type of source is i1 or
+/// vector of i1.
+struct IndexCastI1Pattern final : public OpConversionPattern<arith::IndexCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type srcType = adaptor.getOperands().front().getType();
+    if (!srcType.isInteger(1))
+      return failure();
+
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+    // if (!dstType.isIndex()) {
+    //   llvm::errs() << "why doesnt this work?\n";
+    //   return failure();
+    // }
+
+    auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Location loc = op.getLoc();
+    Type spirvI32T = converter->getIndexType();
+    Value zero = spirv::ConstantOp::getZero(spirvI32T, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(spirvI32T, loc, rewriter);
+    auto newOp = rewriter.replaceOpWithNewOp<spirv::SelectOp>(
+        op, dstType, adaptor.getOperands().front(), one, zero);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ExtSIOp
 //===----------------------------------------------------------------------===//
@@ -1328,7 +1363,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
     TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
-    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
+    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastI1Pattern,
     TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 6e2352e706acc..8bb63fff861ce 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,6 +734,13 @@ func.func @index_castui4(%arg0: index) {
   return
 }
 
+// CHECK-LABEL: index_casti1_1
+func.func @index_casti1_1(%arg0 : i1) -> index {
+  // CHECK: spirv.Select %{{.+}}, %{{.+}}, %{{.+}} : i1, i32
+  %0 = arith.index_cast %arg0 : i1 to index
+  return %0 : index
+}
+
 // CHECK-LABEL: @bit_cast
 func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
   // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>

>From 1ba2dcf5c81b98b67ebf95c9052c28119f230e99 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 28 Aug 2025 13:37:27 -0700
Subject: [PATCH 2/6] Remove redundancy, add missing lit checks

---
 .../lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 16 +++++-----------
 .../Conversion/ArithToSPIRV/arith-to-spirv.mlir  |  8 +++++---
 2 files changed, 10 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 172f322a12fd8..b9e04e456ff72 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -613,7 +613,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
 
 /// Converts arith.index_cast to spirv.Select if the type of source is i1 or
 /// vector of i1.
-struct IndexCastI1Pattern final : public OpConversionPattern<arith::IndexCastOp> {
+struct IndexCastI1IndexPattern final : public OpConversionPattern<arith::IndexCastOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
@@ -626,17 +626,11 @@ struct IndexCastI1Pattern final : public OpConversionPattern<arith::IndexCastOp>
     Type dstType = getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
-    // if (!dstType.isIndex()) {
-    //   llvm::errs() << "why doesnt this work?\n";
-    //   return failure();
-    // }
 
-    auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
     Location loc = op.getLoc();
-    Type spirvI32T = converter->getIndexType();
-    Value zero = spirv::ConstantOp::getZero(spirvI32T, loc, rewriter);
-    Value one = spirv::ConstantOp::getOne(spirvI32T, loc, rewriter);
-    auto newOp = rewriter.replaceOpWithNewOp<spirv::SelectOp>(
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(
         op, dstType, adaptor.getOperands().front(), one, zero);
     return success();
   }
@@ -1363,7 +1357,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
     TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
-    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastI1Pattern,
+    TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, IndexCastI1IndexPattern,
     TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 8bb63fff861ce..938a5ccfed542 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,9 +734,11 @@ func.func @index_castui4(%arg0: index) {
   return
 }
 
-// CHECK-LABEL: index_casti1_1
-func.func @index_casti1_1(%arg0 : i1) -> index {
-  // CHECK: spirv.Select %{{.+}}, %{{.+}}, %{{.+}} : i1, i32
+// CHECK-LABEL: index_casti1index_1
+func.func @index_casti1index_1(%arg0 : i1) -> index {
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
+  // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
   %0 = arith.index_cast %arg0 : i1 to index
   return %0 : index
 }

>From ac32e57f1fdfc358e39e9e01a5787f38d0d3c513 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 28 Aug 2025 13:51:03 -0700
Subject: [PATCH 3/6] remove redundant return

---
 mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 938a5ccfed542..e86b04527383d 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -735,12 +735,12 @@ func.func @index_castui4(%arg0: index) {
 }
 
 // CHECK-LABEL: index_casti1index_1
-func.func @index_casti1index_1(%arg0 : i1) -> index {
+func.func @index_casti1index_1(%arg0 : i1) {
   // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
   // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
   // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
   %0 = arith.index_cast %arg0 : i1 to index
-  return %0 : index
+  return
 }
 
 // CHECK-LABEL: @bit_cast

>From 13f3d477dd241f50d478e6536b8d2f57547e5fd4 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 08:35:59 -0700
Subject: [PATCH 4/6] clang-format

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index b9e04e456ff72..b55322816fd31 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -613,7 +613,8 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
 
 /// Converts arith.index_cast to spirv.Select if the type of source is i1 or
 /// vector of i1.
-struct IndexCastI1IndexPattern final : public OpConversionPattern<arith::IndexCastOp> {
+struct IndexCastI1IndexPattern final
+    : public OpConversionPattern<arith::IndexCastOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult

>From e61ff6a7774e437a0319f84e5aa9c363bd9f2ff7 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 29 Aug 2025 13:07:37 -0700
Subject: [PATCH 5/6] rewrite comments to conform with sister PR

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp     | 3 +--
 mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir | 4 ++--
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index b55322816fd31..09d2a1bbf9d45 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -611,8 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-/// Converts arith.index_cast to spirv.Select if the type of source is i1 or
-/// vector of i1.
+/// Converts arith.index_cast to spirv.Select if the source type is i1
 struct IndexCastI1IndexPattern final
     : public OpConversionPattern<arith::IndexCastOp> {
   using OpConversionPattern::OpConversionPattern;
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index e86b04527383d..7968fce644e4b 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,8 +734,8 @@ func.func @index_castui4(%arg0: index) {
   return
 }
 
-// CHECK-LABEL: index_casti1index_1
-func.func @index_casti1index_1(%arg0 : i1) {
+// CHECK-LABEL: index_casti1index
+func.func @index_casti1index(%arg0 : i1) {
   // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
   // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
   // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32

>From ac383ab7e3362c2e0f5412294b32e989acf3f506 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Tue, 2 Sep 2025 10:49:53 -0700
Subject: [PATCH 6/6] Amend reviewer comments, add vector support

---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp   |  8 ++++----
 .../Conversion/ArithToSPIRV/arith-to-spirv.mlir     | 13 +++++++++++--
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 09d2a1bbf9d45..af297b1c918bf 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -611,7 +611,7 @@ struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-/// Converts arith.index_cast to spirv.Select if the source type is i1
+/// Converts arith.index_cast to spirv.Select if the source type is i1.
 struct IndexCastI1IndexPattern final
     : public OpConversionPattern<arith::IndexCastOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -619,8 +619,8 @@ struct IndexCastI1IndexPattern final
   LogicalResult
   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type srcType = adaptor.getOperands().front().getType();
-    if (!srcType.isInteger(1))
+    Type srcType = adaptor.getIn().getType();
+    if (!isBoolScalarOrVector(srcType))
       return failure();
 
     Type dstType = getTypeConverter()->convertType(op.getType());
@@ -631,7 +631,7 @@ struct IndexCastI1IndexPattern final
     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
     rewriter.replaceOpWithNewOp<spirv::SelectOp>(
-        op, dstType, adaptor.getOperands().front(), one, zero);
+        op, dstType, adaptor.getIn(), one, zero);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 7968fce644e4b..9f575250aab2e 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -734,8 +734,8 @@ func.func @index_castui4(%arg0: index) {
   return
 }
 
-// CHECK-LABEL: index_casti1index
-func.func @index_casti1index(%arg0 : i1) {
+// CHECK-LABEL: index_casti1index_1
+func.func @index_casti1index_1(%arg0 : i1) {
   // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
   // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
   // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : i1, i32
@@ -743,6 +743,15 @@ func.func @index_casti1index(%arg0 : i1) {
   return
 }
 
+// CHECK-LABEL: index_casti1index_2
+func.func @index_casti1index_2(%arg0 : vector<3xi1>) {
+  // CHECK: %[[ZERO:.+]] = spirv.Constant dense<0> : vector<3xi32>
+  // CHECK: %[[ONE:.+]] = spirv.Constant dense<1> : vector<3xi32>
+  // CHECK: spirv.Select %{{.+}}, %[[ONE]], %[[ZERO]] : vector<3xi1>, vector<3xi32>
+  %0 = arith.index_cast %arg0 : vector<3xi1> to vector<3xindex>
+  return
+}
+
 // CHECK-LABEL: @bit_cast
 func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
   // CHECK: spirv.Bitcast %{{.+}} : vector<2xf32> to vector<2xi32>



More information about the Mlir-commits mailing list