[Mlir-commits] [mlir] a744c7e - [mlir][linalg] Update OpDSL to use the newly introduced min and max ops.
Tobias Gysi
llvmlistbot at llvm.org
Wed Oct 6 00:05:02 PDT 2021
Author: Tobias Gysi
Date: 2021-10-06T06:45:53Z
New Revision: a744c7e962d85a6c0b2de19eff840755ef5c2a1d
URL: https://github.com/llvm/llvm-project/commit/a744c7e962d85a6c0b2de19eff840755ef5c2a1d
DIFF: https://github.com/llvm/llvm-project/commit/a744c7e962d85a6c0b2de19eff840755ef5c2a1d.diff
LOG: [mlir][linalg] Update OpDSL to use the newly introduced min and max ops.
Implement min and max using the newly introduced std operations instead of relying on compare and select.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D111170
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
mlir/test/python/integration/dialects/linalg/opsrun.py
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index af292878d1f6e..69cd9e25e5d94 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -276,18 +276,20 @@ class RegionBuilderHelper {
}
Value applyfn__max(Value lhs, Value rhs) {
+ OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
- return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT);
+ return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
- return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt);
+ return builder.create<MaxSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__min(Value lhs, Value rhs) {
+ OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
- return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT);
+ return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
- return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt);
+ return builder.create<MinSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
@@ -324,17 +326,6 @@ class RegionBuilderHelper {
MLIRContext *context;
Block █
- Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) {
- OpBuilder builder = getBuilder();
- Value condition = builder.create<CmpFOp>(lhs.getLoc(), predicate, lhs, rhs);
- return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
- }
- Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) {
- OpBuilder builder = getBuilder();
- Value condition = builder.create<CmpIOp>(lhs.getLoc(), predicate, lhs, rhs);
- return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
- }
-
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index b151a9ba9f39f..4a883e79037b5 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -319,20 +319,16 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
- return _emit_cmpf_and_select(lhs, rhs, ogt_attr)
+ return std.MaxFOp(lhs.type, lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
- return _emit_cmpi_and_select(lhs, rhs, sgt_attr)
+ return std.MaxSIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
- return _emit_cmpf_and_select(lhs, rhs, olt_attr)
+ return std.MinFOp(lhs.type, lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
- return _emit_cmpi_and_select(lhs, rhs, slt_attr)
+ return std.MinSIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
@@ -413,13 +409,3 @@ def _get_floating_point_width(t: Type) -> int:
if BF16Type.isinstance(t):
return 16
raise NotImplementedError(f"Unhandled floating point type switch {t}")
-
-
-def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
- cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result
- return std.SelectOp(lhs.type, cond, lhs, rhs).result
-
-
-def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
- cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result
- return std.SelectOp(lhs.type, cond, lhs, rhs).result
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 3e934d42012c4..89fd83e585eef 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -38,8 +38,7 @@ func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
// CHECK-LABEL: @generalize_pooling_nhwc_max_f32
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
-// CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT_ARG]], %[[IN_ARG]] : f32
-// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT_ARG]], %[[IN_ARG]] : f32
// CHECK-NEXT: linalg.yield %[[MAX]] : f32
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
@@ -53,8 +52,7 @@ func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
// CHECK-LABEL: @generalize_pooling_nhwc_max_i32
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
-// CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT_ARG]], %[[IN_ARG]] : i32
-// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT_ARG]], %[[IN_ARG]] : i32
// CHECK-NEXT: linalg.yield %[[MAX]] : i32
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
@@ -68,9 +66,8 @@ func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
// CHECK-LABEL: @generalize_pooling_nhwc_min_f32
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
-// CHECK-NEXT: %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32
-// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
-// CHECK-NEXT: linalg.yield %[[MAX]] : f32
+// CHECK-NEXT: %[[MIN:.+]] = minf %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT: linalg.yield %[[MIN]] : f32
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
// -----
@@ -83,9 +80,8 @@ func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
// CHECK-LABEL: @generalize_pooling_nhwc_min_i32
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
-// CHECK-NEXT: %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32
-// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
-// CHECK-NEXT: linalg.yield %[[MAX]] : i32
+// CHECK-NEXT: %[[MIN:.+]] = minsi %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT: linalg.yield %[[MIN]] : i32
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
// -----
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index ed33644859012..16a82f63dbc83 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -242,8 +242,7 @@ def test_f32i32_conv(input, filter, init_result):
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
# CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
- # CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT]], %[[IN_CAST:.+]] : i32
- # CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN_CAST:.+]] : i32
+ # CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
# CHECK-NEXT: linalg.yield %[[MAX]] : i32
# CHECK-NEXT: -> tensor<2x4xi32>
@builtin.FuncOp.from_py_func(
@@ -258,8 +257,7 @@ def test_f32i32_max_pooling(input, shape, init_result):
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT]], %[[IN:.+]] : f32
- # CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN:.+]] : f32
+ # CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT]], %[[IN:.+]] : f32
# CHECK-NEXT: linalg.yield %[[MAX]] : f32
# CHECK-NEXT: -> tensor<2x4xf32>
@builtin.FuncOp.from_py_func(
@@ -270,7 +268,7 @@ def test_f32f32_max_pooling(input, shape, init_result):
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32i32_min_pooling
- # CHECK: = cmpi slt,
+ # CHECK: = minsi
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
RankedTensorType.get((2, 4), i32))
@@ -279,7 +277,7 @@ def test_f32i32_min_pooling(input, shape, init_result):
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32f32_min_pooling
- # CHECK: = cmpf olt,
+ # CHECK: = minf
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
RankedTensorType.get((2, 4), f32))
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 04ee6c8dc5ee8..5491193fa992b 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -118,7 +118,7 @@ def log(*args):
def transform(module, boilerplate):
import mlir.conversions
- import mlir.dialects.linalg.passes
+ import mlir.all_passes_registration
import mlir.transforms
# TODO: Allow cloning functions from one module to another.
@@ -128,8 +128,8 @@ def transform(module, boilerplate):
boilerplate)
pm = PassManager.parse(
"builtin.func(convert-linalg-to-loops, lower-affine, " +
- "convert-scf-to-std), convert-vector-to-llvm," +
- "convert-memref-to-llvm,convert-std-to-llvm," +
+ "convert-scf-to-std, std-expand), convert-vector-to-llvm," +
+ "convert-memref-to-llvm, convert-std-to-llvm," +
"reconcile-unrealized-casts")
pm.run(mod)
return mod
More information about the Mlir-commits
mailing list