[Mlir-commits] [mlir] 72b8073 - [mlir][SCF] Add scf.index_switch support for populateSCFStructuralTypeConversionsAndLegality (#160344)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 24 04:50:10 PDT 2025
Author: Artemy Skrebkov
Date: 2025-09-24T13:50:05+02:00
New Revision: 72b8073d05e35baa274e51a8d6a870bc4f0ad29e
URL: https://github.com/llvm/llvm-project/commit/72b8073d05e35baa274e51a8d6a870bc4f0ad29e
DIFF: https://github.com/llvm/llvm-project/commit/72b8073d05e35baa274e51a8d6a870bc4f0ad29e.diff
LOG: [mlir][SCF] Add scf.index_switch support for populateSCFStructuralTypeConversionsAndLegality (#160344)
In a downstream project, there is a need for a type conversion pattern
for scf.index_switch operation. A test is added into
`mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir` (not sure this
functionality is really required for sparse tensors, but the test
showcase that the new conversion pattern is functional)
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index b0c781c7aff11..9468927021495 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -185,6 +185,30 @@ class ConvertWhileOpTypes
};
} // namespace
+namespace {
+class ConvertIndexSwitchOpTypes
+ : public Structural1ToNConversionPattern<IndexSwitchOp,
+ ConvertIndexSwitchOpTypes> {
+public:
+ using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
+
+ std::optional<IndexSwitchOp>
+ convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ TypeRange dstTypes) const {
+ auto newOp =
+ IndexSwitchOp::create(rewriter, op.getLoc(), dstTypes, op.getArg(),
+ op.getCases(), op.getNumCases());
+
+ for (unsigned i = 0u; i < op.getNumRegions(); i++) {
+ auto &dstRegion = newOp.getRegion(i);
+ rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
+ }
+ return newOp;
+ }
+};
+} // namespace
+
namespace {
// When the result types of a ForOp/IfOp get changed, the operand types of the
// corresponding yield op need to be changed. In order to trigger the
@@ -220,18 +244,19 @@ void mlir::scf::populateSCFStructuralTypeConversions(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
- ConvertWhileOpTypes, ConvertConditionOpTypes>(
- typeConverter, patterns.getContext(), benefit);
+ ConvertWhileOpTypes, ConvertConditionOpTypes,
+ ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext(),
+ benefit);
}
void mlir::scf::populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target) {
- target.addDynamicallyLegalOp<ForOp, IfOp>(
+ target.addDynamicallyLegalOp<ForOp, IfOp, IndexSwitchOp>(
[&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
// We only have conversions for a subset of ops that use scf.yield
// terminators.
- if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
+ if (!isa<ForOp, IfOp, WhileOp, IndexSwitchOp>(op->getParentOp()))
return true;
return typeConverter.isLegal(op.getOperands());
});
diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
index f5d6a08b7de31..515de5502f322 100644
--- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
@@ -86,3 +86,47 @@ func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024x
}
return %0: tensor<1024xf32, #SparseVector>
}
+
+// CHECK-LABEL: func.func @index_switch(
+// CHECK-SAME: %[[PRED:.*0]]: index,
+// CHECK-SAME: %[[VAL_A_1:.*1]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_A_2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_A_3:.*3]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_A_4:.*4]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[VAL_B_1:.*5]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_B_2:.*6]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_B_3:.*7]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_B_4:.*8]]: !sparse_tensor.storage_specifier
+// CHECK-SAME: %[[VAL_C_1:.*9]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_C_2:.*10]]: memref<?xindex>,
+// CHECK-SAME: %[[VAL_C_3:.*11]]: memref<?xf32>,
+// CHECK-SAME: %[[VAL_C_4:.*12]]: !sparse_tensor.storage_specifier
+
+// CHECK: %[[RES:.*]]:4 = scf.index_switch %[[PRED]]
+// CHECK-SAME: -> memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
+// CHECK: case 1 {
+// CHECK: scf.yield %[[VAL_A_1]], %[[VAL_A_2]], %[[VAL_A_3]], %[[VAL_A_4]]
+// CHECK: case 2 {
+// CHECK: scf.yield %[[VAL_B_1]], %[[VAL_B_2]], %[[VAL_B_3]], %[[VAL_B_4]]
+// CHECK: default {
+// CHECK: scf.yield %[[VAL_C_1]], %[[VAL_C_2]], %[[VAL_C_3]], %[[VAL_C_4]]
+
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 :
+// CHECK-SAME: memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
+
+func.func @index_switch(%pred: index, %a: tensor<5xf32, #SparseVector>,
+ %b: tensor<5xf32, #SparseVector>,
+ %c: tensor<5xf32, #SparseVector>) -> tensor<5xf32, #SparseVector> {
+ %0 = scf.index_switch %pred -> tensor<5xf32, #SparseVector>
+ case 1 {
+ scf.yield %a : tensor<5xf32, #SparseVector>
+ }
+ case 2 {
+ scf.yield %b : tensor<5xf32, #SparseVector>
+ }
+ default {
+ scf.yield %c : tensor<5xf32, #SparseVector>
+ }
+
+ return %0 : tensor<5xf32, #SparseVector>
+}
More information about the Mlir-commits
mailing list