[Mlir-commits] [mlir] [mlir][vector] Fix RewriteExtOfBitCast dropping the sign extension (PR #199243)

Benoit Jacob llvmlistbot at llvm.org
Fri May 22 10:43:26 PDT 2026


https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/199243

RewriteExtOfBitCast rewrites `ext{s,u}i(vector.bitcast(...))` into a sequence of shuffles and bitwise ops. `BitCastRewriter::genericRewriteStep` assembles each bitcast-result element with a mask and *logical* shifts, so every element ends up zero-extended in a lane as wide as the bitcast source element type.

That is correct for `arith.extui`. For `arith.extsi` it is not: the bitcast-result elements must be sign-extended, but they are left zero-extended in the lane and the trailing `extsi`/`trunci` (which extends from the wide lane, whose top bits are already zero) cannot recover the sign. A signed extension was silently turned into a zero extension.

Minimal reproducer -- each i8 below must be sign-extended:

```mlir
  func.func @repro(%a: vector<4xi16>) -> vector<8xi16> {
    %0 = vector.bitcast %a : vector<4xi16> to vector<8xi8>
    %1 = arith.extsi %0 : vector<8xi8> to vector<8xi16>
    return %1 : vector<8xi16>
  }
```

Before this patch an input byte 0xFF (i8 -1) came back as +255.

Fix: for the signed variant, sign-extend each assembled element within its lane with an `arith.shli`/`arith.shrsi` pair (from the bitcast-result element width up to the lane width) before the final extend/truncate.

>From a4888aaed60ef592794687e1988c4a81e01eba44 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 22 May 2026 13:30:43 -0400
Subject: [PATCH] [mlir][vector] Fix RewriteExtOfBitCast dropping the sign
 extension

RewriteExtOfBitCast rewrites `ext{s,u}i(vector.bitcast(...))` into a
sequence of shuffles and bitwise ops. `BitCastRewriter::genericRewriteStep`
assembles each bitcast-result element with a mask and *logical* shifts,
so every element ends up zero-extended in a lane as wide as the bitcast
source element type.

That is correct for `arith.extui`. For `arith.extsi` it is not: the
bitcast-result elements must be sign-extended, but they are left
zero-extended in the lane and the trailing `extsi`/`trunci` (which
extends from the wide lane, whose top bits are already zero) cannot
recover the sign. A signed extension was silently turned into a zero
extension.

Minimal reproducer -- each i8 below must be sign-extended:

  func.func @repro(%a: vector<4xi16>) -> vector<8xi16> {
    %0 = vector.bitcast %a : vector<4xi16> to vector<8xi8>
    %1 = arith.extsi %0 : vector<8xi8> to vector<8xi16>
    return %1 : vector<8xi16>
  }

Before this patch an input byte 0xFF (i8 -1) came back as +255.

Fix: for the signed variant, sign-extend each assembled element within
its lane with an `arith.shli`/`arith.shrsi` pair (from the bitcast-result
element width up to the lane width) before the final extend/truncate.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 25 ++++++++++++++++
 .../Vector/vector-rewrite-narrow-types.mlir   | 29 ++++++++++++++++++-
 2 files changed, 53 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 583cda7ac2810..23e0118e42488 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -2076,6 +2076,31 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
           rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
     }
 
+    // `genericRewriteStep` assembles each bitcast-result element with bitwise
+    // ops (mask + *logical* shifts), so every element lands zero-extended in a
+    // `shuffledElementType`-wide lane. That is what an unsigned extension
+    // wants, but for a signed extension the bitcast-result elements must be
+    // sign-extended. Recover the sign here with an arithmetic shift pair that
+    // sign-extends from the bitcast-result element width up to the lane width;
+    // without this the `extsi` is silently turned into a zero extension.
+    if (std::is_same<ExtOpType, arith::ExtSIOp>::value) {
+      auto runningVecTy = cast<VectorType>(runningResult.getType());
+      int64_t laneWidth = runningVecTy.getElementTypeBitWidth();
+      int64_t elemWidth =
+          bitCastOp.getResultVectorType().getElementTypeBitWidth();
+      if (elemWidth < laneWidth) {
+        Value shiftAmount = arith::ConstantOp::create(
+            rewriter, bitCastOp->getLoc(),
+            DenseElementsAttr::get(
+                runningVecTy, IntegerAttr::get(runningVecTy.getElementType(),
+                                               laneWidth - elemWidth)));
+        Value shiftedLeft = arith::ShLIOp::create(rewriter, bitCastOp->getLoc(),
+                                                  runningResult, shiftAmount);
+        runningResult = arith::ShRSIOp::create(rewriter, bitCastOp->getLoc(),
+                                               shiftedLeft, shiftAmount);
+      }
+    }
+
     // Finalize the rewrite.
     bool narrowing =
         cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <=
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index a4af307b15da4..f669cb375457c 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -153,6 +153,7 @@ func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> {
   // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<[0, 3, 0, 15, 1, 0, 7, 0]> : vector<8xi8>
   // CHECK-DAG: %[[SHR0_CST:.*]] = arith.constant dense<[0, 5, 2, 7, 4, 1, 6, 3]> : vector<8xi8>
   // CHECK-DAG: %[[SHL1_CST:.*]] = arith.constant dense<[5, 3, 5, 1, 4, 5, 2, 5]> : vector<8xi8>
+  // CHECK-DAG: %[[SIGN_CST:.*]] = arith.constant dense<3> : vector<8xi8>
   // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1, 2, 3, 3, 4] : vector<5xi8>, vector<5xi8>
   // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK0]] : vector<8xi8>
   // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR0_CST]] : vector<8xi8>
@@ -160,7 +161,11 @@ func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> {
   // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK1]] : vector<8xi8>
   // CHECK: %[[SHL1:.*]] = arith.shli %[[A1]], %[[SHL1_CST]] : vector<8xi8>
   // CHECK: %[[O1:.*]] = arith.ori %[[SHR0]], %[[SHL1]] : vector<8xi8>
-  // CHECK: %[[RES:.*]] = arith.extsi %[[O1]] : vector<8xi8> to vector<8xi16>
+  // The bitwise assembly above leaves each i5 zero-extended in an i8 lane; for
+  // a *signed* extension the i5 sign bit must be propagated first.
+  // CHECK: %[[SHLS:.*]] = arith.shli %[[O1]], %[[SIGN_CST]] : vector<8xi8>
+  // CHECK: %[[SHRS:.*]] = arith.shrsi %[[SHLS]], %[[SIGN_CST]] : vector<8xi8>
+  // CHECK: %[[RES:.*]] = arith.extsi %[[SHRS]] : vector<8xi8> to vector<8xi16>
   // return %[[RES]] : vector<8xi16>
 
   %0 = vector.bitcast %a : vector<5xi8> to vector<8xi5>
@@ -213,6 +218,28 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
   return %0 : vector<16x8xi7>
 }
 
+// Signed extension of a bitcast that splits each i16 into two i8 lanes. The
+// byte assembly is purely bitwise (mask + *logical* shift), so each byte ends
+// up zero-extended in its i16 lane; the signed extension must then be
+// completed by the arith.shli/arith.shrsi pair. Regression test for the sign
+// extension being silently dropped (turning extsi into a zero extension).
+// CHECK-LABEL: func.func @ext_of_bitcast_i16_to_i8_signed(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]: vector<4xi16>) -> vector<8xi16> {
+func.func @ext_of_bitcast_i16_to_i8_signed(%a: vector<4xi16>) -> vector<8xi16> {
+  // CHECK-DAG: %[[SIGN_CST:.*]] = arith.constant dense<8> : vector<8xi16>
+  // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<[0, 8, 0, 8, 0, 8, 0, 8]> : vector<8xi16>
+  // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[255, -256, 255, -256, 255, -256, 255, -256]> : vector<8xi16>
+  // CHECK: %[[SHUF:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1, 2, 2, 3, 3] : vector<4xi16>, vector<4xi16>
+  // CHECK: %[[AND:.*]] = arith.andi %[[SHUF]], %[[MASK]] : vector<8xi16>
+  // CHECK: %[[SHRU:.*]] = arith.shrui %[[AND]], %[[SHR_CST]] : vector<8xi16>
+  // CHECK: %[[SHL:.*]] = arith.shli %[[SHRU]], %[[SIGN_CST]] : vector<8xi16>
+  // CHECK: %[[RES:.*]] = arith.shrsi %[[SHL]], %[[SIGN_CST]] : vector<8xi16>
+  // CHECK: return %[[RES]] : vector<8xi16>
+  %0 = vector.bitcast %a : vector<4xi16> to vector<8xi8>
+  %1 = arith.extsi %0 : vector<8xi8> to vector<8xi16>
+  return %1 : vector<8xi16>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op



More information about the Mlir-commits mailing list