[Mlir-commits] [mlir] [mlir][Complex] Fix bug in `MergeComplexBitcast` (PR #74271)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Dec 3 18:56:44 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
When two `complex.bitcast` ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an `arith.bitcast` should be generated. Otherwise, the generated `complex.bitcast` op is invalid.
Also remove a pattern that convertes non-complex -> non-complex `complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are invalid and should not appear in the input.
Note: This bug can only be triggered by running with `-debug` (which will should intermediate IR that does not verify) or with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (#<!-- -->74270).
---
Full diff: https://github.com/llvm/llvm-project/pull/74271.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Complex/IR/ComplexOps.cpp (+12-19)
- (modified) mlir/test/Dialect/Complex/invalid.mlir (+1-1)
``````````diff
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
}
if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
- return emitOpError("requires input or output is a complex type");
+ return emitOpError(
+ "requires that either input or output has a complex type");
}
if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
LogicalResult matchAndRewrite(BitcastOp op,
PatternRewriter &rewriter) const override {
if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
- rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
- defining.getOperand());
+ if (isa<ComplexType>(op.getType()) ||
+ isa<ComplexType>(defining.getOperand().getType())) {
+ // complex.bitcast requires that input or output is complex.
+ rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+ defining.getOperand());
+ } else {
+ rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+ defining.getOperand());
+ }
return success();
}
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
}
};
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
- using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BitcastOp op,
- PatternRewriter &rewriter) const override {
- if (isa<ComplexType>(op.getType()) ||
- isa<ComplexType>(op.getOperand().getType()))
- return failure();
-
- rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
- op.getOperand());
- return success();
- }
-};
-
void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+ results.add<MergeComplexBitcast, MergeArithBitcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
// -----
func.func @complex_bitcast_i64(%arg0 : i64) {
- // expected-error @+1 {{op requires input or output is a complex type}}
+ // expected-error @+1 {{op requires that either input or output has a complex type}}
%0 = complex.bitcast %arg0: i64 to f64
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/74271
More information about the Mlir-commits
mailing list