[Mlir-commits] [mlir] [mlir][Complex] Fix bug in `MergeComplexBitcast` (PR #74271)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 3 18:56:44 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

When two `complex.bitcast` ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an `arith.bitcast` should be generated. Otherwise, the generated `complex.bitcast` op is invalid.

Also remove a pattern that convertes non-complex -> non-complex `complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with `-debug` (which will should intermediate IR that does not verify) or with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (#<!-- -->74270).

---
Full diff: https://github.com/llvm/llvm-project/pull/74271.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Complex/IR/ComplexOps.cpp (+12-19) 
- (modified) mlir/test/Dialect/Complex/invalid.mlir (+1-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
   }
 
   if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
-    return emitOpError("requires input or output is a complex type");
+    return emitOpError(
+        "requires that either input or output has a complex type");
   }
 
   if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
   LogicalResult matchAndRewrite(BitcastOp op,
                                 PatternRewriter &rewriter) const override {
     if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
-      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
-                                             defining.getOperand());
+      if (isa<ComplexType>(op.getType()) ||
+          isa<ComplexType>(defining.getOperand().getType())) {
+        // complex.bitcast requires that input or output is complex.
+        rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+                                               defining.getOperand());
+      } else {
+        rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+                                                      defining.getOperand());
+      }
       return success();
     }
 
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
   }
 };
 
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
-  using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BitcastOp op,
-                                PatternRewriter &rewriter) const override {
-    if (isa<ComplexType>(op.getType()) ||
-        isa<ComplexType>(op.getOperand().getType()))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
-                                                  op.getOperand());
-    return success();
-  }
-};
-
 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
 // -----
 
 func.func @complex_bitcast_i64(%arg0 : i64) {
-  // expected-error @+1 {{op requires input or output is a complex type}}
+  // expected-error @+1 {{op requires that either input or output has a complex type}}
   %0 = complex.bitcast %arg0: i64 to f64
   return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/74271


More information about the Mlir-commits mailing list