[Mlir-commits] [mlir] b1ef5a8 - [mlir][MemRef] Add support for emulating narrow floats (#148036)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 14 08:18:54 PDT 2025


Author: Quinn Dawkins
Date: 2025-07-14T11:18:51-04:00
New Revision: b1ef5a8890f26b6430d95696d8ff4e99e270a80e

URL: https://github.com/llvm/llvm-project/commit/b1ef5a8890f26b6430d95696d8ff4e99e270a80e
DIFF: https://github.com/llvm/llvm-project/commit/b1ef5a8890f26b6430d95696d8ff4e99e270a80e.diff

LOG: [mlir][MemRef] Add support for emulating narrow floats (#148036)

This enables memref.load/store + vector.load/store support for sub-byte
float types. Since the memref types don't matter for loads/stores, we
still use the same types as integers with equivalent widths, with a few
extra bitcasts needed around certain operations.

There is no direct change needed for vector.load/store support. The
tests added for them are to verify that float types are
supported as well.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
    mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
    mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index d2a032688fb6d..ec2bc95291455 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -323,19 +323,28 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
     // It is not clear if this case actually happens in practice, but we keep
     // the operations just in case. Otherwise, if the arith computation bitwidth
     // is 
diff erent from the emulated bitwidth we truncate the result.
-    Operation *result;
+    Value result;
     auto resultTy = getTypeConverter()->convertType(oldElementType);
-    if (resultTy == convertedElementType) {
+    auto conversionTy =
+        resultTy.isInteger()
+            ? resultTy
+            : IntegerType::get(rewriter.getContext(),
+                               resultTy.getIntOrFloatBitWidth());
+    if (conversionTy == convertedElementType) {
       auto mask = rewriter.create<arith::ConstantOp>(
           loc, convertedElementType,
           rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
 
       result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
     } else {
-      result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+      result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
     }
 
-    rewriter.replaceOp(op, result->getResult(0));
+    if (conversionTy != resultTy) {
+      result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -415,8 +424,18 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
     }
 
     Location loc = op.getLoc();
-    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
-                                                          adaptor.getValue());
+
+    // Pad the input value with 0s on the left.
+    Value input = adaptor.getValue();
+    if (!input.getType().isInteger()) {
+      input = rewriter.create<arith::BitcastOp>(
+          loc,
+          IntegerType::get(rewriter.getContext(),
+                           input.getType().getIntOrFloatBitWidth()),
+          input);
+    }
+    Value extendedInput =
+        rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
 
     // Special case 0-rank memref stores. No need for masking.
     if (convertedType.getRank() == 0) {
@@ -619,11 +638,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
     arith::NarrowTypeEmulationConverter &typeConverter) {
   typeConverter.addConversion(
       [&typeConverter](MemRefType ty) -> std::optional<Type> {
-        auto intTy = dyn_cast<IntegerType>(ty.getElementType());
-        if (!intTy)
+        Type elementType = ty.getElementType();
+        if (!elementType.isIntOrFloat())
           return ty;
 
-        unsigned width = intTy.getWidth();
+        unsigned width = elementType.getIntOrFloatBitWidth();
         unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
         if (width >= loadStoreWidth)
           return ty;
@@ -636,8 +655,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
         if (!strides.empty() && strides.back() != 1)
           return nullptr;
 
-        auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
-                                          intTy.getSignedness());
+        auto newElemTy = IntegerType::get(
+            ty.getContext(), loadStoreWidth,
+            elementType.isInteger()
+                ? cast<IntegerType>(elementType).getSignedness()
+                : IntegerType::SignednessSemantics::Signless);
         if (!newElemTy)
           return nullptr;
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 004beadc9ec7d..0fe08417f818f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1268,8 +1268,18 @@ struct ConvertVectorTransferRead final
     bool isDivisibleInSize =
         fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
 
-    auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
-                                                      adaptor.getPadding());
+    // Pad the padding value with 0s on the left. These bits are discarded and
+    // thus their values don't matter.
+    Value padding = adaptor.getPadding();
+    if (!padding.getType().isInteger()) {
+      padding = rewriter.create<arith::BitcastOp>(
+          loc,
+          IntegerType::get(rewriter.getContext(),
+                           padding.getType().getIntOrFloatBitWidth()),
+          padding);
+    }
+    auto newPadding =
+        rewriter.create<arith::ExtUIOp>(loc, containerElemTy, padding);
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());

diff  --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 3378d329e8205..0cce8c18a40bc 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -61,6 +61,41 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
 
 // -----
 
+func.func @memref_load_f4(%arg0: index) -> f4E2M1FN {
+    %0 = memref.alloc() : memref<5xf4E2M1FN>
+    %1 = memref.load %0[%arg0] : memref<5xf4E2M1FN>
+    return %1 : f4E2M1FN
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+//      CHECK: func @memref_load_f4(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+//      CHECK:   return %[[BC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+//      CHECK32: func @memref_load_f4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+//      CHECK32:   return %[[BC]]
+
+// -----
+
 func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
     %0 = memref.alloc() : memref<3x125xi4>
     %align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
@@ -470,6 +505,29 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
 
 // -----
 
+func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
+  %0 = memref.alloc() : memref<f4E2M1FN>
+  memref.store %arg0, %0[] : memref<f4E2M1FN>
+  return
+}
+// CHECK-LABEL: func @rank_zero_memref
+//  CHECK-SAME:     %[[ARG0:.+]]: f4E2M1FN
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+//       CHECK:   %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+//       CHECK:   %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
+//       CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+//       CHECK:   return
+
+// CHECK32-LABEL: func @rank_zero_memref
+//  CHECK32-SAME:     %[[ARG0:.+]]: f4E2M1FN
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+//       CHECK32:   %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+//       CHECK32:   %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32
+//       CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+//       CHECK32:   return
+
+// -----
+
 func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
   %arr = memref.alloc() : memref<32x8x128xi4>
   %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 6c924492b513e..98b1f07ef5fb0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -53,6 +53,31 @@ func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> {
 
 // -----
 
+func.func @vector_load_f4(%arg1: index, %arg2: index) -> vector<3x8xf4E2M1FN> {
+    %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+    %cst = arith.constant dense<0.0> : vector<3x8xf4E2M1FN>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+    %2 = vector.insert %1, %cst [0] : vector<8xf4E2M1FN> into vector<3x8xf4E2M1FN>
+    return %2 : vector<3x8xf4E2M1FN>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_load_f4
+// CHECK-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8>
+//      CHECK:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_load_f4
+// CHECK32-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
 func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> {
   %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
   %1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
@@ -119,6 +144,37 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
 
 // -----
 
+func.func @vector_transfer_read_f4(%arg1: index, %arg2: index) -> vector<8xf4E2M1FN> {
+    %c0 = arith.constant 0.0 : f4E2M1FN
+    %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+    %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
+      memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+    return %1 : vector<8xf4E2M1FN>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_transfer_read_f4
+// CHECK-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK:   %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+//      CHECK:   %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i8
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
+//      CHECK:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_transfer_read_f4
+// CHECK32-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK32:   %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+//      CHECK32:   %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i32
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// vector.maskedload
 ///----------------------------------------------------------------------------------------
@@ -439,6 +495,28 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
 
 // -----
 
+func.func @vector_store_f4(%arg0: vector<8xf4E2M1FN>, %arg1: index, %arg2: index) {
+    %0 = memref.alloc() : memref<4x8xf4E2M1FN>
+    vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN>
+    return
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_store_f4
+//      CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
+//      CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<4xi8>
+//      CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_store_f4
+//      CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
+//      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<1xi32>
+//      CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
+
+// -----
+
 // FIXME: This example assumes that the store happens at a byte boundary, but
 // that's not guaranteed. Below is a counter-example with specific dimensions:
 //    vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>


        


More information about the Mlir-commits mailing list