[Mlir-commits] [mlir] 192439d - [mlir][Complex] Fix bug in `MergeComplexBitcast` (#74271)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 4 17:25:23 PST 2023


Author: Matthias Springer
Date: 2023-12-05T10:25:19+09:00
New Revision: 192439db6e3fcccf98c850bda1b970a11c590bbb

URL: https://github.com/llvm/llvm-project/commit/192439db6e3fcccf98c850bda1b970a11c590bbb
DIFF: https://github.com/llvm/llvm-project/commit/192439db6e3fcccf98c850bda1b970a11c590bbb.diff

LOG: [mlir][Complex] Fix bug in `MergeComplexBitcast` (#74271)

When two `complex.bitcast` ops are folded and the resulting bitcast is a
non-complex -> non-complex bitcast, an `arith.bitcast` should be
generated. Otherwise, the generated `complex.bitcast` op is invalid.

Also remove a pattern that convertes non-complex -> non-complex
`complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are
invalid and should not appear in the input.

Note: This bug can only be triggered by running with `-debug` (which
will should intermediate IR that does not verify) or with
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (#74270).

Added: 
    

Modified: 
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/test/Dialect/Complex/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
   }
 
   if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
-    return emitOpError("requires input or output is a complex type");
+    return emitOpError(
+        "requires that either input or output has a complex type");
   }
 
   if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
   LogicalResult matchAndRewrite(BitcastOp op,
                                 PatternRewriter &rewriter) const override {
     if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
-      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
-                                             defining.getOperand());
+      if (isa<ComplexType>(op.getType()) ||
+          isa<ComplexType>(defining.getOperand().getType())) {
+        // complex.bitcast requires that input or output is complex.
+        rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+                                               defining.getOperand());
+      } else {
+        rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+                                                      defining.getOperand());
+      }
       return success();
     }
 
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
   }
 };
 
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
-  using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BitcastOp op,
-                                PatternRewriter &rewriter) const override {
-    if (isa<ComplexType>(op.getType()) ||
-        isa<ComplexType>(op.getOperand().getType()))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
-                                                  op.getOperand());
-    return success();
-  }
-};
-
 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_
diff erent_element_types() {
 // -----
 
 func.func @complex_bitcast_i64(%arg0 : i64) {
-  // expected-error @+1 {{op requires input or output is a complex type}}
+  // expected-error @+1 {{op requires that either input or output has a complex type}}
   %0 = complex.bitcast %arg0: i64 to f64
   return
 }


        


More information about the Mlir-commits mailing list