[Mlir-commits] [mlir] [mlir][vector] Fix RewriteExtOfBitCast dropping the sign extension (PR #199243)
Benoit Jacob
llvmlistbot at llvm.org
Fri May 22 10:43:26 PDT 2026
https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/199243
RewriteExtOfBitCast rewrites `ext{s,u}i(vector.bitcast(...))` into a sequence of shuffles and bitwise ops. `BitCastRewriter::genericRewriteStep` assembles each bitcast-result element with a mask and *logical* shifts, so every element ends up zero-extended in a lane as wide as the bitcast source element type.
That is correct for `arith.extui`. For `arith.extsi` it is not: the bitcast-result elements must be sign-extended, but they are left zero-extended in the lane and the trailing `extsi`/`trunci` (which extends from the wide lane, whose top bits are already zero) cannot recover the sign. A signed extension was silently turned into a zero extension.
Minimal reproducer -- each i8 below must be sign-extended:
```mlir
func.func @repro(%a: vector<4xi16>) -> vector<8xi16> {
%0 = vector.bitcast %a : vector<4xi16> to vector<8xi8>
%1 = arith.extsi %0 : vector<8xi8> to vector<8xi16>
return %1 : vector<8xi16>
}
```
Before this patch an input byte 0xFF (i8 -1) came back as +255.
Fix: for the signed variant, sign-extend each assembled element within its lane with an `arith.shli`/`arith.shrsi` pair (from the bitcast-result element width up to the lane width) before the final extend/truncate.
>From a4888aaed60ef592794687e1988c4a81e01eba44 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 22 May 2026 13:30:43 -0400
Subject: [PATCH] [mlir][vector] Fix RewriteExtOfBitCast dropping the sign
extension
RewriteExtOfBitCast rewrites `ext{s,u}i(vector.bitcast(...))` into a
sequence of shuffles and bitwise ops. `BitCastRewriter::genericRewriteStep`
assembles each bitcast-result element with a mask and *logical* shifts,
so every element ends up zero-extended in a lane as wide as the bitcast
source element type.
That is correct for `arith.extui`. For `arith.extsi` it is not: the
bitcast-result elements must be sign-extended, but they are left
zero-extended in the lane and the trailing `extsi`/`trunci` (which
extends from the wide lane, whose top bits are already zero) cannot
recover the sign. A signed extension was silently turned into a zero
extension.
Minimal reproducer -- each i8 below must be sign-extended:
func.func @repro(%a: vector<4xi16>) -> vector<8xi16> {
%0 = vector.bitcast %a : vector<4xi16> to vector<8xi8>
%1 = arith.extsi %0 : vector<8xi8> to vector<8xi16>
return %1 : vector<8xi16>
}
Before this patch an input byte 0xFF (i8 -1) came back as +255.
Fix: for the signed variant, sign-extend each assembled element within
its lane with an `arith.shli`/`arith.shrsi` pair (from the bitcast-result
element width up to the lane width) before the final extend/truncate.
---
.../Transforms/VectorEmulateNarrowType.cpp | 25 ++++++++++++++++
.../Vector/vector-rewrite-narrow-types.mlir | 29 ++++++++++++++++++-
2 files changed, 53 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 583cda7ac2810..23e0118e42488 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -2076,6 +2076,31 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
}
+ // `genericRewriteStep` assembles each bitcast-result element with bitwise
+ // ops (mask + *logical* shifts), so every element lands zero-extended in a
+ // `shuffledElementType`-wide lane. That is what an unsigned extension
+ // wants, but for a signed extension the bitcast-result elements must be
+ // sign-extended. Recover the sign here with an arithmetic shift pair that
+ // sign-extends from the bitcast-result element width up to the lane width;
+ // without this the `extsi` is silently turned into a zero extension.
+ if (std::is_same<ExtOpType, arith::ExtSIOp>::value) {
+ auto runningVecTy = cast<VectorType>(runningResult.getType());
+ int64_t laneWidth = runningVecTy.getElementTypeBitWidth();
+ int64_t elemWidth =
+ bitCastOp.getResultVectorType().getElementTypeBitWidth();
+ if (elemWidth < laneWidth) {
+ Value shiftAmount = arith::ConstantOp::create(
+ rewriter, bitCastOp->getLoc(),
+ DenseElementsAttr::get(
+ runningVecTy, IntegerAttr::get(runningVecTy.getElementType(),
+ laneWidth - elemWidth)));
+ Value shiftedLeft = arith::ShLIOp::create(rewriter, bitCastOp->getLoc(),
+ runningResult, shiftAmount);
+ runningResult = arith::ShRSIOp::create(rewriter, bitCastOp->getLoc(),
+ shiftedLeft, shiftAmount);
+ }
+ }
+
// Finalize the rewrite.
bool narrowing =
cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index a4af307b15da4..f669cb375457c 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -153,6 +153,7 @@ func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> {
// CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[0, 3, 0, 15, 1, 0, 7, 0]> : vector<8xi8>
// CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 5, 2, 7, 4, 1, 6, 3]> : vector<8xi8>
// CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 3, 5, 1, 4, 5, 2, 5]> : vector<8xi8>
+ // CHECK-DAG: %[[SIGN_CST:.*]] = arith.constant dense<3> : vector<8xi8>
// CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1, 2, 3, 3, 4] : vector<5xi8>, vector<5xi8>
// CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<8xi8>
// CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<8xi8>
@@ -160,7 +161,11 @@ func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> {
// CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<8xi8>
// CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<8xi8>
// CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<8xi8>
- // CHECK: %[[RES:.*]] = arith.extsi %[[O1]] : vector<8xi8> to vector<8xi16>
+ // The bitwise assembly above leaves each i5 zero-extended in an i8 lane; for
+ // a *signed* extension the i5 sign bit must be propagated first.
+ // CHECK: %[[SHLS:.*]] = arith.shli %[[O1]], %[[SIGN_CST]] : vector<8xi8>
+ // CHECK: %[[SHRS:.*]] = arith.shrsi %[[SHLS]], %[[SIGN_CST]] : vector<8xi8>
+ // CHECK: %[[RES:.*]] = arith.extsi %[[SHRS]] : vector<8xi8> to vector<8xi16>
// return %[[RES]] : vector<8xi16>
%0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
@@ -213,6 +218,28 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
return %0 : vector<16x8xi7>
}
+// Signed extension of a bitcast that splits each i16 into two i8 lanes. The
+// byte assembly is purely bitwise (mask + *logical* shift), so each byte ends
+// up zero-extended in its i16 lane; the signed extension must then be
+// completed by the arith.shli/arith.shrsi pair. Regression test for the sign
+// extension being silently dropped (turning extsi into a zero extension).
+// CHECK-LABEL: func.func @ext_of_bitcast_i16_to_i8_signed(
+// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<4xi16>) -> vector<8xi16> {
+func.func @ext_of_bitcast_i16_to_i8_signed(%a: vector<4xi16>) -> vector<8xi16> {
+ // CHECK-DAG: %[[SIGN_CST:.*]] = arith.constant dense<8> : vector<8xi16>
+ // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<[0, 8, 0, 8, 0, 8, 0, 8]> : vector<8xi16>
+ // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[255, -256, 255, -256, 255, -256, 255, -256]> : vector<8xi16>
+ // CHECK: %[[SHUF:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1, 2, 2, 3, 3] : vector<4xi16>, vector<4xi16>
+ // CHECK: %[[AND:.*]] = arith.andi %[[SHUF]], %[[MASK]] : vector<8xi16>
+ // CHECK: %[[SHRU:.*]] = arith.shrui %[[AND]], %[[SHR_CST]] : vector<8xi16>
+ // CHECK: %[[SHL:.*]] = arith.shli %[[SHRU]], %[[SIGN_CST]] : vector<8xi16>
+ // CHECK: %[[RES:.*]] = arith.shrsi %[[SHL]], %[[SIGN_CST]] : vector<8xi16>
+ // CHECK: return %[[RES]] : vector<8xi16>
+ %0 = vector.bitcast %a : vector<4xi16> to vector<8xi8>
+ %1 = arith.extsi %0 : vector<8xi8> to vector<8xi16>
+ return %1 : vector<8xi16>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
More information about the Mlir-commits
mailing list