[Mlir-commits] [mlir] [MLIR][MemRef] Add cf-conversion helper for narrow-type emulation (PR #198053)

Alan Li llvmlistbot at llvm.org
Fri May 15 17:58:45 PDT 2026


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/198053

>From fcf5ca163490eb3234735663a8a96b9776ee7c0f Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 15 May 2026 16:48:19 -0700
Subject: [PATCH] [MLIR][MemRef] Add cf-conversion helper for narrow-type
 emulation

Add memref::populateMemRefNarrowTypeEmulationCFPatterns, a thin wrapper
over cf::populateCFStructuralTypeConversionsAndLegality, so callers can
opt into rewriting cf.br / cf.cond_br / cf.switch operand and successor
block-argument types when emulating sub-byte memref element types. Wire
the helper through the in-tree test pass behind a new
`enable-cf-conversion` option, and add lit coverage that exercises
cf.br, cf.cond_br, and cf.switch carrying sub-byte memref values across
block-argument boundaries.
---
 .../Dialect/MemRef/Transforms/Transforms.h    | 10 +++
 .../Dialect/MemRef/Transforms/CMakeLists.txt  |  1 +
 .../MemRef/Transforms/EmulateNarrowType.cpp   |  8 ++
 .../MemRef/emulate-narrow-type-cf.mlir        | 78 +++++++++++++++++++
 .../Dialect/MemRef/TestEmulateNarrowType.cpp  | 10 ++-
 5 files changed, 104 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Dialect/MemRef/emulate-narrow-type-cf.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 4d6c54d74d2a9..0df84bad2dd65 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -18,6 +18,7 @@
 #include "llvm/ADT/STLFunctionalExtras.h"
 
 namespace mlir {
+class ConversionTarget;
 class OpBuilder;
 class RewritePatternSet;
 class RewriterBase;
@@ -104,6 +105,15 @@ void populateMemRefNarrowTypeEmulationPatterns(
 void populateMemRefNarrowTypeEmulationConversions(
     arith::NarrowTypeEmulationConverter &typeConverter);
 
+/// Register patterns + dynamic legality so that cf branch ops carrying
+/// memref values whose element type is being emulated have both their
+/// operand types and their successor block-argument types rewritten to the
+/// container element type. Thin wrapper over
+/// cf::populateCFStructuralTypeConversionsAndLegality.
+void populateMemRefNarrowTypeEmulationCFPatterns(
+    const arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns, ConversionTarget &target);
+
 /// Transformation to do multi-buffering/array expansion to remove dependencies
 /// on the temporary allocation between consecutive loop iterations.
 /// It returns the new allocation if the original allocation was multi-buffered
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 1c5e07f89b338..e5e90e3f5ff88 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   MLIRAffineUtils
   MLIRArithDialect
   MLIRArithTransforms
+  MLIRControlFlowTransforms
   MLIRDialectUtils
   MLIRFuncDialect
   MLIRGPUDialect
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index a11e14faa5475..51a147f5fa79f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
@@ -813,3 +814,10 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
                                newElemTy, layoutAttr, ty.getMemorySpace());
       });
 }
+
+void memref::populateMemRefNarrowTypeEmulationCFPatterns(
+    const arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns, ConversionTarget &target) {
+  cf::populateCFStructuralTypeConversionsAndLegality(typeConverter, patterns,
+                                                     target);
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-cf.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-cf.mlir
new file mode 100644
index 0000000000000..dc67b776553a0
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-cf.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8 arith-compute-bitwidth=1" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
+
+// Sub-byte memref type carried through cf.br block args. The cf branch
+// pattern (registered by cf::populateCFStructuralTypeConversionsAndLegality)
+// must rewrite both the cf.br operand type and the successor block-arg type
+// to the i8 container, so the downstream uses in the successor block see an
+// i8 source.
+
+// CHECK-LABEL: func.func @cf_br_block_arg_narrow_type
+// CHECK-SAME:    %[[ARG:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>
+// CHECK:         cf.br ^[[BB1:.+]](%[[ARG]] : memref<{{[0-9]+}}xi8>)
+// CHECK:       ^[[BB1]](%[[BARG:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK:         return %[[BARG]]
+// CHECK-NOT:     memref<{{[0-9]+}}xi4>
+func.func @cf_br_block_arg_narrow_type(%arg: memref<8xi4>) -> memref<8xi4> {
+  cf.br ^bb1(%arg : memref<8xi4>)
+^bb1(%a: memref<8xi4>):
+  return %a : memref<8xi4>
+}
+
+// -----
+
+// Sub-byte memref carried through both successors of a cf.cond_br. Both
+// branch operand types and both successor block-arg types must be rewritten
+// to the i8 container.
+
+// CHECK-LABEL: func.func @cf_cond_br_block_arg_narrow_type
+// CHECK-SAME:    %[[COND:[A-Za-z0-9_]+]]: i1
+// CHECK-SAME:    %[[A:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>
+// CHECK-SAME:    %[[B:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>
+// CHECK:         cf.cond_br %[[COND]], ^[[BBT:.+]](%[[A]] : memref<{{[0-9]+}}xi8>), ^[[BBF:.+]](%[[B]] : memref<{{[0-9]+}}xi8>)
+// CHECK:       ^[[BBT]](%[[XT:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK:         return %[[XT]]
+// CHECK:       ^[[BBF]](%[[XF:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK:         return %[[XF]]
+// CHECK-NOT:     memref<{{[0-9]+}}xi4>
+func.func @cf_cond_br_block_arg_narrow_type(%cond: i1, %a: memref<8xi4>, %b: memref<8xi4>) -> memref<8xi4> {
+  cf.cond_br %cond, ^bb1(%a : memref<8xi4>), ^bb2(%b : memref<8xi4>)
+^bb1(%x: memref<8xi4>):
+  return %x : memref<8xi4>
+^bb2(%y: memref<8xi4>):
+  return %y : memref<8xi4>
+}
+
+// -----
+
+// Sub-byte memref carried through the default and case successors of a
+// cf.switch. The branch pattern must rewrite the operand type at every
+// successor edge and the matching block-arg type at every successor.
+
+// CHECK-LABEL: func.func @cf_switch_block_arg_narrow_type
+// CHECK-SAME:    %[[FLAG:[A-Za-z0-9_]+]]: i32
+// CHECK-SAME:    %[[ARG:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>
+// CHECK:         cf.switch %[[FLAG]] : i32, [
+// CHECK:           default: ^[[BBD:.+]](%[[ARG]] : memref<{{[0-9]+}}xi8>)
+// CHECK:           0: ^[[BB0:.+]](%[[ARG]] : memref<{{[0-9]+}}xi8>)
+// CHECK:           1: ^[[BB1:.+]](%[[ARG]] : memref<{{[0-9]+}}xi8>)
+// CHECK:         ]
+// CHECK:       ^[[BBD]](%[[XD:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK:         return %[[XD]]
+// CHECK:       ^[[BB0]](%[[X0:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK:         return %[[X0]]
+// CHECK:       ^[[BB1]](%[[X1:[A-Za-z0-9_]+]]: memref<{{[0-9]+}}xi8>):
+// CHECK:         return %[[X1]]
+// CHECK-NOT:     memref<{{[0-9]+}}xi4>
+func.func @cf_switch_block_arg_narrow_type(%flag: i32, %arg: memref<8xi4>) -> memref<8xi4> {
+  cf.switch %flag : i32, [
+    default: ^bb1(%arg : memref<8xi4>),
+    0: ^bb2(%arg : memref<8xi4>),
+    1: ^bb3(%arg : memref<8xi4>)
+  ]
+^bb1(%x: memref<8xi4>):
+  return %x : memref<8xi4>
+^bb2(%y: memref<8xi4>):
+  return %y : memref<8xi4>
+^bb3(%z: memref<8xi4>):
+  return %z : memref<8xi4>
+}
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index bec83a8dcbef9..a962539d7f003 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -34,9 +35,9 @@ struct TestEmulateNarrowTypePass
       : PassWrapper(pass) {}
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
-                vector::VectorDialect, affine::AffineDialect>();
+    registry.insert<arith::ArithDialect, cf::ControlFlowDialect,
+                    func::FuncDialect, memref::MemRefDialect,
+                    vector::VectorDialect, affine::AffineDialect>();
   }
   StringRef getArgument() const final { return "test-emulate-narrow-int"; }
   StringRef getDescription() const final {
@@ -104,6 +105,9 @@ struct TestEmulateNarrowTypePass
     vector::populateVectorNarrowTypeEmulationPatterns(
         typeConverter, patterns, disableAtomicRMW, assumeAligned);
 
+    memref::populateMemRefNarrowTypeEmulationCFPatterns(typeConverter, patterns,
+                                                        target);
+
     if (failed(applyPartialConversion(op, target, std::move(patterns))))
       signalPassFailure();
   }



More information about the Mlir-commits mailing list