[Mlir-commits] [mlir] [mlir][Complex] Fix bug in `MergeComplexBitcast` (PR #74271)

Matthias Springer llvmlistbot at llvm.org
Sun Dec 3 18:56:15 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/74271

When two `complex.bitcast` ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an `arith.bitcast` should be generated. Otherwise, the generated `complex.bitcast` op is invalid.

Also remove a pattern that convertes non-complex -> non-complex `complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with `-debug` (which will should intermediate IR that does not verify) or with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (#74270).

>From f59d8ac712ae521b3f5bbf92bd27bfeb3361f5e8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 4 Dec 2023 11:54:02 +0900
Subject: [PATCH] [mlir][Complex] Fix bug in `MergeComplexBitcast`

When two `complex.bitcast` ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an `arith.bitcast` should be generated. Otherwise, the generated `complex.bitcast` op is invalid.

Also remove a pattern that convertes non-complex -> non-complex `complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with `-debug` (which will should intermediate IR that does not verify) or with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (#74270).
---
 mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 31 +++++++++-------------
 mlir/test/Dialect/Complex/invalid.mlir     |  2 +-
 2 files changed, 13 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
   }
 
   if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
-    return emitOpError("requires input or output is a complex type");
+    return emitOpError(
+        "requires that either input or output has a complex type");
   }
 
   if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
   LogicalResult matchAndRewrite(BitcastOp op,
                                 PatternRewriter &rewriter) const override {
     if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
-      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
-                                             defining.getOperand());
+      if (isa<ComplexType>(op.getType()) ||
+          isa<ComplexType>(defining.getOperand().getType())) {
+        // complex.bitcast requires that input or output is complex.
+        rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+                                               defining.getOperand());
+      } else {
+        rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+                                                      defining.getOperand());
+      }
       return success();
     }
 
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
   }
 };
 
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
-  using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BitcastOp op,
-                                PatternRewriter &rewriter) const override {
-    if (isa<ComplexType>(op.getType()) ||
-        isa<ComplexType>(op.getOperand().getType()))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
-                                                  op.getOperand());
-    return success();
-  }
-};
-
 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
 // -----
 
 func.func @complex_bitcast_i64(%arg0 : i64) {
-  // expected-error @+1 {{op requires input or output is a complex type}}
+  // expected-error @+1 {{op requires that either input or output has a complex type}}
   %0 = complex.bitcast %arg0: i64 to f64
   return
 }



More information about the Mlir-commits mailing list