[Mlir-commits] [mlir] [mlirbc] Add AffineMap serialization support (PR #191970)

Jacques Pienaar llvmlistbot at llvm.org
Tue Apr 14 02:17:29 PDT 2026


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/191970

>From 0c06aa5f26892005307f31004c63883c805a6b15 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Thu, 19 Feb 2026 22:08:41 +0200
Subject: [PATCH 1/4] [mlirbc] Add missing encoding for float types

Doing this on reader side first to allow folks to update, before enabling on
writer side.
---
 .../include/mlir/IR/BuiltinDialectBytecode.td | 46 ++++++++++++++++++-
 mlir/include/mlir/IR/BytecodeBase.td          |  4 ++
 mlir/test/Dialect/Builtin/Bytecode/types.mlir | 28 ++++++++++-
 3 files changed, 75 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index c97d093c84e51..0f593b21614b4 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -231,6 +231,16 @@ def BFloat16Type : DialectType<(type)>;
 
 def Float16Type : DialectType<(type)>;
 
+// Stage the addition of new floating point types so that readers can be updated
+// first.
+#ifdef MLIRBC_DISABLE_FLOAT_ADD_JUNE
+class EnableFloatPrintingJune2026<dag d> : DialectTypeNoPrint<d>;
+#else
+class EnableFloatPrintingJune2026<dag d> : DialectType<d>;
+#endif
+
+def FloatTF32Type : EnableFloatPrintingJune2026<(type)>;
+
 def Float32Type : DialectType<(type)>;
 
 def Float64Type : DialectType<(type)>;
@@ -239,6 +249,28 @@ def Float80Type : DialectType<(type)>;
 
 def Float128Type : DialectType<(type)>;
 
+def Float8E5M2Type : EnableFloatPrintingJune2026<(type)>;
+
+def Float8E4M3Type : EnableFloatPrintingJune2026<(type)>;
+
+def Float8E4M3FNType : EnableFloatPrintingJune2026<(type)>;
+
+def Float8E5M2FNUZType : EnableFloatPrintingJune2026<(type)>;
+
+def Float8E4M3FNUZType : EnableFloatPrintingJune2026<(type)>;
+
+def Float8E4M3B11FNUZType : EnableFloatPrintingJune2026<(type)>;
+
+def Float8E3M4Type : EnableFloatPrintingJune2026<(type)>;
+
+def Float4E2M1FNType : EnableFloatPrintingJune2026<(type)>;
+
+def Float6E2M3FNType : EnableFloatPrintingJune2026<(type)>;
+
+def Float6E3M2FNType : EnableFloatPrintingJune2026<(type)>;
+
+def Float8E8M0FNUType : EnableFloatPrintingJune2026<(type)>;
+
 def ComplexType : DialectType<(type
   Type:$elementType
 )>;
@@ -371,7 +403,19 @@ def BuiltinDialectTypes : DialectTypes<"Builtin"> {
     UnrankedMemRefTypeWithMemSpace,
     UnrankedTensorType,
     VectorType,
-    VectorTypeWithScalableDims
+    VectorTypeWithScalableDims,
+    FloatTF32Type,
+    Float8E5M2Type,
+    Float8E4M3Type,
+    Float8E4M3FNType,
+    Float8E5M2FNUZType,
+    Float8E4M3FNUZType,
+    Float8E4M3B11FNUZType,
+    Float8E3M4Type,
+    Float4E2M1FNType,
+    Float6E2M3FNType,
+    Float6E3M2FNType,
+    Float8E8M0FNUType
   ];
 }
 
diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td
index 184c81e6a5f7d..df60800ef639b 100644
--- a/mlir/include/mlir/IR/BytecodeBase.td
+++ b/mlir/include/mlir/IR/BytecodeBase.td
@@ -153,6 +153,10 @@ class DialectType<dag d> : DialectAttrOrType<d>, TypeKind {
   let cParser = "succeeded($_reader.readType<$_resultType>($_var))";
   let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)";
 }
+// Variant of the above, where it never prints. Useful for staging.
+class DialectTypeNoPrint<dag d> : DialectType<d> {
+  let printerPredicate = "false";
+}
 
 class DialectAttributes<string d> {
   string dialect = d;
diff --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
index bcfbf64c833dd..5e421e2bf75bf 100644
--- a/mlir/test/Dialect/Builtin/Bytecode/types.mlir
+++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir
@@ -18,16 +18,40 @@ module @TestComplex attributes {
 module @TestFloat attributes {
   // CHECK: bytecode.test = bf16,
   // CHECK: bytecode.test1 = f16,
+  // CHECK: bytecode.test10 = f8E4M3FNUZ,
+  // CHECK: bytecode.test11 = f8E4M3B11FNUZ,
+  // CHECK: bytecode.test12 = f8E3M4,
+  // CHECK: bytecode.test13 = f4E2M1FN,
+  // CHECK: bytecode.test14 = f6E2M3FN,
+  // CHECK: bytecode.test15 = f6E3M2FN,
+  // CHECK: bytecode.test16 = f8E8M0FNU,
+  // CHECK: bytecode.test17 = tf32,
   // CHECK: bytecode.test2 = f32,
   // CHECK: bytecode.test3 = f64,
   // CHECK: bytecode.test4 = f80,
-  // CHECK: bytecode.test5 = f128
+  // CHECK: bytecode.test5 = f128,
+  // CHECK: bytecode.test6 = f8E5M2,
+  // CHECK: bytecode.test7 = f8E4M3,
+  // CHECK: bytecode.test8 = f8E4M3FN,
+  // CHECK: bytecode.test9 = f8E5M2FNUZ
   bytecode.test = bf16,
   bytecode.test1 = f16,
   bytecode.test2 = f32,
   bytecode.test3 = f64,
   bytecode.test4 = f80,
-  bytecode.test5 = f128
+  bytecode.test5 = f128,
+  bytecode.test6 = f8E5M2,
+  bytecode.test7 = f8E4M3,
+  bytecode.test8 = f8E4M3FN,
+  bytecode.test9 = f8E5M2FNUZ,
+  bytecode.test10 = f8E4M3FNUZ,
+  bytecode.test11 = f8E4M3B11FNUZ,
+  bytecode.test12 = f8E3M4,
+  bytecode.test13 = f4E2M1FN,
+  bytecode.test14 = f6E2M3FN,
+  bytecode.test15 = f6E3M2FN,
+  bytecode.test16 = f8E8M0FNU,
+  bytecode.test17 = tf32
 } {}
 
 //===----------------------------------------------------------------------===//

>From f1fcaa2ccbdc21c7777ca2f41ca1e2fcdef71a88 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Tue, 14 Apr 2026 08:24:09 +0200
Subject: [PATCH 2/4] Remove ifdefs for now

---
 mlir/include/mlir/IR/BuiltinDialectBytecode.td | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0f593b21614b4..351ffee8af7be 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -233,11 +233,7 @@ def Float16Type : DialectType<(type)>;
 
 // Stage the addition of new floating point types so that readers can be updated
 // first.
-#ifdef MLIRBC_DISABLE_FLOAT_ADD_JUNE
 class EnableFloatPrintingJune2026<dag d> : DialectTypeNoPrint<d>;
-#else
-class EnableFloatPrintingJune2026<dag d> : DialectType<d>;
-#endif
 
 def FloatTF32Type : EnableFloatPrintingJune2026<(type)>;
 

>From 0206ae0f4a0c5050aacd6f00ad755055690e843f Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 20 Feb 2026 09:31:21 +0200
Subject: [PATCH 3/4] [mlirbc] Add AffineMap serialization support

Add binary bytecode encoding for AffineMapAttr, replacing the textual fallback.
AffineMap is encoded as numDims, numSymbols, numResults, followed by the result
expressions.  Where each expression, AffineExpr, is encoded as a recursive tree
with a VarInt kind tag followed by kind-specific data.
---
 .../include/mlir/IR/BuiltinDialectBytecode.td |  10 ++
 mlir/lib/IR/BuiltinDialectBytecode.cpp        | 167 +++++++++++++++++-
 2 files changed, 174 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 351ffee8af7be..963a876fe1f30 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -205,6 +205,15 @@ def DistinctAttr : DialectAttribute<(attr
   Attribute:$referencedAttr
 )>;
 
+// Enable writer only first.
+class EnableAffineMapPrintingJune2026<dag d> : DialectTypeNoPrint<d>;
+
+def AffineMapAttr : EnableAffineMapPrintingJune2026<(attr
+  WithParser<"succeeded(readAffineMap($_reader, context, $_var))",
+    WithPrinter<"writeAffineMap($_writer, $_name)",
+    WithType<"AffineMap">>>:$value
+)>;
+
 // Types
 // -----
 
@@ -374,6 +383,7 @@ def BuiltinDialectAttributes : DialectAttributes<"Builtin"> {
     SparseElementsAttr,
     DistinctAttr,
     FileLineColRange,
+    AffineMapAttr,
   ];
 }
 
diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 14dc665184099..7225abe82ce92 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -9,6 +9,8 @@
 #include "BuiltinDialectBytecode.h"
 #include "AttributeDetail.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -33,9 +35,168 @@ namespace {
 
 // TODO: Move these to separate file.
 
-// Returns the bitwidth if known, else return std::nullopt.
-static std::optional<unsigned> getIntegerBitWidth(DialectBytecodeReader &reader,
-                                                  Type type) {
+//===----------------------------------------------------------------------===//
+// AffineExpr / AffineMap bytecode helpers
+//===----------------------------------------------------------------------===//
+
+// AffineExpr kind encoding:
+// Extra kinds may be appended here but the existing ones and their
+// ordering should not be changed.
+enum class AffineExprBytecodeKind : uint64_t {
+  DimId = 0,
+  SymbolId = 1,
+  Constant = 2,
+  Add = 3,
+  Mul = 4,
+  Mod = 5,
+  FloorDiv = 6,
+  CeilDiv = 7
+};
+
+static FailureOr<AffineExpr> readAffineExpr(DialectBytecodeReader &reader,
+                                             MLIRContext *context) {
+  uint64_t kind;
+  if (failed(reader.readVarInt(kind)))
+    return failure();
+
+  switch (static_cast<AffineExprBytecodeKind>(kind)) {
+  case AffineExprBytecodeKind::DimId: {
+    uint64_t position;
+    if (failed(reader.readVarInt(position)))
+      return failure();
+    return getAffineDimExpr(position, context);
+  }
+  case AffineExprBytecodeKind::SymbolId: {
+    uint64_t position;
+    if (failed(reader.readVarInt(position)))
+      return failure();
+    return getAffineSymbolExpr(position, context);
+  }
+  case AffineExprBytecodeKind::Constant: {
+    int64_t value;
+    if (failed(reader.readSignedVarInt(value)))
+      return failure();
+    return getAffineConstantExpr(value, context);
+  }
+  case AffineExprBytecodeKind::Add:
+  case AffineExprBytecodeKind::Mul:
+  case AffineExprBytecodeKind::Mod:
+  case AffineExprBytecodeKind::FloorDiv:
+  case AffineExprBytecodeKind::CeilDiv: { // Binary ops
+    auto lhs = readAffineExpr(reader, context);
+    if (failed(lhs))
+      return failure();
+    auto rhs = readAffineExpr(reader, context);
+    if (failed(rhs))
+      return failure();
+    AffineExprKind exprKind;
+    switch (static_cast<AffineExprBytecodeKind>(kind)) {
+    case AffineExprBytecodeKind::Add:
+      exprKind = AffineExprKind::Add;
+      break;
+    case AffineExprBytecodeKind::Mul:
+      exprKind = AffineExprKind::Mul;
+      break;
+    case AffineExprBytecodeKind::Mod:
+      exprKind = AffineExprKind::Mod;
+      break;
+    case AffineExprBytecodeKind::FloorDiv:
+      exprKind = AffineExprKind::FloorDiv;
+      break;
+    case AffineExprBytecodeKind::CeilDiv:
+      exprKind = AffineExprKind::CeilDiv;
+      break;
+    default:
+      llvm_unreachable("unhandled affine expr kind");
+    }
+    return getAffineBinaryOpExpr(exprKind, *lhs, *rhs);
+  }
+  default:
+    reader.emitError() << "unknown AffineExpr kind: " << kind;
+    return failure();
+  }
+}
+
+static void writeAffineExpr(DialectBytecodeWriter &writer, AffineExpr expr) {
+  switch (expr.getKind()) {
+  case AffineExprKind::DimId:
+    writer.writeVarInt(static_cast<uint64_t>(AffineExprBytecodeKind::DimId));
+    writer.writeVarInt(cast<AffineDimExpr>(expr).getPosition());
+    break;
+  case AffineExprKind::SymbolId:
+    writer.writeVarInt(static_cast<uint64_t>(AffineExprBytecodeKind::SymbolId));
+    writer.writeVarInt(cast<AffineSymbolExpr>(expr).getPosition());
+    break;
+  case AffineExprKind::Constant:
+    writer.writeVarInt(static_cast<uint64_t>(AffineExprBytecodeKind::Constant));
+    writer.writeSignedVarInt(cast<AffineConstantExpr>(expr).getValue());
+    break;
+  case AffineExprKind::Add:
+    writer.writeVarInt(static_cast<uint64_t>(AffineExprBytecodeKind::Add));
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getLHS());
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getRHS());
+    break;
+  case AffineExprKind::Mul:
+    writer.writeVarInt(static_cast<uint64_t>(AffineExprBytecodeKind::Mul));
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getLHS());
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getRHS());
+    break;
+  case AffineExprKind::Mod:
+    writer.writeVarInt(static_cast<uint64_t>(AffineExprBytecodeKind::Mod));
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getLHS());
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getRHS());
+    break;
+  case AffineExprKind::FloorDiv:
+    writer.writeVarInt(static_cast<uint64_t>(AffineExprBytecodeKind::FloorDiv));
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getLHS());
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getRHS());
+    break;
+  case AffineExprKind::CeilDiv:
+    writer.writeVarInt(static_cast<uint64_t>(AffineExprBytecodeKind::CeilDiv));
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getLHS());
+    writeAffineExpr(writer, cast<AffineBinaryOpExpr>(expr).getRHS());
+    break;
+  }
+}
+
+static LogicalResult readAffineMap(DialectBytecodeReader &reader,
+                                   MLIRContext *context,
+                                   AffineMap &map) {
+  uint64_t numDims, numSymbols, numResults;
+  if (failed(reader.readVarInt(numDims)) ||
+      failed(reader.readVarInt(numSymbols)) ||
+      failed(reader.readVarInt(numResults)))
+    return failure();
+
+  SmallVector<AffineExpr> results;
+  results.reserve(numResults);
+  for (uint64_t i = 0; i < numResults; ++i) {
+    auto expr = readAffineExpr(reader, context);
+    if (failed(expr))
+      return failure();
+    results.push_back(*expr);
+  }
+  map = AffineMap::get(numDims, numSymbols, results, context);
+  return success();
+}
+
+static void writeAffineMap(DialectBytecodeWriter &writer,
+                           AffineMapAttr attr) {
+  AffineMap map = attr.getValue();
+  writer.writeVarInt(map.getNumDims());
+  writer.writeVarInt(map.getNumSymbols());
+  writer.writeVarInt(map.getNumResults());
+  for (AffineExpr expr : map.getResults())
+    writeAffineExpr(writer, expr);
+}
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+// Returns the bitwidth if known, else return 0.
+static std::optional<unsigned> getIntegerBitWidth(
+    DialectBytecodeReader &reader, Type type) {
   if (auto intType = dyn_cast<IntegerType>(type))
     return intType.getWidth();
   if (llvm::isa<IndexType>(type))

>From 1a475510ec7a584761476e4b9d4e4018be419a93 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Tue, 14 Apr 2026 11:16:42 +0200
Subject: [PATCH 4/4] clang-format

---
 mlir/lib/IR/BuiltinDialectBytecode.cpp | 15 +++++----------
 1 file changed, 5 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 7225abe82ce92..b4e2c8b588a74 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -54,7 +54,7 @@ enum class AffineExprBytecodeKind : uint64_t {
 };
 
 static FailureOr<AffineExpr> readAffineExpr(DialectBytecodeReader &reader,
-                                             MLIRContext *context) {
+                                            MLIRContext *context) {
   uint64_t kind;
   if (failed(reader.readVarInt(kind)))
     return failure();
@@ -111,9 +111,6 @@ static FailureOr<AffineExpr> readAffineExpr(DialectBytecodeReader &reader,
     }
     return getAffineBinaryOpExpr(exprKind, *lhs, *rhs);
   }
-  default:
-    reader.emitError() << "unknown AffineExpr kind: " << kind;
-    return failure();
   }
 }
 
@@ -160,8 +157,7 @@ static void writeAffineExpr(DialectBytecodeWriter &writer, AffineExpr expr) {
 }
 
 static LogicalResult readAffineMap(DialectBytecodeReader &reader,
-                                   MLIRContext *context,
-                                   AffineMap &map) {
+                                   MLIRContext *context, AffineMap &map) {
   uint64_t numDims, numSymbols, numResults;
   if (failed(reader.readVarInt(numDims)) ||
       failed(reader.readVarInt(numSymbols)) ||
@@ -180,8 +176,7 @@ static LogicalResult readAffineMap(DialectBytecodeReader &reader,
   return success();
 }
 
-static void writeAffineMap(DialectBytecodeWriter &writer,
-                           AffineMapAttr attr) {
+static void writeAffineMap(DialectBytecodeWriter &writer, AffineMapAttr attr) {
   AffineMap map = attr.getValue();
   writer.writeVarInt(map.getNumDims());
   writer.writeVarInt(map.getNumSymbols());
@@ -195,8 +190,8 @@ static void writeAffineMap(DialectBytecodeWriter &writer,
 //===----------------------------------------------------------------------===//
 
 // Returns the bitwidth if known, else return 0.
-static std::optional<unsigned> getIntegerBitWidth(
-    DialectBytecodeReader &reader, Type type) {
+static std::optional<unsigned> getIntegerBitWidth(DialectBytecodeReader &reader,
+                                                  Type type) {
   if (auto intType = dyn_cast<IntegerType>(type))
     return intType.getWidth();
   if (llvm::isa<IndexType>(type))



More information about the Mlir-commits mailing list