[Mlir-commits] [mlir] [mlir][vector] Fix invalid IR in `RewriteBitCastOfTruncI` (PR #78146)

Matthias Springer llvmlistbot at llvm.org
Mon Jan 15 04:47:29 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/78146

>From 645cbb502d78b63b74ba28f59b6dcf8b4c73322d Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 15 Jan 2024 11:28:02 +0000
Subject: [PATCH] [mlir][vector] Fix invalid IR in `RewriteBitCastOfTruncI`

This commit fixes `Dialect/Vector/vector-rewrite-narrow-types.mlir` when running with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

```
within split at llvm-project/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir:1 offset :118:8: error: 'arith.trunci' op operand type 'vector<3xi16>' and result type 'vector<3xi16>' are cast incompatible
  %1 = vector.bitcast %0 : vector<16xi3> to vector<3xi16>
       ^
within split at llvm-project/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir:1 offset :118:8: note: see current operation: %48 = "arith.trunci"(%47) : (vector<3xi16>) -> vector<3xi16>
LLVM ERROR: IR failed to verify after pattern application
```
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 20 +++++++++++++------
 1 file changed, 14 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ead7d645cb5bb3..a4a72754ccc250 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -664,8 +664,8 @@ struct BitCastRewriter {
 
 } // namespace
 
-[[maybe_unused]] static raw_ostream &operator<<(raw_ostream &os,
-                               const SmallVector<SourceElementRangeList> &vec) {
+[[maybe_unused]] static raw_ostream &
+operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) {
   for (const auto &l : vec) {
     for (auto it : llvm::enumerate(l)) {
       os << "{ " << it.value().sourceElementIdx << ": b@["
@@ -847,11 +847,19 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     bool narrowing = targetVectorType.getElementTypeBitWidth() <=
                      shuffledElementType.getIntOrFloatBitWidth();
     if (narrowing) {
-      rewriter.replaceOpWithNewOp<arith::TruncIOp>(
-          bitCastOp, bitCastOp.getResultVectorType(), runningResult);
+      if (runningResult.getType() == bitCastOp.getResultVectorType()) {
+        rewriter.replaceOp(bitCastOp, runningResult);
+      } else {
+        rewriter.replaceOpWithNewOp<arith::TruncIOp>(
+            bitCastOp, bitCastOp.getResultVectorType(), runningResult);
+      }
     } else {
-      rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
-          bitCastOp, bitCastOp.getResultVectorType(), runningResult);
+      if (runningResult.getType() == bitCastOp.getResultVectorType()) {
+        rewriter.replaceOp(bitCastOp, runningResult);
+      } else {
+        rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
+            bitCastOp, bitCastOp.getResultVectorType(), runningResult);
+      }
     }
 
     return success();



More information about the Mlir-commits mailing list