[Mlir-commits] [mlir] 13d3307 - [mlir][linalg] Add a few unary operations.
Bixia Zheng
llvmlistbot at llvm.org
Thu Mar 10 09:39:04 PST 2022
Author: Bixia Zheng
Date: 2022-03-10T09:38:58-08:00
New Revision: 13d330717666646443946e13df90c84a44ab4722
URL: https://github.com/llvm/llvm-project/commit/13d330717666646443946e13df90c84a44ab4722
DIFF: https://github.com/llvm/llvm-project/commit/13d330717666646443946e13df90c84a44ab4722.diff
LOG: [mlir][linalg] Add a few unary operations.
Add operations abs, ceil, floor, and neg to the C++ API and Python API.
Add test cases.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D121339
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/python/dialects/linalg/opdsl/emit_misc.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index f962eb6b2a869..a1a8477b9bee4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,7 +61,11 @@ def Linalg_Dialect : Dialect {
// Define the function attribute enums matching the OpDSL functions.
def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"exp", 0>,
- I32EnumAttrCase<"log", 1>
+ I32EnumAttrCase<"log", 1>,
+ I32EnumAttrCase<"abs", 2>,
+ I32EnumAttrCase<"ceil", 3>,
+ I32EnumAttrCase<"floor", 4>,
+ I32EnumAttrCase<"negf", 5>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 02ed7555a418c..8880c16b8ccd4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -144,6 +144,14 @@ class RegionBuilderHelper {
return builder.create<math::ExpOp>(arg.getLoc(), arg);
case UnaryFn::log:
return builder.create<math::LogOp>(arg.getLoc(), arg);
+ case UnaryFn::abs:
+ return builder.create<math::AbsOp>(arg.getLoc(), arg);
+ case UnaryFn::ceil:
+ return builder.create<math::CeilOp>(arg.getLoc(), arg);
+ case UnaryFn::floor:
+ return builder.create<math::FloorOp>(arg.getLoc(), arg);
+ case UnaryFn::negf:
+ return builder.create<arith::NegFOp>(arg.getLoc(), arg);
}
llvm_unreachable("unsupported unary function");
}
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 47083de625def..135f55ea516d0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -274,6 +274,10 @@ class UnaryFn:
"""Unary function namespace."""
exp = UnaryFnType("exp")
log = UnaryFnType("log")
+ abs = UnaryFnType("abs")
+ ceil = UnaryFnType("ceil")
+ floor = UnaryFnType("floor")
+ negf = UnaryFnType("negf")
class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 93baef14bc197..2e71e561a7f54 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -390,6 +390,26 @@ def _unary_log(self, x: Value) -> Value:
return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
+ def _unary_abs(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.AbsOp(x).result
+ raise NotImplementedError("Unsupported 'abs' operand: {x}")
+
+ def _unary_ceil(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.CeilOp(x).result
+ raise NotImplementedError("Unsupported 'ceil' operand: {x}")
+
+ def _unary_floor(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return math.FloorOp(x).result
+ raise NotImplementedError("Unsupported 'floor' operand: {x}")
+
+ def _unary_negf(self, x: Value) -> Value:
+ if _is_floating_point_type(x.type):
+ return arith.NegFOp(x).result
+ raise NotImplementedError("Unsupported 'negf' operand: {x}")
+
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.AddFOp(lhs, rhs).result
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index ebb9a87696ad0..3ac2d752bdd93 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -298,6 +298,54 @@ func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>)
// -----
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_abs(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>}
+ ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_abs
+// CHECK: = math.abs
+
+// -----
+
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_ceil(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<ceil>}
+ ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_ceil
+// CHECK: = math.ceil
+
+// -----
+
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_floor(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<floor>}
+ ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_floor
+// CHECK: = math.floor
+
+// -----
+
+// Verifies the fun attribute controls the unary function used.
+func @generalize_elemwise_negf(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
+ ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0: tensor<4x8xf32>
+}
+
+// CHECK-LABEL: @generalize_elemwise_negf
+// CHECK: = arith.negf
+
+// -----
+
// Verifies the default value of the fun attribute is an add op.
func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index e69d71dcb337d..e57a49bec7b82 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -11,7 +11,7 @@
# fill, matmul, convolution, or pooling tests. The features include:
# - constant defined in the body
# - fix/predefined types
-# - exponential functions
+# - some math/arith functions, including abs, ceil, exp, floor, log, and negf
# - custom op names.
@@ -89,6 +89,46 @@ def test_f32_elemwise_exp(input, init_result):
def test_f32_elemwise_log(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+ # CHECK-LABEL: @test_f32_elemwise_abs
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[EXP:.+]] = math.abs %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+ def test_f32_elemwise_abs(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
+ # CHECK-LABEL: @test_f32_elemwise_ceil
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+ def test_f32_elemwise_ceil(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.ceil)
+
+ # CHECK-LABEL: @test_f32_elemwise_floor
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+ def test_f32_elemwise_floor(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.floor)
+
+ # CHECK-LABEL: @test_f32_elemwise_neg
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32))
+ def test_f32_elemwise_neg(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
# Just check that we don't assert out on name mismatch.
# CHECK-LABEL: @test_non_default_op_name
@builtin.FuncOp.from_py_func(
More information about the Mlir-commits
mailing list