[Mlir-commits] [mlir] a8df21f - [mlir][complex] Add a `complex.bitcast` operation

Rob Suderman llvmlistbot at llvm.org
Thu Jul 6 16:22:55 PDT 2023


Author: Rob Suderman
Date: 2023-07-06T16:21:57-07:00
New Revision: a8df21f433d2f7cd267ac6e5581163f15501918b

URL: https://github.com/llvm/llvm-project/commit/a8df21f433d2f7cd267ac6e5581163f15501918b
DIFF: https://github.com/llvm/llvm-project/commit/a8df21f433d2f7cd267ac6e5581163f15501918b.diff

LOG: [mlir][complex] Add a `complex.bitcast` operation

Converting between a complex<f32> to i64 could be useful for handling interop
between the `arith` and `complex` dialects.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D154663

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/test/Dialect/Complex/canonicalize.mlir
    mlir/test/Dialect/Complex/invalid.mlir
    mlir/test/Dialect/Complex/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 0e962df658ec36..b80d77996a20f5 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -95,6 +95,31 @@ def Atan2Op : ComplexArithmeticOp<"atan2"> {
   }];
 }
 
+
+//===----------------------------------------------------------------------===//
+// Bitcast
+//===----------------------------------------------------------------------===//
+
+def BitcastOp : Complex_Op<"bitcast", [Pure]> {
+
+  let summary = "computes bitcast between complex and equal arith types";
+  let description = [{
+
+    Example:
+
+    ```mlir
+         %a = complex.bitcast %b : complex<f32> -> i64
+    ```
+  }];
+  let assemblyFormat = "$operand attr-dict `:` type($operand) `to` type($result)";
+  let arguments = (ins AnyType:$operand);
+  let results = (outs AnyType:$result);
+
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 6041f494b930f7..8fd914dd107ffb 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -72,6 +72,109 @@ LogicalResult ConstantOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// BitcastOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
+  if (getOperand().getType() == getType())
+    return getOperand();
+
+  return {};
+}
+
+LogicalResult BitcastOp::verify() {
+  auto operandType = getOperand().getType();
+  auto resultType = getType();
+
+  // We allow this to be legal as it can be folded away.
+  if (operandType == resultType)
+    return success();
+
+  if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
+    return emitOpError("operand must be int/float/complex");
+  }
+
+  if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
+    return emitOpError("result must be int/float/complex");
+  }
+
+  if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
+    return emitOpError("requires input or output is a complex type");
+  }
+
+  if (isa<ComplexType>(resultType))
+    std::swap(operandType, resultType);
+
+  int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
+                                .getElementType()
+                                .getIntOrFloatBitWidth() *
+                            2;
+  int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
+
+  if (operandBitwidth != resultBitwidth) {
+    return emitOpError("casting bitwidths do not match");
+  }
+
+  return success();
+}
+
+struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
+  using OpRewritePattern<BitcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BitcastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
+      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+                                             defining.getOperand());
+      return success();
+    }
+
+    if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
+      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+                                             defining.getOperand());
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
+  using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::BitcastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
+      rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
+                                                      defining.getOperand());
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+struct ArithBitcast final : OpRewritePattern<BitcastOp> {
+  using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BitcastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (isa<ComplexType>(op.getType()) ||
+        isa<ComplexType>(op.getOperand().getType()))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+                                                  op.getOperand());
+    return success();
+  }
+};
+
+void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                            MLIRContext *context) {
+  results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // CreateOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index e36805691e3fdf..64c3f313dda959 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -232,3 +232,52 @@ func.func @mul_one_f128(%arg0: f128, %arg1: f128) -> complex<f128> {
   // CHECK-NEXT: return %[[CREATE]]
   return %mul : complex<f128>
 }
+
+// CHECK-LABEL: func @fold_between_complex
+//  CHECK-SAME: %[[ARG0:.*]]: complex<f32>
+func.func @fold_between_complex(%arg0 : complex<f32>) -> complex<f32> {
+  %0 = complex.bitcast %arg0 : complex<f32> to i64
+  %1 = complex.bitcast %0 : i64 to complex<f32>
+  // CHECK: return %[[ARG0]] : complex<f32>
+  func.return %1 : complex<f32>
+}
+
+// CHECK-LABEL: func @fold_between_i64
+//  CHECK-SAME: %[[ARG0:.*]]: i64
+func.func @fold_between_i64(%arg0 : i64) -> i64 {
+  %0 = complex.bitcast %arg0 : i64 to complex<f32>
+  %1 = complex.bitcast %0 : complex<f32> to i64
+  // CHECK: return %[[ARG0]] : i64
+  func.return %1 : i64
+}
+
+// CHECK-LABEL: func @canon_arith_bitcast
+//  CHECK-SAME: %[[ARG0:.*]]: f64
+func.func @canon_arith_bitcast(%arg0 : f64) -> i64 {
+  %0 = complex.bitcast %arg0 : f64 to complex<f32>
+  %1 = complex.bitcast %0 : complex<f32> to i64
+  // CHECK: %[[R0:.+]] = arith.bitcast %[[ARG0]]
+  // CHECK: return %[[R0]] : i64
+  func.return %1 : i64
+}
+
+
+// CHECK-LABEL: func @double_bitcast
+//  CHECK-SAME: %[[ARG0:.*]]: f64
+func.func @double_bitcast(%arg0 : f64) -> complex<f32> {
+  // CHECK: %[[R0:.+]] = complex.bitcast %[[ARG0]]
+  %0 = arith.bitcast %arg0 : f64 to i64
+  %1 = complex.bitcast %0 : i64 to complex<f32>
+  // CHECK: return %[[R0]] : complex<f32>
+  func.return %1 : complex<f32>
+}
+
+// CHECK-LABEL: func @double_reverse_bitcast
+//  CHECK-SAME: %[[ARG0:.*]]: complex<f32>
+func.func @double_reverse_bitcast(%arg0 : complex<f32>) -> f64 {
+  // CHECK: %[[R0:.+]] = complex.bitcast %[[ARG0]]
+  %0 = complex.bitcast %arg0 : complex<f32> to i64
+  %1 = arith.bitcast %0 : i64 to f64
+  // CHECK: return %[[R0]] : f64
+  func.return %1 : f64
+}

diff  --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 591ebe79f7b779..51b1b0fda202a0 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -21,3 +21,11 @@ func.func @complex_constant_two_
diff erent_element_types() {
   %0 = complex.constant [1.0 : f32, -1.0 : f64] : complex<f64>
   return
 }
+
+// -----
+
+func.func @complex_bitcast_i64(%arg0 : i64) {
+  // expected-error @+1 {{op requires input or output is a complex type}}
+  %0 = complex.bitcast %arg0: i64 to f64
+  return
+}

diff  --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index a78ad3efaa34aa..1050ad0dcd5305 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -83,5 +83,8 @@ func.func @ops(%f: f32) {
   // CHECK: complex.tan %[[C]] : complex<f32>
   %tan = complex.tan %complex : complex<f32>
 
+  // CHECK: complex.bitcast %[[C]]
+  %i64 = complex.bitcast %complex : complex<f32> to i64
+
   return
 }


        


More information about the Mlir-commits mailing list