[Mlir-commits] [mlir] f513b70 - [mlir] Add `ComplexType` conversion support for `convertScalarToDtype`
Rob Suderman
llvmlistbot at llvm.org
Mon Jul 17 14:05:30 PDT 2023
Author: Rob Suderman
Date: 2023-07-17T14:00:58-07:00
New Revision: f513b70d43c39e830888c9aa5a4765b449e8c4ad
URL: https://github.com/llvm/llvm-project/commit/f513b70d43c39e830888c9aa5a4765b449e8c4ad
DIFF: https://github.com/llvm/llvm-project/commit/f513b70d43c39e830888c9aa5a4765b449e8c4ad.diff
LOG: [mlir] Add `ComplexType` conversion support for `convertScalarToDtype`
Linalg operations can include `complex` types in the src/target types.
This should include conversion between `arith` and `complex` types when
constructing `linalg` operations.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D154740
Added:
Modified:
mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
mlir/lib/Dialect/Arith/Utils/Utils.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
index 6767050ede6135..2be2724d4a9172 100644
--- a/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Utils/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRArithUtils
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRComplexDialect
MLIRDialect
MLIRIR
)
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 965ef117d79efa..d3e61030d8183a 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
@@ -84,45 +86,122 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
}
-Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
- Type toType, bool isUnsignedCast) {
- if (operand.getType() == toType)
- return operand;
- if (auto toIntType = dyn_cast<IntegerType>(toType)) {
- // If operand is floating point, cast directly to the int type.
- if (isa<FloatType>(operand.getType())) {
- if (isUnsignedCast)
- return b.create<arith::FPToUIOp>(loc, toType, operand);
- return b.create<arith::FPToSIOp>(loc, toType, operand);
+static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand,
+ IntegerType toType, bool isUnsigned) {
+ // If operand is floating point, cast directly to the int type.
+ if (isa<FloatType>(operand.getType())) {
+ if (isUnsigned)
+ return b.create<arith::FPToUIOp>(toType, operand);
+ return b.create<arith::FPToSIOp>(toType, operand);
+ }
+ // Cast index operands directly to the int type.
+ if (operand.getType().isIndex())
+ return b.create<arith::IndexCastOp>(toType, operand);
+ if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
+ // Either extend or truncate.
+ if (toType.getWidth() > fromIntType.getWidth()) {
+ if (isUnsigned)
+ return b.create<arith::ExtUIOp>(toType, operand);
+ return b.create<arith::ExtSIOp>(toType, operand);
}
- // Cast index operands directly to the int type.
- if (operand.getType().isIndex())
- return b.create<arith::IndexCastOp>(loc, toType, operand);
- if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
- // Either extend or truncate.
- if (toIntType.getWidth() > fromIntType.getWidth()) {
- if (isUnsignedCast)
- return b.create<arith::ExtUIOp>(loc, toType, operand);
- return b.create<arith::ExtSIOp>(loc, toType, operand);
+ if (toType.getWidth() < fromIntType.getWidth())
+ return b.create<arith::TruncIOp>(toType, operand);
+ return operand;
+ }
+
+ return {};
+}
+
+static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand,
+ FloatType toType, bool isUnsigned) {
+ // If operand is integer, cast directly to the float type.
+ // Note that it is unclear how to cast from BF16<->FP16.
+ if (isa<IntegerType>(operand.getType())) {
+ if (isUnsigned)
+ return b.create<arith::UIToFPOp>(toType, operand);
+ return b.create<arith::SIToFPOp>(toType, operand);
+ }
+ if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
+ if (toType.getWidth() > fromFpTy.getWidth())
+ return b.create<arith::ExtFOp>(toType, operand);
+ if (toType.getWidth() < fromFpTy.getWidth())
+ return b.create<arith::TruncFOp>(toType, operand);
+ return operand;
+ }
+
+ return {};
+}
+
+static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
+ ComplexType targetType,
+ bool isUnsigned) {
+ if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
+ if (isa<FloatType>(targetType.getElementType()) &&
+ isa<FloatType>(fromComplexType.getElementType())) {
+ Value real = b.create<complex::ReOp>(operand);
+ Value imag = b.create<complex::ImOp>(operand);
+ Type targetETy = targetType.getElementType();
+ if (targetType.getElementType().getIntOrFloatBitWidth() <
+ fromComplexType.getElementType().getIntOrFloatBitWidth()) {
+ real = b.create<arith::TruncFOp>(targetETy, real);
+ imag = b.create<arith::TruncFOp>(targetETy, imag);
+ } else {
+ real = b.create<arith::ExtFOp>(targetETy, real);
+ imag = b.create<arith::ExtFOp>(targetETy, imag);
}
- if (toIntType.getWidth() < fromIntType.getWidth())
- return b.create<arith::TruncIOp>(loc, toType, operand);
+ return b.create<complex::CreateOp>(targetType, real, imag);
}
- } else if (auto toFloatType = dyn_cast<FloatType>(toType)) {
- // If operand is integer, cast directly to the float type.
- // Note that it is unclear how to cast from BF16<->FP16.
- if (isa<IntegerType>(operand.getType())) {
- if (isUnsignedCast)
- return b.create<arith::UIToFPOp>(loc, toFloatType, operand);
- return b.create<arith::SIToFPOp>(loc, toFloatType, operand);
+ }
+
+ if (auto fromFpType = dyn_cast<FloatType>(operand.getType())) {
+ FloatType toFpTy = cast<FloatType>(targetType.getElementType());
+ auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
+ Value from = operand;
+ if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
+ from = b.create<arith::ExtFOp>(toFpTy, from);
}
- if (auto fromFloatType = dyn_cast<FloatType>(operand.getType())) {
- if (toFloatType.getWidth() > fromFloatType.getWidth())
- return b.create<arith::ExtFOp>(loc, toFloatType, operand);
- if (toFloatType.getWidth() < fromFloatType.getWidth())
- return b.create<arith::TruncFOp>(loc, toFloatType, operand);
+ if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
+ from = b.create<arith::TruncFOp>(toFpTy, from);
+ }
+ Value zero = b.create<mlir::arith::ConstantFloatOp>(
+ mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
+ return b.create<complex::CreateOp>(targetType, from, zero);
+ }
+
+ if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
+ FloatType toFpTy = cast<FloatType>(targetType.getElementType());
+ Value from = operand;
+ if (isUnsigned) {
+ from = b.create<arith::UIToFPOp>(toFpTy, from);
+ } else {
+ from = b.create<arith::SIToFPOp>(toFpTy, from);
}
+ Value zero = b.create<mlir::arith::ConstantFloatOp>(
+ mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
+ return b.create<complex::CreateOp>(targetType, from, zero);
}
+
+ return {};
+}
+
+Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
+ Type toType, bool isUnsignedCast) {
+ if (operand.getType() == toType)
+ return operand;
+ ImplicitLocOpBuilder ib(loc, b);
+ Value result;
+ if (auto intTy = dyn_cast<IntegerType>(toType)) {
+ result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
+ } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
+ result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
+ } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
+ result =
+ convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
+ }
+
+ if (result)
+ return result;
+
emitWarning(loc) << "could not cast operand of type " << operand.getType()
<< " to " << toType;
return operand;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 036d76c12eba6e..275e78aaa73dde 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -29,9 +30,10 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
}
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
- bool isInt = isa<IntegerType>(x.getType());
- if (isInt)
+ if (isa<IntegerType>(x.getType()))
return builder.create<arith::AddIOp>(loc, x, y);
+ if (isa<ComplexType>(x.getType()))
+ return builder.create<complex::AddOp>(loc, x, y);
return builder.create<arith::AddFOp>(loc, x, y);
}
@@ -42,6 +44,8 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
Value yConvert =
convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
+ if (isa<ComplexType>(accType))
+ return builder.create<complex::MulOp>(loc, xConvert, yConvert);
if (isa<IntegerType>(accType))
return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
@@ -111,7 +115,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
// Reshape output and filter to the LHS and result of a (B)MNK matmul.
SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
auto reshapedFilterType =
- RankedTensorType::get({fh * fw * ic, oc}, inputType.getElementType());
+ RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
loc, reshapedFilterType, filter, filterReassocIndices);
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index d87b4a3956a751..657cf83f25460f 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -314,3 +314,119 @@ transform.sequence failures(propagate) {
transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
transform.print %transformed {name = "transformed"}: !transform.any_op
}
+
+// -----
+
+// Check for compatible complex case.
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: @conv_complex
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex<f32>>, tensor<36x16xcomplex<f32>>)
+// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex<f32>>)
+// CHECK: ^bb0(%[[ARG0:.+]]: complex<f32>, %[[ARG1:.+]]: complex<f32>, %[[ARG2:.+]]: complex<f32>)
+// CHECK: %[[MUL:.+]] = complex.mul %[[ARG0]], %[[ARG1]] : complex<f32>
+// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
+// CHECK: linalg.yield %[[ADD]] : complex<f32>
+// CHECK: } -> tensor<1x196x16xcomplex<f32>>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
+// CHECK: return %[[RESULT]]
+
+func.func @conv_complex(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f32>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xcomplex<f32>>, tensor<3x3x4x16xcomplex<f32>>)
+ outs(%arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>>
+ return %0 : tensor<1x14x14x16xcomplex<f32>>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
+ transform.print %transformed {name = "transformed"}: !transform.any_op
+}
+
+// -----
+
+// Check for compatible complex extended case.
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: @conv_complex_extended
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex<f32>>, tensor<36x16xcomplex<f16>>)
+// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex<f32>>)
+// CHECK: ^bb0(%[[ARG0:.+]]: complex<f32>, %[[ARG1:.+]]: complex<f16>, %[[ARG2:.+]]: complex<f32>)
+// CHECK: %[[REAL:.+]] = complex.re %[[ARG1]] : complex<f16>
+// CHECK: %[[IMAG:.+]] = complex.im %[[ARG1]] : complex<f16>
+// CHECK: %[[REEXT:.+]] = arith.extf %[[REAL]] : f16 to f32
+// CHECK: %[[IMEXT:.+]] = arith.extf %[[IMAG]] : f16 to f32
+// CHECK: %[[COMPLEX:.+]] = complex.create %[[REEXT]], %[[IMEXT]] : complex<f32>
+// CHECK: %[[MUL:.+]] = complex.mul %[[ARG0]], %[[COMPLEX]] : complex<f32>
+// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
+// CHECK: linalg.yield %[[ADD]] : complex<f32>
+// CHECK: } -> tensor<1x196x16xcomplex<f32>>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
+// CHECK: return %[[RESULT]]
+
+func.func @conv_complex_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f16>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xcomplex<f32>>, tensor<3x3x4x16xcomplex<f16>>)
+ outs(%arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>>
+ return %0 : tensor<1x14x14x16xcomplex<f32>>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
+ transform.print %transformed {name = "transformed"}: !transform.any_op
+}
+
+// -----
+
+// Check for compatible complex extended case.
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK: @conv_complex_f16_extended
+// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xcomplex<f32>>, tensor<36x16xf16>)
+// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xcomplex<f32>>)
+// CHECK: ^bb0(%[[ARG0:.+]]: complex<f32>, %[[ARG1:.+]]: f16, %[[ARG2:.+]]: complex<f32>)
+// CHECK: %[[EXT:.+]] = arith.extf %[[ARG1]] : f16 to f32
+// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[COMPLEX:.+]] = complex.create %[[EXT]], %[[ZERO]]
+// CHECK: %[[MUL:.+]] = complex.mul %[[ARG0]], %[[COMPLEX]] : complex<f32>
+// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
+// CHECK: linalg.yield %[[ADD]] : complex<f32>
+// CHECK: } -> tensor<1x196x16xcomplex<f32>>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
+// CHECK: return %[[RESULT]]
+
+func.func @conv_complex_f16_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xf16>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
+ %0 = linalg.conv_2d_nhwc_hwcf
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xcomplex<f32>>, tensor<3x3x4x16xf16>)
+ outs(%arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>>
+ return %0 : tensor<1x14x14x16xcomplex<f32>>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.print %img2col_tensor_producer {name = "tensor_producer"}: !transform.any_op
+ transform.print %transformed {name = "transformed"}: !transform.any_op
+}
More information about the Mlir-commits
mailing list