[Mlir-commits] [mlir] 3036382 - [mlir][linalg] Add lowering of named ops on complex numbers

Benjamin Kramer llvmlistbot at llvm.org
Thu May 12 04:46:57 PDT 2022


Author: Benjamin Kramer
Date: 2022-05-12T13:37:34+02:00
New Revision: 303638248ab1299b38cac2c76260a92202005642

URL: https://github.com/llvm/llvm-project/commit/303638248ab1299b38cac2c76260a92202005642
DIFF: https://github.com/llvm/llvm-project/commit/303638248ab1299b38cac2c76260a92202005642.diff

LOG: [mlir][linalg] Add lowering of named ops on complex numbers

This lets linalg.dot and friends lower to a complex muladd using ops
from the complex dialect.

Differential Revision: https://reviews.llvm.org/D125461

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index fc99e290e5cc4..4729dea98402c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -127,7 +128,8 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
     return MatchContractionResult::NotProjectedPermutations;
   // TODO: more fields than add/mul.
   if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
-      !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()))
+      !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
+      !isAddMul<complex::AddOp, complex::MulOp>(linalgOp->getRegion(0).front()))
     return MatchContractionResult::NotAddMul;
   return MatchContractionResult::Success;
 }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0584043d80933..6d5378201742f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
@@ -320,37 +321,48 @@ class RegionBuilderHelper {
 
   // Build the binary functions defined by OpDSL.
   Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+    bool allComplex = isComplex(arg0) && isComplex(arg1);
     bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
     bool allInteger = isInteger(arg0) && isInteger(arg1);
-    if (!allFloatingPoint && !allInteger)
+    if (!allComplex && !allFloatingPoint && !allInteger)
       llvm_unreachable("unsupported non numeric type");
     OpBuilder builder = getBuilder();
     switch (binaryFn) {
     case BinaryFn::add:
+      if (allComplex)
+        return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::sub:
+      if (allComplex)
+        return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::mul:
+      if (allComplex)
+        return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_signed:
+      assert(!allComplex);
       if (allFloatingPoint)
         return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::min_signed:
+      assert(!allComplex);
       if (allFloatingPoint)
         return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_unsigned:
+      assert(!allComplex);
       if (allFloatingPoint)
         return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::min_unsigned:
+      assert(!allComplex);
       if (allFloatingPoint)
         return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
       return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
@@ -447,6 +459,7 @@ class RegionBuilderHelper {
     return operand;
   }
 
+  bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
   bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
   bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index a61b4fbc916a8..86bd070c8835d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -49,6 +49,29 @@ func.func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>,
 
 // -----
 
+func.func @generalize_matmul_tensor_complex(%A : tensor<16x8xcomplex<f32>>,
+                                            %B: tensor<8x32xcomplex<f32>>,
+                                            %C: tensor<16x32xcomplex<f32>>)
+          -> tensor<16x32xcomplex<f32>> {
+  %0 = linalg.matmul ins(%A, %B: tensor<16x8xcomplex<f32>>, tensor<8x32xcomplex<f32>>)
+                    outs(%C: tensor<16x32xcomplex<f32>>) -> tensor<16x32xcomplex<f32>>
+  return %0: tensor<16x32xcomplex<f32>>
+}
+
+// CHECK: func @generalize_matmul_tensor_complex
+
+// CHECK: linalg.generic
+// CHECK-SAME:  ins(%{{.+}}, %{{.+}} : tensor<16x8xcomplex<f32>>, tensor<8x32xcomplex<f32>>)
+// CHECK-SAME: outs(%{{.+}} : tensor<16x32xcomplex<f32>>)
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: complex<f32>, %[[B_ARG:.+]]: complex<f32>, %[[C_ARG:.+]]: complex<f32>)
+// CHECK-NEXT:   %[[MUL:.+]] = complex.mul %[[A_ARG]], %[[B_ARG]] : complex<f32>
+// CHECK-NEXT:   %[[ADD:.+]] = complex.add %[[C_ARG]], %[[MUL]] : complex<f32>
+// CHECK-NEXT:   linalg.yield %[[ADD]] : complex<f32>
+// CHECK-NEXT: -> tensor<16x32xcomplex<f32>>
+
+// -----
+
 func.func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
   linalg.depthwise_conv_2d_nhwc_hwcm
      { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2138b73100a77..ea947228155dd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1203,8 +1203,8 @@ cc_library(
     hdrs = ["include/mlir/Dialect/AMDGPU/AMDGPUDialect.h"],
     includes = ["include"],
     deps = [
-        ":IR",
         ":AMDGPUIncGen",
+        ":IR",
         ":SideEffectInterfaces",
         "//llvm:Core",
         "//llvm:Support",
@@ -2448,8 +2448,8 @@ cc_library(
     hdrs = ["include/mlir/Conversion/Passes.h"],
     includes = ["include"],
     deps = [
-        ":AffineToStandard",
         ":AMDGPUToROCDL",
+        ":AffineToStandard",
         ":ArithmeticToLLVM",
         ":ArithmeticToSPIRV",
         ":ArmNeon2dToIntr",
@@ -2646,6 +2646,7 @@ cc_library(
     deps = [
         ":Affine",
         ":ArithmeticDialect",
+        ":ComplexDialect",
         ":DialectUtils",
         ":IR",
         ":InferTypeOpInterface",
@@ -3693,12 +3694,12 @@ cc_library(
     ]),
     includes = ["include"],
     deps = [
+        ":AMDGPU",
         ":ConversionPassIncGen",
         ":IR",
         ":LLVMCommonConversion",
-        ":AMDGPU",
-        ":ROCDLDialect",
         ":Pass",
+        ":ROCDLDialect",
         ":Transforms",
         "//llvm:Support",
     ],
@@ -3799,8 +3800,8 @@ cc_library(
     hdrs = ["include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"],
     includes = ["include"],
     deps = [
-        ":ArithmeticToLLVM",
         ":AMDGPUToROCDL",
+        ":ArithmeticToLLVM",
         ":ControlFlowToLLVM",
         ":ConversionPassIncGen",
         ":FuncDialect",
@@ -6133,14 +6134,14 @@ cc_library(
         "include/mlir/InitAllPasses.h",
     ],
     deps = [
+        ":AMDGPU",
+        ":AMDGPUToROCDL",
         ":AMX",
         ":AMXTransforms",
         ":Affine",
         ":AffinePassIncGen",
         ":AffineToStandard",
         ":AffineTransforms",
-        ":AMDGPU",
-        ":AMDGPUToROCDL",
         ":ArithmeticDialect",
         ":ArithmeticToLLVM",
         ":ArithmeticToSPIRV",
@@ -7300,6 +7301,7 @@ cc_library(
         ":ArithmeticDialect",
         ":ArithmeticUtils",
         ":BufferizationDialect",
+        ":ComplexDialect",
         ":ControlFlowInterfaces",
         ":CopyOpInterface",
         ":DialectUtils",


        


More information about the Mlir-commits mailing list