[Mlir-commits] [mlir] 3036382 - [mlir][linalg] Add lowering of named ops on complex numbers
Benjamin Kramer
llvmlistbot at llvm.org
Thu May 12 04:46:57 PDT 2022
Author: Benjamin Kramer
Date: 2022-05-12T13:37:34+02:00
New Revision: 303638248ab1299b38cac2c76260a92202005642
URL: https://github.com/llvm/llvm-project/commit/303638248ab1299b38cac2c76260a92202005642
DIFF: https://github.com/llvm/llvm-project/commit/303638248ab1299b38cac2c76260a92202005642.diff
LOG: [mlir][linalg] Add lowering of named ops on complex numbers
This lets linalg.dot and friends lower to a complex muladd using ops
from the complex dialect.
Differential Revision: https://reviews.llvm.org/D125461
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index fc99e290e5cc4..4729dea98402c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExprVisitor.h"
@@ -127,7 +128,8 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
return MatchContractionResult::NotProjectedPermutations;
// TODO: more fields than add/mul.
if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
- !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()))
+ !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
+ !isAddMul<complex::AddOp, complex::MulOp>(linalgOp->getRegion(0).front()))
return MatchContractionResult::NotAddMul;
return MatchContractionResult::Success;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0584043d80933..6d5378201742f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
@@ -320,37 +321,48 @@ class RegionBuilderHelper {
// Build the binary functions defined by OpDSL.
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+ bool allComplex = isComplex(arg0) && isComplex(arg1);
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
bool allInteger = isInteger(arg0) && isInteger(arg1);
- if (!allFloatingPoint && !allInteger)
+ if (!allComplex && !allFloatingPoint && !allInteger)
llvm_unreachable("unsupported non numeric type");
OpBuilder builder = getBuilder();
switch (binaryFn) {
case BinaryFn::add:
+ if (allComplex)
+ return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::sub:
+ if (allComplex)
+ return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::mul:
+ if (allComplex)
+ return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
if (allFloatingPoint)
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max_signed:
+ assert(!allComplex);
if (allFloatingPoint)
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::min_signed:
+ assert(!allComplex);
if (allFloatingPoint)
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
+ assert(!allComplex);
if (allFloatingPoint)
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
case BinaryFn::min_unsigned:
+ assert(!allComplex);
if (allFloatingPoint)
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
@@ -447,6 +459,7 @@ class RegionBuilderHelper {
return operand;
}
+ bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index a61b4fbc916a8..86bd070c8835d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -49,6 +49,29 @@ func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>,
// -----
+func.func @generalize_matmul_tensor_complex(%A : tensor<16x8xcomplex<f32>>,
+ %B: tensor<8x32xcomplex<f32>>,
+ %C: tensor<16x32xcomplex<f32>>)
+ -> tensor<16x32xcomplex<f32>> {
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xcomplex<f32>>, tensor<8x32xcomplex<f32>>)
+ outs(%C: tensor<16x32xcomplex<f32>>) -> tensor<16x32xcomplex<f32>>
+ return %0: tensor<16x32xcomplex<f32>>
+}
+
+// CHECK: func @generalize_matmul_tensor_complex
+
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<16x8xcomplex<f32>>, tensor<8x32xcomplex<f32>>)
+// CHECK-SAME: outs(%{{.+}} : tensor<16x32xcomplex<f32>>)
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: complex<f32>, %[[B_ARG:.+]]: complex<f32>, %[[C_ARG:.+]]: complex<f32>)
+// CHECK-NEXT: %[[MUL:.+]] = complex.mul %[[A_ARG]], %[[B_ARG]] : complex<f32>
+// CHECK-NEXT: %[[ADD:.+]] = complex.add %[[C_ARG]], %[[MUL]] : complex<f32>
+// CHECK-NEXT: linalg.yield %[[ADD]] : complex<f32>
+// CHECK-NEXT: -> tensor<16x32xcomplex<f32>>
+
+// -----
+
func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
linalg.depthwise_conv_2d_nhwc_hwcm
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2138b73100a77..ea947228155dd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1203,8 +1203,8 @@ cc_library(
hdrs = ["include/mlir/Dialect/AMDGPU/AMDGPUDialect.h"],
includes = ["include"],
deps = [
- ":IR",
":AMDGPUIncGen",
+ ":IR",
":SideEffectInterfaces",
"//llvm:Core",
"//llvm:Support",
@@ -2448,8 +2448,8 @@ cc_library(
hdrs = ["include/mlir/Conversion/Passes.h"],
includes = ["include"],
deps = [
- ":AffineToStandard",
":AMDGPUToROCDL",
+ ":AffineToStandard",
":ArithmeticToLLVM",
":ArithmeticToSPIRV",
":ArmNeon2dToIntr",
@@ -2646,6 +2646,7 @@ cc_library(
deps = [
":Affine",
":ArithmeticDialect",
+ ":ComplexDialect",
":DialectUtils",
":IR",
":InferTypeOpInterface",
@@ -3693,12 +3694,12 @@ cc_library(
]),
includes = ["include"],
deps = [
+ ":AMDGPU",
":ConversionPassIncGen",
":IR",
":LLVMCommonConversion",
- ":AMDGPU",
- ":ROCDLDialect",
":Pass",
+ ":ROCDLDialect",
":Transforms",
"//llvm:Support",
],
@@ -3799,8 +3800,8 @@ cc_library(
hdrs = ["include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"],
includes = ["include"],
deps = [
- ":ArithmeticToLLVM",
":AMDGPUToROCDL",
+ ":ArithmeticToLLVM",
":ControlFlowToLLVM",
":ConversionPassIncGen",
":FuncDialect",
@@ -6133,14 +6134,14 @@ cc_library(
"include/mlir/InitAllPasses.h",
],
deps = [
+ ":AMDGPU",
+ ":AMDGPUToROCDL",
":AMX",
":AMXTransforms",
":Affine",
":AffinePassIncGen",
":AffineToStandard",
":AffineTransforms",
- ":AMDGPU",
- ":AMDGPUToROCDL",
":ArithmeticDialect",
":ArithmeticToLLVM",
":ArithmeticToSPIRV",
@@ -7300,6 +7301,7 @@ cc_library(
":ArithmeticDialect",
":ArithmeticUtils",
":BufferizationDialect",
+ ":ComplexDialect",
":ControlFlowInterfaces",
":CopyOpInterface",
":DialectUtils",
More information about the Mlir-commits
mailing list