[Mlir-commits] [mlir] b1ef5a8 - [mlir][MemRef] Add support for emulating narrow floats (#148036)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 14 08:18:54 PDT 2025
Author: Quinn Dawkins
Date: 2025-07-14T11:18:51-04:00
New Revision: b1ef5a8890f26b6430d95696d8ff4e99e270a80e
URL: https://github.com/llvm/llvm-project/commit/b1ef5a8890f26b6430d95696d8ff4e99e270a80e
DIFF: https://github.com/llvm/llvm-project/commit/b1ef5a8890f26b6430d95696d8ff4e99e270a80e.diff
LOG: [mlir][MemRef] Add support for emulating narrow floats (#148036)
This enables memref.load/store + vector.load/store support for sub-byte
float types. Since the memref types don't matter for loads/stores, we
still use the same types as integers with equivalent widths, with a few
extra bitcasts needed around certain operations.
There is no direct change needed for vector.load/store support. The
tests added for them are to verify that float types are
supported as well.
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index d2a032688fb6d..ec2bc95291455 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -323,19 +323,28 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// It is not clear if this case actually happens in practice, but we keep
// the operations just in case. Otherwise, if the arith computation bitwidth
// is
diff erent from the emulated bitwidth we truncate the result.
- Operation *result;
+ Value result;
auto resultTy = getTypeConverter()->convertType(oldElementType);
- if (resultTy == convertedElementType) {
+ auto conversionTy =
+ resultTy.isInteger()
+ ? resultTy
+ : IntegerType::get(rewriter.getContext(),
+ resultTy.getIntOrFloatBitWidth());
+ if (conversionTy == convertedElementType) {
auto mask = rewriter.create<arith::ConstantOp>(
loc, convertedElementType,
rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
} else {
- result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+ result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
}
- rewriter.replaceOp(op, result->getResult(0));
+ if (conversionTy != resultTy) {
+ result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
+ }
+
+ rewriter.replaceOp(op, result);
return success();
}
};
@@ -415,8 +424,18 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
}
Location loc = op.getLoc();
- Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
- adaptor.getValue());
+
+ // Pad the input value with 0s on the left.
+ Value input = adaptor.getValue();
+ if (!input.getType().isInteger()) {
+ input = rewriter.create<arith::BitcastOp>(
+ loc,
+ IntegerType::get(rewriter.getContext(),
+ input.getType().getIntOrFloatBitWidth()),
+ input);
+ }
+ Value extendedInput =
+ rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
// Special case 0-rank memref stores. No need for masking.
if (convertedType.getRank() == 0) {
@@ -619,11 +638,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
arith::NarrowTypeEmulationConverter &typeConverter) {
typeConverter.addConversion(
[&typeConverter](MemRefType ty) -> std::optional<Type> {
- auto intTy = dyn_cast<IntegerType>(ty.getElementType());
- if (!intTy)
+ Type elementType = ty.getElementType();
+ if (!elementType.isIntOrFloat())
return ty;
- unsigned width = intTy.getWidth();
+ unsigned width = elementType.getIntOrFloatBitWidth();
unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
if (width >= loadStoreWidth)
return ty;
@@ -636,8 +655,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
if (!strides.empty() && strides.back() != 1)
return nullptr;
- auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
- intTy.getSignedness());
+ auto newElemTy = IntegerType::get(
+ ty.getContext(), loadStoreWidth,
+ elementType.isInteger()
+ ? cast<IntegerType>(elementType).getSignedness()
+ : IntegerType::SignednessSemantics::Signless);
if (!newElemTy)
return nullptr;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 004beadc9ec7d..0fe08417f818f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1268,8 +1268,18 @@ struct ConvertVectorTransferRead final
bool isDivisibleInSize =
fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
- auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
- adaptor.getPadding());
+ // Pad the padding value with 0s on the left. These bits are discarded and
+ // thus their values don't matter.
+ Value padding = adaptor.getPadding();
+ if (!padding.getType().isInteger()) {
+ padding = rewriter.create<arith::BitcastOp>(
+ loc,
+ IntegerType::get(rewriter.getContext(),
+ padding.getType().getIntOrFloatBitWidth()),
+ padding);
+ }
+ auto newPadding =
+ rewriter.create<arith::ExtUIOp>(loc, containerElemTy, padding);
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 3378d329e8205..0cce8c18a40bc 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -61,6 +61,41 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
// -----
+func.func @memref_load_f4(%arg0: index) -> f4E2M1FN {
+ %0 = memref.alloc() : memref<5xf4E2M1FN>
+ %1 = memref.load %0[%arg0] : memref<5xf4E2M1FN>
+ return %1 : f4E2M1FN
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_load_f4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+// CHECK: return %[[BC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_load_f4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+// CHECK32: return %[[BC]]
+
+// -----
+
func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
%0 = memref.alloc() : memref<3x125xi4>
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
@@ -470,6 +505,29 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
// -----
+func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
+ %0 = memref.alloc() : memref<f4E2M1FN>
+ memref.store %arg0, %0[] : memref<f4E2M1FN>
+ return
+}
+// CHECK-LABEL: func @rank_zero_memref
+// CHECK-SAME: %[[ARG0:.+]]: f4E2M1FN
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+// CHECK: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+// CHECK: return
+
+// CHECK32-LABEL: func @rank_zero_memref
+// CHECK32-SAME: %[[ARG0:.+]]: f4E2M1FN
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+// CHECK32: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+// CHECK32: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+// CHECK32: return
+
+// -----
+
func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
%arr = memref.alloc() : memref<32x8x128xi4>
%collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 6c924492b513e..98b1f07ef5fb0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -53,6 +53,31 @@ func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> {
// -----
+func.func @vector_load_f4(%arg1: index, %arg2: index) -> vector<3x8xf4E2M1FN> {
+ %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+ %cst = arith.constant dense<0.0> : vector<3x8xf4E2M1FN>
+ %1 = vector.load %0[%arg1, %arg2] : memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+ %2 = vector.insert %1, %cst [0] : vector<8xf4E2M1FN> into vector<3x8xf4E2M1FN>
+ return %2 : vector<3x8xf4E2M1FN>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_load_f4
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8>
+// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_load_f4
+// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> {
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
%1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
@@ -119,6 +144,37 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
// -----
+func.func @vector_transfer_read_f4(%arg1: index, %arg2: index) -> vector<8xf4E2M1FN> {
+ %c0 = arith.constant 0.0 : f4E2M1FN
+ %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
+ memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+ return %1 : vector<8xf4E2M1FN>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_transfer_read_f4
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+// CHECK: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i8
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
+// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_transfer_read_f4
+// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK32: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+// CHECK32: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i32
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
+// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
///----------------------------------------------------------------------------------------
/// vector.maskedload
///----------------------------------------------------------------------------------------
@@ -439,6 +495,28 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
// -----
+func.func @vector_store_f4(%arg0: vector<8xf4E2M1FN>, %arg1: index, %arg2: index) {
+ %0 = memref.alloc() : memref<4x8xf4E2M1FN>
+ vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN>
+ return
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_store_f4
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<4xi8>
+// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_store_f4
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<1xi32>
+// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
+
+// -----
+
// FIXME: This example assumes that the store happens at a byte boundary, but
// that's not guaranteed. Below is a counter-example with specific dimensions:
// vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>
More information about the Mlir-commits
mailing list