[Mlir-commits] [mlir] a8df21f - [mlir][complex] Add a `complex.bitcast` operation
Rob Suderman
llvmlistbot at llvm.org
Thu Jul 6 16:22:55 PDT 2023
Author: Rob Suderman
Date: 2023-07-06T16:21:57-07:00
New Revision: a8df21f433d2f7cd267ac6e5581163f15501918b
URL: https://github.com/llvm/llvm-project/commit/a8df21f433d2f7cd267ac6e5581163f15501918b
DIFF: https://github.com/llvm/llvm-project/commit/a8df21f433d2f7cd267ac6e5581163f15501918b.diff
LOG: [mlir][complex] Add a `complex.bitcast` operation
Converting between a complex<f32> to i64 could be useful for handling interop
between the `arith` and `complex` dialects.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D154663
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
mlir/test/Dialect/Complex/canonicalize.mlir
mlir/test/Dialect/Complex/invalid.mlir
mlir/test/Dialect/Complex/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 0e962df658ec36..b80d77996a20f5 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -95,6 +95,31 @@ def Atan2Op : ComplexArithmeticOp<"atan2"> {
}];
}
+
+//===----------------------------------------------------------------------===//
+// Bitcast
+//===----------------------------------------------------------------------===//
+
+def BitcastOp : Complex_Op<"bitcast", [Pure]> {
+
+ let summary = "computes bitcast between complex and equal arith types";
+ let description = [{
+
+ Example:
+
+ ```mlir
+ %a = complex.bitcast %b : complex<f32> -> i64
+ ```
+ }];
+ let assemblyFormat = "$operand attr-dict `:` type($operand) `to` type($result)";
+ let arguments = (ins AnyType:$operand);
+ let results = (outs AnyType:$result);
+
+ let hasCanonicalizer = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 6041f494b930f7..8fd914dd107ffb 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -72,6 +72,109 @@ LogicalResult ConstantOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// BitcastOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
+ if (getOperand().getType() == getType())
+ return getOperand();
+
+ return {};
+}
+
+LogicalResult BitcastOp::verify() {
+ auto operandType = getOperand().getType();
+ auto resultType = getType();
+
+ // We allow this to be legal as it can be folded away.
+ if (operandType == resultType)
+ return success();
+
+ if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
+ return emitOpError("operand must be int/float/complex");
+ }
+
+ if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
+ return emitOpError("result must be int/float/complex");
+ }
+
+ if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
+ return emitOpError("requires input or output is a complex type");
+ }
+
+ if (isa<ComplexType>(resultType))
+ std::swap(operandType, resultType);
+
+ int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
+ .getElementType()
+ .getIntOrFloatBitWidth() *
+ 2;
+ int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
+
+ if (operandBitwidth != resultBitwidth) {
+ return emitOpError("casting bitwidths do not match");
+ }
+
+ return success();
+}
+
+struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
+ using OpRewritePattern<BitcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BitcastOp op,
+ PatternRewriter &rewriter) const override {
+ if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
+ rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+ defining.getOperand());
+ return success();
+ }
+
+ if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
+ rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+ defining.getOperand());
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
+ using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::BitcastOp op,
+ PatternRewriter &rewriter) const override {
+ if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
+ rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
+ defining.getOperand());
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+struct ArithBitcast final : OpRewritePattern<BitcastOp> {
+ using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BitcastOp op,
+ PatternRewriter &rewriter) const override {
+ if (isa<ComplexType>(op.getType()) ||
+ isa<ComplexType>(op.getOperand().getType()))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+ op.getOperand());
+ return success();
+ }
+};
+
+void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+}
+
//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index e36805691e3fdf..64c3f313dda959 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -232,3 +232,52 @@ func.func @mul_one_f128(%arg0: f128, %arg1: f128) -> complex<f128> {
// CHECK-NEXT: return %[[CREATE]]
return %mul : complex<f128>
}
+
+// CHECK-LABEL: func @fold_between_complex
+// CHECK-SAME: %[[ARG0:.*]]: complex<f32>
+func.func @fold_between_complex(%arg0 : complex<f32>) -> complex<f32> {
+ %0 = complex.bitcast %arg0 : complex<f32> to i64
+ %1 = complex.bitcast %0 : i64 to complex<f32>
+ // CHECK: return %[[ARG0]] : complex<f32>
+ func.return %1 : complex<f32>
+}
+
+// CHECK-LABEL: func @fold_between_i64
+// CHECK-SAME: %[[ARG0:.*]]: i64
+func.func @fold_between_i64(%arg0 : i64) -> i64 {
+ %0 = complex.bitcast %arg0 : i64 to complex<f32>
+ %1 = complex.bitcast %0 : complex<f32> to i64
+ // CHECK: return %[[ARG0]] : i64
+ func.return %1 : i64
+}
+
+// CHECK-LABEL: func @canon_arith_bitcast
+// CHECK-SAME: %[[ARG0:.*]]: f64
+func.func @canon_arith_bitcast(%arg0 : f64) -> i64 {
+ %0 = complex.bitcast %arg0 : f64 to complex<f32>
+ %1 = complex.bitcast %0 : complex<f32> to i64
+ // CHECK: %[[R0:.+]] = arith.bitcast %[[ARG0]]
+ // CHECK: return %[[R0]] : i64
+ func.return %1 : i64
+}
+
+
+// CHECK-LABEL: func @double_bitcast
+// CHECK-SAME: %[[ARG0:.*]]: f64
+func.func @double_bitcast(%arg0 : f64) -> complex<f32> {
+ // CHECK: %[[R0:.+]] = complex.bitcast %[[ARG0]]
+ %0 = arith.bitcast %arg0 : f64 to i64
+ %1 = complex.bitcast %0 : i64 to complex<f32>
+ // CHECK: return %[[R0]] : complex<f32>
+ func.return %1 : complex<f32>
+}
+
+// CHECK-LABEL: func @double_reverse_bitcast
+// CHECK-SAME: %[[ARG0:.*]]: complex<f32>
+func.func @double_reverse_bitcast(%arg0 : complex<f32>) -> f64 {
+ // CHECK: %[[R0:.+]] = complex.bitcast %[[ARG0]]
+ %0 = complex.bitcast %arg0 : complex<f32> to i64
+ %1 = arith.bitcast %0 : i64 to f64
+ // CHECK: return %[[R0]] : f64
+ func.return %1 : f64
+}
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 591ebe79f7b779..51b1b0fda202a0 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -21,3 +21,11 @@ func.func @complex_constant_two_
diff erent_element_types() {
%0 = complex.constant [1.0 : f32, -1.0 : f64] : complex<f64>
return
}
+
+// -----
+
+func.func @complex_bitcast_i64(%arg0 : i64) {
+ // expected-error @+1 {{op requires input or output is a complex type}}
+ %0 = complex.bitcast %arg0: i64 to f64
+ return
+}
diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index a78ad3efaa34aa..1050ad0dcd5305 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -83,5 +83,8 @@ func.func @ops(%f: f32) {
// CHECK: complex.tan %[[C]] : complex<f32>
%tan = complex.tan %complex : complex<f32>
+ // CHECK: complex.bitcast %[[C]]
+ %i64 = complex.bitcast %complex : complex<f32> to i64
+
return
}
More information about the Mlir-commits
mailing list