[Mlir-commits] [mlir] [MLIR][MemRef] Fix memref.reshape lowering to avoid unresolvable unrealized casts for i32 shape (PR #189238)
Mehdi Amini
llvmlistbot at llvm.org
Sun Mar 29 06:11:26 PDT 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189238
When lowering memref.reshape with a shape memref whose element type is a narrower integer than the LLVM index type (e.g. memref<Nxi32> on a 64-bit target where the index type is i64), the MemRefReshapeOpLowering pattern called typeConverter->materializeTargetConversion to convert the loaded i32 to i64. This produced a builtin.unrealized_conversion_cast from i32 to i64 that had no registered conversion to eliminate it, leaving the output in an invalid state that mlir-translate could not process.
Fix by emitting llvm.zext when the loaded shape element type is an integer type narrower than the LLVM index type. For other types (e.g. the MLIR index type), fall back to materializeTargetConversion as before.
Fixes #77430
Assisted-by: Claude Code
>From bce50d4c000346baeda58713b84639eacdfb2020 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 15:28:13 -0700
Subject: [PATCH] [MLIR][MemRef] Fix memref.reshape lowering to avoid
unresolvable unrealized casts for i32 shape
When lowering memref.reshape with a shape memref whose element type is a
narrower integer than the LLVM index type (e.g. memref<Nxi32> on a 64-bit
target where the index type is i64), the MemRefReshapeOpLowering pattern
called typeConverter->materializeTargetConversion to convert the loaded i32
to i64. This produced a builtin.unrealized_conversion_cast from i32 to i64
that had no registered conversion to eliminate it, leaving the output in an
invalid state that mlir-translate could not process.
Fix by emitting llvm.zext when the loaded shape element type is an integer
type narrower than the LLVM index type. For other types (e.g. the MLIR index
type), fall back to materializeTargetConversion as before.
Fixes #77430
Assisted-by: Claude Code
---
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 27 +++++++++++++++----
.../convert-static-memref-ops.mlir | 15 +++++++++++
2 files changed, 37 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index c42a85fa375ba..ca6e3e676be58 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1547,11 +1547,28 @@ struct MemRefReshapeOpLowering
Value shapeOp = reshapeOp.getShape();
Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
- Type indexType = getIndexType();
- if (dimSize.getType() != indexType)
- dimSize = typeConverter->materializeTargetConversion(
- rewriter, loc, indexType, dimSize);
- assert(dimSize && "Invalid memref element type");
+ if (dimSize.getType() != indexType) {
+ // The shape memref element type may differ from the LLVM index
+ // type (i64 on 64-bit targets). For integer types narrower than
+ // the index type, emit llvm.zext to widen without producing an
+ // unresolvable unrealized_conversion_cast. For other types (e.g.
+ // the MLIR index type), fall back to materializeTargetConversion
+ // which inserts an unrealized cast that the type converter or
+ // reconcile-unrealized-casts can later resolve.
+ if (auto intType = dyn_cast<IntegerType>(dimSize.getType())) {
+ auto indexIntType = cast<IntegerType>(indexType);
+ if (intType.getWidth() < indexIntType.getWidth())
+ dimSize =
+ LLVM::ZExtOp::create(rewriter, loc, indexType, dimSize);
+ else
+ dimSize = typeConverter->materializeTargetConversion(
+ rewriter, loc, indexType, dimSize);
+ } else {
+ dimSize = typeConverter->materializeTargetConversion(
+ rewriter, loc, indexType, dimSize);
+ }
+ assert(dimSize && "Invalid memref element type");
+ }
}
desc.setSize(rewriter, loc, i, dimSize);
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index 040a27e160557..686cb3f256409 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -359,6 +359,21 @@ func.func @memref.reshape_index(%arg0: memref<?x?xi32>, %shape: memref<1xindex>)
return %1 : memref<?xi32>
}
+// CHECK-LABEL: func @memref.reshape_i32_shape
+// CHECK-SAME: %[[arg:.*]]: memref<1x?x?x1xf32>, %[[shape:.*]]: memref<2xi32>
+func.func @memref.reshape_i32_shape(%arg: memref<1x?x?x1xf32>, %shape: memref<2xi32>) -> memref<?x?xf32> {
+ // CHECK-DAG: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<1x?x?x1xf32> to !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>
+ // CHECK-DAG: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<2xi32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // When the shape memref element type (i32) differs from the LLVM index type (i64),
+ // the lowering must emit llvm.zext rather than an unresolvable unrealized_conversion_cast.
+ // CHECK: %[[shape_load:.*]] = llvm.load {{.*}} : !llvm.ptr -> i32
+ // CHECK-NEXT: %[[zext:.*]] = llvm.zext %[[shape_load]] : i32 to i64
+ %0 = memref.reshape %arg(%shape) : (memref<1x?x?x1xf32>, memref<2xi32>) -> memref<?x?xf32>
+ return %0 : memref<?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: @memref_memory_space_cast
func.func @memref_memory_space_cast(%input : memref<?xf32>) -> memref<?xf32, 1> {
%cast = memref.memory_space_cast %input : memref<?xf32> to memref<?xf32, 1>
More information about the Mlir-commits
mailing list