[Mlir-commits] [mlir] a744c7e - [mlir][linalg] Update OpDSL to use the newly introduced min and max ops.

Tobias Gysi llvmlistbot at llvm.org
Wed Oct 6 00:05:02 PDT 2021


Author: Tobias Gysi
Date: 2021-10-06T06:45:53Z
New Revision: a744c7e962d85a6c0b2de19eff840755ef5c2a1d

URL: https://github.com/llvm/llvm-project/commit/a744c7e962d85a6c0b2de19eff840755ef5c2a1d
DIFF: https://github.com/llvm/llvm-project/commit/a744c7e962d85a6c0b2de19eff840755ef5c2a1d.diff

LOG: [mlir][linalg] Update OpDSL to use the newly introduced min and max ops.

Implement min and max using the newly introduced std operations instead of relying on compare and select.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D111170

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/test/python/integration/dialects/linalg/opsrun.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index af292878d1f6e..69cd9e25e5d94 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -276,18 +276,20 @@ class RegionBuilderHelper {
   }
 
   Value applyfn__max(Value lhs, Value rhs) {
+    OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
-      return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT);
+      return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
     if (isInteger(lhs))
-      return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt);
+      return builder.create<MaxSIOp>(lhs.getLoc(), lhs, rhs);
     llvm_unreachable("unsupported non numeric type");
   }
 
   Value applyfn__min(Value lhs, Value rhs) {
+    OpBuilder builder = getBuilder();
     if (isFloatingPoint(lhs))
-      return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT);
+      return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
     if (isInteger(lhs))
-      return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt);
+      return builder.create<MinSIOp>(lhs.getLoc(), lhs, rhs);
     llvm_unreachable("unsupported non numeric type");
   }
 
@@ -324,17 +326,6 @@ class RegionBuilderHelper {
   MLIRContext *context;
   Block █
 
-  Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) {
-    OpBuilder builder = getBuilder();
-    Value condition = builder.create<CmpFOp>(lhs.getLoc(), predicate, lhs, rhs);
-    return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
-  }
-  Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) {
-    OpBuilder builder = getBuilder();
-    Value condition = builder.create<CmpIOp>(lhs.getLoc(), predicate, lhs, rhs);
-    return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
-  }
-
   bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
   bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index b151a9ba9f39f..4a883e79037b5 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -319,20 +319,16 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
 
   def _eval_max(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
-      ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
-      return _emit_cmpf_and_select(lhs, rhs, ogt_attr)
+      return std.MaxFOp(lhs.type, lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
-      return _emit_cmpi_and_select(lhs, rhs, sgt_attr)
+      return std.MaxSIOp(lhs.type, lhs, rhs).result
     raise NotImplementedError("Unsupported 'max' operand: {lhs}")
 
   def _eval_min(self, lhs: Value, rhs: Value) -> Value:
     if _is_floating_point_type(lhs.type):
-      olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
-      return _emit_cmpf_and_select(lhs, rhs, olt_attr)
+      return std.MinFOp(lhs.type, lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
-      slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
-      return _emit_cmpi_and_select(lhs, rhs, slt_attr)
+      return std.MinSIOp(lhs.type, lhs, rhs).result
     raise NotImplementedError("Unsupported 'min' operand: {lhs}")
 
 
@@ -413,13 +409,3 @@ def _get_floating_point_width(t: Type) -> int:
   if BF16Type.isinstance(t):
     return 16
   raise NotImplementedError(f"Unhandled floating point type switch {t}")
-
-
-def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
-  cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result
-  return std.SelectOp(lhs.type, cond, lhs, rhs).result
-
-
-def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
-  cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result
-  return std.SelectOp(lhs.type, cond, lhs, rhs).result

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 3e934d42012c4..89fd83e585eef 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -38,8 +38,7 @@ func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
 
 // CHECK-LABEL: @generalize_pooling_nhwc_max_f32
 // CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
-// CHECK-NEXT:   %[[COND:.+]] = cmpf ogt, %[[OUT_ARG]], %[[IN_ARG]] : f32
-// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT:   %[[MAX:.+]] = maxf %[[OUT_ARG]], %[[IN_ARG]] : f32
 // CHECK-NEXT:   linalg.yield %[[MAX]] : f32
 // CHECK-NEXT: -> tensor<1x2x4x1xf32>
 
@@ -53,8 +52,7 @@ func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
 
 // CHECK-LABEL: @generalize_pooling_nhwc_max_i32
 // CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
-// CHECK-NEXT:   %[[COND:.+]] = cmpi sgt, %[[OUT_ARG]], %[[IN_ARG]] : i32
-// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT:   %[[MAX:.+]] = maxsi %[[OUT_ARG]], %[[IN_ARG]] : i32
 // CHECK-NEXT:   linalg.yield %[[MAX]] : i32
 // CHECK-NEXT: -> tensor<1x2x4x1xi32>
 
@@ -68,9 +66,8 @@ func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
 
 // CHECK-LABEL: @generalize_pooling_nhwc_min_f32
 // CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
-// CHECK-NEXT:   %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32
-// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
-// CHECK-NEXT:   linalg.yield %[[MAX]] : f32
+// CHECK-NEXT:   %[[MIN:.+]] = minf %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT:   linalg.yield %[[MIN]] : f32
 // CHECK-NEXT: -> tensor<1x2x4x1xf32>
 
 // -----
@@ -83,9 +80,8 @@ func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
 
 // CHECK-LABEL: @generalize_pooling_nhwc_min_i32
 // CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
-// CHECK-NEXT:   %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32
-// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
-// CHECK-NEXT:   linalg.yield %[[MAX]] : i32
+// CHECK-NEXT:   %[[MIN:.+]] = minsi %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT:   linalg.yield %[[MIN]] : i32
 // CHECK-NEXT: -> tensor<1x2x4x1xi32>
 
 // -----

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index ed33644859012..16a82f63dbc83 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -242,8 +242,7 @@ def test_f32i32_conv(input, filter, init_result):
     # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
     # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
     # CHECK-NEXT:   %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
-    # CHECK-NEXT:   %[[COND:.+]] = cmpi sgt, %[[OUT]], %[[IN_CAST:.+]] : i32
-    # CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN_CAST:.+]] : i32
+    # CHECK-NEXT:   %[[MAX:.+]] = maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
     # CHECK-NEXT:   linalg.yield %[[MAX]] : i32
     # CHECK-NEXT: -> tensor<2x4xi32>
     @builtin.FuncOp.from_py_func(
@@ -258,8 +257,7 @@ def test_f32i32_max_pooling(input, shape, init_result):
     # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
     # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
     # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
-    # CHECK-NEXT:   %[[COND:.+]] = cmpf ogt, %[[OUT]], %[[IN:.+]] : f32
-    # CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN:.+]] : f32
+    # CHECK-NEXT:   %[[MAX:.+]] = maxf %[[OUT]], %[[IN:.+]] : f32
     # CHECK-NEXT:   linalg.yield %[[MAX]] : f32
     # CHECK-NEXT: -> tensor<2x4xf32>
     @builtin.FuncOp.from_py_func(
@@ -270,7 +268,7 @@ def test_f32f32_max_pooling(input, shape, init_result):
           input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
     # CHECK-LABEL: @test_f32i32_min_pooling
-    # CHECK:   = cmpi slt,
+    # CHECK:   = minsi
     @builtin.FuncOp.from_py_func(
         RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((2, 4), i32))
@@ -279,7 +277,7 @@ def test_f32i32_min_pooling(input, shape, init_result):
           input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
     # CHECK-LABEL: @test_f32f32_min_pooling
-    # CHECK:   = cmpf olt,
+    # CHECK:   = minf
     @builtin.FuncOp.from_py_func(
         RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((2, 4), f32))

diff  --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 04ee6c8dc5ee8..5491193fa992b 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -118,7 +118,7 @@ def log(*args):
 
 def transform(module, boilerplate):
   import mlir.conversions
-  import mlir.dialects.linalg.passes
+  import mlir.all_passes_registration
   import mlir.transforms
 
   # TODO: Allow cloning functions from one module to another.
@@ -128,8 +128,8 @@ def transform(module, boilerplate):
       boilerplate)
   pm = PassManager.parse(
       "builtin.func(convert-linalg-to-loops, lower-affine, " +
-      "convert-scf-to-std), convert-vector-to-llvm," +
-      "convert-memref-to-llvm,convert-std-to-llvm," +
+      "convert-scf-to-std, std-expand), convert-vector-to-llvm," +
+      "convert-memref-to-llvm, convert-std-to-llvm," +
       "reconcile-unrealized-casts")
   pm.run(mod)
   return mod


        


More information about the Mlir-commits mailing list