[Mlir-commits] [mlir] abfa950 - [mlir][linalg][python] Add exp and log to the OpDSL.
Tobias Gysi
llvmlistbot at llvm.org
Thu Jul 8 02:11:09 PDT 2021
Author: Tobias Gysi
Date: 2021-07-08T08:48:23Z
New Revision: abfa950d86da1737a7dd52ba262fa39dd2e937fa
URL: https://github.com/llvm/llvm-project/commit/abfa950d86da1737a7dd52ba262fa39dd2e937fa
DIFF: https://github.com/llvm/llvm-project/commit/abfa950d86da1737a7dd52ba262fa39dd2e937fa.diff
LOG: [mlir][linalg][python] Add exp and log to the OpDSL.
Introduce the exp and log function in OpDSL. Add the soft plus operator to test the emitted IR in Python and C++.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D105420
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 092d22983d3f2..49ececc0790aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -33,8 +33,8 @@ def Linalg_Dialect : Dialect {
}];
let cppNamespace = "::mlir::linalg";
let dependentDialects = [
- "AffineDialect", "memref::MemRefDialect", "StandardOpsDialect",
- "tensor::TensorDialect"
+ "AffineDialect", "math::MathDialect", "memref::MemRefDialect",
+ "StandardOpsDialect", "tensor::TensorDialect"
];
let hasCanonicalizer = 1;
let hasOperationAttrVerify = 1;
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 1e4277ecd7bdf..04f9776005c4e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -887,3 +887,58 @@ structured_op: !LinalgStructuredOpConfig
scalar_const: '2.3283063999999999E-10 : f64'
- !ScalarExpression
scalar_arg: min
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: soft_plus_2d
+ cpp_class_name: SoftPlus2DOp
+ doc: |-
+ Implements the soft plus operator.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ usage: InputOperand
+ type_var: T
+ shape_map: affine_map<()[s0, s1] -> (s0, s1)>
+ - !LinalgOperandDefConfig
+ name: O
+ usage: OutputOperand
+ type_var: U
+ shape_map: affine_map<()[s0, s1] -> (s0, s1)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+ - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: log
+ operands:
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_const: '1.000000e+00 : f64'
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: exp
+ operands:
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index c5cfdd15c00a8..f5913e6ad6164 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_LINALG_LINALGTYPES_H_
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index 21104281b8120..14187f400e726 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRLinalg
MLIRSideEffectInterfaces
MLIRViewLikeInterface
MLIRStandard
+ MLIRMath
MLIRMemRef
MLIRTensor
)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 93062b10ccc63..ea12a312d9c01 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -256,6 +256,20 @@ class RegionBuilderHelper {
llvm_unreachable("unsupported non numeric type");
}
+ Value applyfn__exp(Value x) {
+ OpBuilder builder = getBuilder();
+ if (isFloatingPoint(x))
+ return builder.create<math::ExpOp>(x.getLoc(), x);
+ llvm_unreachable("unsupported non numeric type");
+ }
+
+ Value applyfn__log(Value x) {
+ OpBuilder builder = getBuilder();
+ if (isFloatingPoint(x))
+ return builder.create<math::LogOp>(x.getLoc(), x);
+ llvm_unreachable("unsupported non numeric type");
+ }
+
Value applyfn__sub(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 61d2260587116..3810df9dff74a 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -7,6 +7,7 @@
from mlir.ir import *
from mlir.dialects import linalg
from mlir.dialects import std
+from mlir.dialects import math
# TODO: resolve name collision for Linalg functionality that is injected inside
# the _mlir.dialects.linalg directly via pybind.
from _mlir.dialects.linalg import fill_builtin_region
@@ -293,6 +294,16 @@ def _eval_add(self, lhs: Value, rhs: Value) -> Value:
return std.AddIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operand: {lhs}")
+ def _eval_exp(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.ExpOp(x.type, x).result
+ raise NotImplementedError("Unsupported 'exp' operand: {x}")
+
+ def _eval_log(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.LogOp(x.type, x).result
+ raise NotImplementedError("Unsupported 'log' operand: {x}")
+
def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return std.SubFOp(lhs.type, lhs, rhs).result
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index a37e1944c1f75..72793cbf9c726 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -209,3 +209,16 @@ def fill_rng_2d(
offset = cast(F64, const(2147483647))
scaling = (max - min) * inv_range
O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
+
+
+ at linalg_structured_op
+def soft_plus_2d(
+ I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
+ """Implements the soft plus operator.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ domain(D.m, D.n)
+ O[D.m, D.n] = \
+ PrimFn.log(cast(U, const(1.0)) + PrimFn.exp(cast(U, I[D.m, D.n])))
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 0e1c6a62a7b10..aed3585d4f547 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -188,6 +188,23 @@ func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16
// CHECK-NEXT: linalg.yield %[[VAL6]] : i32
// CHECK-NEXT: -> tensor<16x32xi32>
+// -----
+
+func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK-LABEL: @generalize_soft_plus_2d_f32
+// CHECK: %[[C1:.+]] = constant 1.000000e+00 : f64
+// CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
+// CHECK-NEXT: %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32
+// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32
+// CHECK-NEXT: %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32
+// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32
+// CHECK-NEXT: linalg.yield %[[LOG]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
// -----
// Verifies floating point to integer cast.
func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 44ac4e8e8c5b4..ed33644859012 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -84,6 +84,13 @@ def fill_rng_poly(
O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min)
+ at linalg_structured_op
+def soft_plus_poly(
+ I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)):
+ O[D.m, D.n] = \
+ PrimFn.log(cast(U, const(1.0)) + cast(U, PrimFn.exp(I[D.m, D.n])))
+
+
with Context() as ctx, Location.unknown():
module = Module.create()
f16 = F16Type.get()
@@ -299,5 +306,19 @@ def test_f32f32_min_pooling(input, shape, init_result):
def test_i32_fill_rng(min, max, seed, init_result):
return fill_rng_poly(min, max, seed, outs=[init_result])
+ # CHECK-LABEL: @test_f32_soft_plus
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[C1:.+]] = constant 1.000000e+00 : f64
+ # CHECK-NEXT: %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32
+ # CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32
+ # CHECK-NEXT: %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32
+ # CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32
+ # CHECK-NEXT: linalg.yield %[[LOG]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+ def test_f32_soft_plus(input, init_result):
+ return soft_plus_poly(input, outs=[init_result])
+
print(module)
More information about the Mlir-commits
mailing list