[Mlir-commits] [mlir] abfa950 - [mlir][linalg][python] Add exp and log to the OpDSL.

Tobias Gysi llvmlistbot at llvm.org
Thu Jul 8 02:11:09 PDT 2021


Author: Tobias Gysi
Date: 2021-07-08T08:48:23Z
New Revision: abfa950d86da1737a7dd52ba262fa39dd2e937fa

URL: https://github.com/llvm/llvm-project/commit/abfa950d86da1737a7dd52ba262fa39dd2e937fa
DIFF: https://github.com/llvm/llvm-project/commit/abfa950d86da1737a7dd52ba262fa39dd2e937fa.diff

LOG: [mlir][linalg][python] Add exp and log to the OpDSL.

Introduce the exp and log function in OpDSL. Add the soft plus operator to test the emitted IR in Python and C++.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D105420

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
    mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 092d22983d3f2..49ececc0790aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -33,8 +33,8 @@ def Linalg_Dialect : Dialect {
   }];
   let cppNamespace = "::mlir::linalg";
   let dependentDialects = [
-    "AffineDialect", "memref::MemRefDialect", "StandardOpsDialect",
-    "tensor::TensorDialect"
+    "AffineDialect", "math::MathDialect", "memref::MemRefDialect",
+    "StandardOpsDialect", "tensor::TensorDialect"
   ];
   let hasCanonicalizer = 1;
   let hasOperationAttrVerify = 1;

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 1e4277ecd7bdf..04f9776005c4e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -887,3 +887,58 @@ structured_op: !LinalgStructuredOpConfig
                           scalar_const: '2.3283063999999999E-10 : f64'
             - !ScalarExpression
               scalar_arg: min
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: soft_plus_2d
+  cpp_class_name: SoftPlus2DOp
+  doc: |-
+    Implements the soft plus operator.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T
+    shape_map: affine_map<()[s0, s1] -> (s0, s1)>
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1] -> (s0, s1)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: log
+        operands:
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: add
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_const: '1.000000e+00 : f64'
+            - !ScalarExpression
+              scalar_apply:
+                fn_name: exp
+                operands:
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index c5cfdd15c00a8..f5913e6ad6164 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LINALG_LINALGTYPES_H_
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"

diff  --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index 21104281b8120..14187f400e726 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRLinalg
   MLIRSideEffectInterfaces
   MLIRViewLikeInterface
   MLIRStandard
+  MLIRMath
   MLIRMemRef
   MLIRTensor
   )

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 93062b10ccc63..ea12a312d9c01 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -256,6 +256,20 @@ class RegionBuilderHelper {
     llvm_unreachable("unsupported non numeric type");
   }
 
+  Value applyfn__exp(Value x) {
+    OpBuilder builder = getBuilder();
+    if (isFloatingPoint(x))
+      return builder.create<math::ExpOp>(x.getLoc(), x);
+    llvm_unreachable("unsupported non numeric type");
+  }
+
+  Value applyfn__log(Value x) {
+    OpBuilder builder = getBuilder();
+    if (isFloatingPoint(x))
+      return builder.create<math::LogOp>(x.getLoc(), x);
+    llvm_unreachable("unsupported non numeric type");
+  }
+
   Value applyfn__sub(Value lhs, Value rhs) {
     OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 61d2260587116..3810df9dff74a 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -7,6 +7,7 @@
 from mlir.ir import *
 from mlir.dialects import linalg
 from mlir.dialects import std
+from mlir.dialects import math
 # TODO: resolve name collision for Linalg functionality that is injected inside
 # the _mlir.dialects.linalg directly via pybind.
 from _mlir.dialects.linalg import fill_builtin_region
@@ -293,6 +294,16 @@ def _eval_add(self, lhs: Value, rhs: Value) -> Value:
       return std.AddIOp(lhs.type, lhs, rhs).result
     raise NotImplementedError("Unsupported 'add' operand: {lhs}")
 
+  def _eval_exp(self, x: Value) -> Value:
+    if _is_floating_point_type(x.type):
+      return math.ExpOp(x.type, x).result
+    raise NotImplementedError("Unsupported 'exp' operand: {x}")
+
+  def _eval_log(self, x: Value) -> Value:
+    if _is_floating_point_type(x.type):
+      return math.LogOp(x.type, x).result
+    raise NotImplementedError("Unsupported 'log' operand: {x}")
+
   def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
       return std.SubFOp(lhs.type, lhs, rhs).result

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index a37e1944c1f75..72793cbf9c726 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -209,3 +209,16 @@ def fill_rng_2d(
   offset = cast(F64, const(2147483647))
   scaling = (max - min) * inv_range
   O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
+
+
+ at linalg_structured_op
+def soft_plus_2d(
+    I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
+  """Implements the soft plus operator.
+
+  Numeric casting is performed on the input operand, promoting it to the same
+  data type as the accumulator/output.
+  """
+  domain(D.m, D.n)
+  O[D.m, D.n] = \
+      PrimFn.log(cast(U, const(1.0)) + PrimFn.exp(cast(U, I[D.m, D.n])))

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 0e1c6a62a7b10..aed3585d4f547 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -188,6 +188,23 @@ func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16
 // CHECK-NEXT:   linalg.yield %[[VAL6]] : i32
 // CHECK-NEXT: -> tensor<16x32xi32>
 
+// -----
+
+func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_soft_plus_2d_f32
+//      CHECK: %[[C1:.+]] = constant 1.000000e+00 : f64
+//      CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
+// CHECK-NEXT:   %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32
+// CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
+// CHECK-NEXT:   %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32
+// CHECK-NEXT:   %[[LOG:.+]] = math.log %[[SUM]] : f32
+// CHECK-NEXT:   linalg.yield %[[LOG]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
 // -----
 // Verifies floating point to integer cast.
 func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 44ac4e8e8c5b4..ed33644859012 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -84,6 +84,13 @@ def fill_rng_poly(
   O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
 
 
+ at linalg_structured_op
+def soft_plus_poly(
+    I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
+  O[D.m, D.n] = \
+      PrimFn.log(cast(U, const(1.0)) + cast(U, PrimFn.exp(I[D.m, D.n])))
+
+
 with Context() as ctx, Location.unknown():
   module = Module.create()
   f16 = F16Type.get()
@@ -299,5 +306,19 @@ def test_f32f32_min_pooling(input, shape, init_result):
     def test_i32_fill_rng(min, max, seed, init_result):
       return fill_rng_poly(min, max, seed, outs=[init_result])
 
+    # CHECK-LABEL: @test_f32_soft_plus
+    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+    # CHECK-NEXT:   %[[C1:.+]] = constant 1.000000e+00 : f64
+    # CHECK-NEXT:   %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32
+    # CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
+    # CHECK-NEXT:   %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32
+    # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[SUM]] : f32
+    # CHECK-NEXT:   linalg.yield %[[LOG]] : f32
+    # CHECK-NEXT: -> tensor<4x16xf32>
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+    def test_f32_soft_plus(input, init_result):
+      return soft_plus_poly(input, outs=[init_result])
+
 
 print(module)


        


More information about the Mlir-commits mailing list