[Mlir-commits] [mlir] 3b95400 - [mlir][linalg][python] Add max operation in OpDSL
Tobias Gysi
llvmlistbot at llvm.org
Fri Jul 2 00:13:05 PDT 2021
Author: Tobias Gysi
Date: 2021-07-02T07:12:37Z
New Revision: 3b95400f78a9824172629123580c0a0df36cbc70
URL: https://github.com/llvm/llvm-project/commit/3b95400f78a9824172629123580c0a0df36cbc70
DIFF: https://github.com/llvm/llvm-project/commit/3b95400f78a9824172629123580c0a0df36cbc70.diff
LOG: [mlir][linalg][python] Add max operation in OpDSL
Add the max operation to the OpDSL and introduce a max pooling operation to test the implementation. As MLIR has no builtin max operation, the max function is lowered to a compare and select pair.
Differential Revision: https://reviews.llvm.org/D105203
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
mlir/test/python/dialects/linalg/opsrun.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index a8baf23bbfaab..39045a212ce11 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,4 +1,3 @@
-
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
@@ -594,6 +593,77 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: pooling_nhwc_max_poly
+ cpp_class_name: PoolingNhwcMaxPolyOp
+ doc: |-
+ Performs max pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ usage: InputOperand
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1, s2, s3)>
+ - !LinalgOperandDefConfig
+ name: K
+ usage: InputOperand
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s4, s5)>
+ - !LinalgOperandDefConfig
+ name: O
+ usage: OutputOperand
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s6, s7, s3)>
+ - !LinalgOperandDefConfig
+ name: strides
+ usage: IndexAttribute
+ type_var: I64
+ attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s8, s9)>
+ - !LinalgOperandDefConfig
+ name: dilations
+ usage: IndexAttribute
+ type_var: I64
+ attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s10, s11)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+ s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+ s10, s11] -> (d3, d4)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+ s10, s11] -> (d0, d1, d2, d5)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ - parallel
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: max
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d
cpp_class_name: FillRng2DOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d0c69b4148345..9b729b9db5d10 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -274,6 +274,21 @@ class RegionBuilderHelper {
llvm_unreachable("unsupported non numeric type");
}
+ Value applyfn__max(Value lhs, Value rhs) {
+ OpBuilder builder = getBuilder();
+ if (isFloatingPoint(lhs)) {
+ Value condition =
+ builder.create<CmpFOp>(lhs.getLoc(), CmpFPredicate::OGT, lhs, rhs);
+ return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
+ }
+ if (isInteger(lhs)) {
+ Value condition =
+ builder.create<CmpIOp>(lhs.getLoc(), CmpIPredicate::sgt, lhs, rhs);
+ return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
+ }
+ llvm_unreachable("unsupported non numeric type");
+ }
+
void yieldOutputs(ValueRange values) {
assert(!values.empty() && "linalg ops must yield outputs");
if (values.empty())
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index f6fb0cc7d0d0e..9489dec522716 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -307,6 +307,18 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
return std.MulIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
+ def _eval_max(self, lhs: Value, rhs: Value) -> Value:
+ i1 = IntegerType.get_signless(1)
+ if _is_floating_point_type(lhs.type):
+ ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
+ cond = std.CmpFOp(i1, ogt_attr, lhs, rhs).result
+ return std.SelectOp(lhs.type, cond, lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
+ cond = std.CmpIOp(i1, sgt_attr, lhs, rhs).result
+ return std.SelectOp(lhs.type, cond, lhs, rhs).result
+ raise NotImplementedError("Unsupported 'max' operand: {lhs}")
+
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
in_arg_defs: Sequence[OperandDefConfig],
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 095d94956f5b7..04c950e0a44db 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -148,6 +148,24 @@ def pooling_nhwc_sum_poly(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+ at linalg_structured_op
+def pooling_nhwc_max_poly(
+ I=TensorDef(T1, S.N, S.H, S.W, S.C),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=AttributeDef(S.SH, S.SW),
+ dilations=AttributeDef(S.DH, S.DW)):
+ """Performs max pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
+ cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+ D.c]))
+
+
@linalg_structured_op
def fill_rng_2d(
min=ScalarDef(F64),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 723859c913c04..4a1cb8dbcfa58 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -60,6 +60,36 @@ func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tenso
// -----
+func @generalize_pooling_nhwc_max_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
+ %0 = linalg.pooling_nhwc_max_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+ ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
+ return %0: tensor<1x2x4x1xf32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_max_poly_f32
+// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
+// CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT: linalg.yield %[[MAX]] : f32
+// CHECK-NEXT: -> tensor<1x2x4x1xf32>
+
+// -----
+
+func @generalize_pooling_nhwc_max_poly_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
+ %0 = linalg.pooling_nhwc_max_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+ ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
+ return %0: tensor<1x2x4x1xi32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_max_poly_i32
+// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
+// CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT: linalg.yield %[[MAX]] : i32
+// CHECK-NEXT: -> tensor<1x2x4x1xi32>
+
+// -----
+
func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
%0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index f7db532dced5c..12f6c560cfecc 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -50,8 +50,9 @@ def pooling_poly(
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] += cast(
- U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
+ cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+ D.c]))
@linalg_structured_op
@@ -221,8 +222,9 @@ def test_f32i32_conv(input, filter, init_result):
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
# CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
- # CHECK-NEXT: %[[SUM:.+]] = addi %[[OUT]], %[[IN_CAST]] : i32
- # CHECK-NEXT: linalg.yield %[[SUM]] : i32
+ # CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT]], %[[IN_CAST:.+]] : i32
+ # CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN_CAST:.+]] : i32
+ # CHECK-NEXT: linalg.yield %[[MAX]] : i32
# CHECK-NEXT: -> tensor<2x4xi32>
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
@@ -231,6 +233,22 @@ def test_f32i32_pooling(input, shape, init_result):
return pooling_poly(
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+ # CHECK-LABEL: @test_f32f32_pooling
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT]], %[[IN:.+]] : f32
+ # CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN:.+]] : f32
+ # CHECK-NEXT: linalg.yield %[[MAX]] : f32
+ # CHECK-NEXT: -> tensor<2x4xf32>
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((2, 4), f32))
+ def test_f32f32_pooling(input, shape, init_result):
+ return pooling_poly(
+ input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+
# CHECK-LABEL: @test_i32_fill_rng
# CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
# CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index 08b13a5352984..c6d26d1c6b858 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -85,6 +85,7 @@ def log(*args):
pooling_boiler = """
func @main() -> i32 attributes {llvm.emit_c_interface} {
%v0 = constant 0 : i32
+ %v42 = constant 42.0 : f64
%v1 = constant 1.0 : f64
%input = memref.alloc() : memref<1x4x16x1xf64>
@@ -94,10 +95,12 @@ def log(*args):
linalg.fill(%v1, %shape) : f64, memref<2x2xf64>
linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
+ %c0 = constant 0 : index
+ memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
+
call @pooling_on_buffers(%input, %shape, %output) :
(memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
- %c0 = constant 0 : index
%0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
// TODO: FFI-based solution to allow testing and printing with python code.
@@ -105,6 +108,7 @@ def log(*args):
}
"""
+
def transform(module, boilerplate):
import mlir.conversions
import mlir.dialects.linalg.passes
@@ -308,12 +312,8 @@ def test_pooling_builtin():
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
MemRefType.get((1, 2, 4, 1), i32))
def pooling_on_buffers(input, shape, output):
- linalg.pooling_nhwc_sum_poly(
- input,
- shape,
- outs=[output],
- strides=[2, 4],
- dilations=[1, 2])
+ linalg.pooling_nhwc_max_poly(
+ input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
@@ -325,7 +325,7 @@ def pooling_on_buffers(input, shape, output):
execution_engine.invoke("main", res)
log("RESULT: ", res[0])
- # CHECK: RESULT: 4
+ # CHECK: RESULT: 42
test_pooling_builtin()
@@ -342,7 +342,7 @@ def test_pooling_generic():
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
MemRefType.get((1, 2, 4, 1), i32))
def pooling_on_buffers(input, shape, output):
- linalg.pooling_nhwc_sum_poly(
+ linalg.pooling_nhwc_max_poly(
input,
shape,
outs=[output],
@@ -360,7 +360,7 @@ def pooling_on_buffers(input, shape, output):
execution_engine.invoke("main", res)
log("RESULT: ", res[0])
- # CHECK: RESULT: 4
+ # CHECK: RESULT: 42
test_pooling_generic()
More information about the Mlir-commits
mailing list