[Mlir-commits] [mlir] [mlir][SCF] Add scf.index_switch support for populateSCFStructuralTypeConversionsAndLegality (PR #160344)

Artemy Skrebkov llvmlistbot at llvm.org
Tue Sep 23 09:50:46 PDT 2025


https://github.com/ArtemySkrebkov created https://github.com/llvm/llvm-project/pull/160344

In a downstream project, there is a need for a type conversion pattern for scf.index_switch operation. A test is added into `mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir` (not sure this functionality is really required for sparse tensors, but the test showcase that the new conversion pattern is functional)

>From a16df3bd80bc07657fa1bfc5ee05438f22cd6107 Mon Sep 17 00:00:00 2001
From: "Skrebkov, Artemy" <artemy.skrebkov at intel.com>
Date: Tue, 23 Sep 2025 16:48:05 +0000
Subject: [PATCH] [mlir][SCF] Add scf.index_switch support for
 populateSCFStructuralTypeConversionsAndLegality

---
 .../Transforms/StructuralTypeConversions.cpp  | 34 +++++++++++++--
 .../SparseTensor/scf_1_N_conversion.mlir      | 43 +++++++++++++++++++
 2 files changed, 73 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index b0c781c7aff11..c9ff7885f3a49 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -185,6 +185,31 @@ class ConvertWhileOpTypes
 };
 } // namespace
 
+namespace {
+class ConvertIndexSwitchOpTypes
+    : public Structural1ToNConversionPattern<IndexSwitchOp,
+                                             ConvertIndexSwitchOpTypes> {
+public:
+  using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
+
+  std::optional<IndexSwitchOp>
+  convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter,
+                  TypeRange dstTypes) const {
+    auto newOp = rewriter.create<IndexSwitchOp>(
+        op.getLoc(), dstTypes, op.getArg(), op.getCases(), op.getNumCases());
+
+    for (unsigned i = 0u; i < op.getNumRegions(); i++) {
+      if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
+        return std::nullopt;
+      auto &dstRegion = newOp.getRegion(i);
+      rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
+    }
+    return newOp;
+  }
+};
+} // namespace
+
 namespace {
 // When the result types of a ForOp/IfOp get changed, the operand types of the
 // corresponding yield op need to be changed. In order to trigger the
@@ -220,18 +245,19 @@ void mlir::scf::populateSCFStructuralTypeConversions(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
     PatternBenefit benefit) {
   patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
-               ConvertWhileOpTypes, ConvertConditionOpTypes>(
-      typeConverter, patterns.getContext(), benefit);
+               ConvertWhileOpTypes, ConvertConditionOpTypes,
+               ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext(),
+                                          benefit);
 }
 
 void mlir::scf::populateSCFStructuralTypeConversionTarget(
     const TypeConverter &typeConverter, ConversionTarget &target) {
-  target.addDynamicallyLegalOp<ForOp, IfOp>(
+  target.addDynamicallyLegalOp<ForOp, IfOp, IndexSwitchOp>(
       [&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
   target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
     // We only have conversions for a subset of ops that use scf.yield
     // terminators.
-    if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
+    if (!isa<ForOp, IfOp, WhileOp, IndexSwitchOp>(op->getParentOp()))
       return true;
     return typeConverter.isLegal(op.getOperands());
   });
diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
index f5d6a08b7de31..00f13ed7c8149 100644
--- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
@@ -86,3 +86,46 @@ func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024x
   }
   return %0: tensor<1024xf32, #SparseVector>
 }
+
+// CHECK-LABEL:   func.func @index_switch(
+// CHECK-SAME:                            %[[PRED:.*0]]: index,
+// CHECK-SAME:                            %[[VAL_A_1:.*1]]: memref<?xindex>,
+// CHECK-SAME:                            %[[VAL_A_2:.*2]]: memref<?xindex>,
+// CHECK-SAME:                            %[[VAL_A_3:.*3]]: memref<?xf32>,
+// CHECK-SAME:                            %[[VAL_A_4:.*4]]: !sparse_tensor.storage_specifier
+// CHECK-SAME:                            %[[VAL_B_1:.*5]]: memref<?xindex>,
+// CHECK-SAME:                            %[[VAL_B_2:.*6]]: memref<?xindex>,
+// CHECK-SAME:                            %[[VAL_B_3:.*7]]: memref<?xf32>,
+// CHECK-SAME:                            %[[VAL_B_4:.*8]]: !sparse_tensor.storage_specifier
+// CHECK-SAME:                            %[[VAL_C_1:.*9]]: memref<?xindex>,
+// CHECK-SAME:                            %[[VAL_C_2:.*10]]: memref<?xindex>,
+// CHECK-SAME:                            %[[VAL_C_3:.*11]]: memref<?xf32>,
+// CHECK-SAME:                            %[[VAL_C_4:.*12]]: !sparse_tensor.storage_specifier
+
+// CHECK:           %[[RES:.*]]:4 = scf.index_switch %[[PRED]]
+// CHECK:           case 1 {
+// CHECK:             scf.yield %[[VAL_A_1]], %[[VAL_A_2]], %[[VAL_A_3]], %[[VAL_A_4]]
+// CHECK:           case 2 {
+// CHECK:             scf.yield %[[VAL_B_1]], %[[VAL_B_2]], %[[VAL_B_3]], %[[VAL_B_4]]
+// CHECK:           default {
+// CHECK:             scf.yield %[[VAL_C_1]], %[[VAL_C_2]], %[[VAL_C_3]], %[[VAL_C_4]]
+
+// CHECK:           return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 :
+// CHECK-SAME:        memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
+
+func.func @index_switch(%pred: index, %a: tensor<5xf32, #SparseVector>,
+                        %b: tensor<5xf32, #SparseVector>,
+                        %c: tensor<5xf32, #SparseVector>) -> tensor<5xf32, #SparseVector> {
+  %0 = scf.index_switch %pred -> tensor<5xf32, #SparseVector>
+  case 1 {
+    scf.yield %a : tensor<5xf32, #SparseVector>
+  }
+  case 2 {
+    scf.yield %b : tensor<5xf32, #SparseVector>
+  }
+  default {
+    scf.yield %c : tensor<5xf32, #SparseVector>
+  }
+
+  return %0 : tensor<5xf32, #SparseVector>
+}



More information about the Mlir-commits mailing list