[Mlir-commits] [mlir] 6d6d5db - [MLIR][Linalg] Generate the right type of load/store when lowering max/min pooling ops

Uday Bondhugula llvmlistbot at llvm.org
Sat Jul 4 02:28:31 PDT 2020


Author: Uday Bondhugula
Date: 2020-07-04T14:55:02+05:30
New Revision: 6d6d5db251e88b0c40f7a6951b51b9e4a1812c8c

URL: https://github.com/llvm/llvm-project/commit/6d6d5db251e88b0c40f7a6951b51b9e4a1812c8c
DIFF: https://github.com/llvm/llvm-project/commit/6d6d5db251e88b0c40f7a6951b51b9e4a1812c8c.diff

LOG: [MLIR][Linalg] Generate the right type of load/store when lowering max/min pooling ops

While lowering min/max pooling ops to loops, generate the right kind of
load/stores (std or affine) instead of always generating std
load/stores.

Differential Revision: https://reviews.llvm.org/D83080

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/Dialect/Linalg/affine.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index ec57717eaca9..575115c0fbed 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -333,23 +333,28 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) {
 
 template <typename IndexedValueType>
 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
-  auto indices = getInputAndOutputIndices(allIvs, op);
+  InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
   // Emit scalar form.
-  Value lhs = std_load(op.output(), indices.outputs);
-  Value rhs = std_load(op.input(), indices.inputs);
+  IndexedValueType output(op.output());
+  IndexedValueType input(op.input());
+  Value lhs = output(indices.outputs);
+  Value rhs = input(indices.inputs);
   using edsc::op::sgt;
   Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
-  std_store(maxValue, op.output(), indices.outputs);
+  output(indices.outputs) = maxValue;
 }
+
 template <typename IndexedValueType>
 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
-  auto indices = getInputAndOutputIndices(allIvs, op);
+  InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op);
   // Emit scalar form.
-  Value lhs = std_load(op.output(), indices.outputs);
-  Value rhs = std_load(op.input(), indices.inputs);
+  IndexedValueType output(op.output());
+  IndexedValueType input(op.input());
+  Value lhs = output(indices.outputs);
+  Value rhs = input(indices.inputs);
   using edsc::op::slt;
   Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
-  std_store(minValue, op.output(), indices.outputs);
+  output(indices.outputs) = minValue;
 }
 template <typename IndexedValueType>
 void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {

diff  --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir
index cb2064602c47..13f37d844b8a 100644
--- a/mlir/test/Dialect/Linalg/affine.mlir
+++ b/mlir/test/Dialect/Linalg/affine.mlir
@@ -123,3 +123,27 @@ func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memre
 //       CHECK:       %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
 //       CHECK:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
 //       CHECK:       affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+
+// CHECK-LABEL: func @pooling_max_min
+func @pooling_max_min(%arg0: memref<?x?xf32>,
+                      %arg1: memref<?x?xi32>,
+                      %arg2: memref<?x?xf32>) {
+  linalg.pooling_max(%arg0, %arg1, %arg2) { strides = [2, 1] }:
+    memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
+  linalg.pooling_min(%arg0, %arg1, %arg2) { strides = [2, 1] }:
+    memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
+  return
+}
+// This is a basic check to make sure the right load/stores are used. loops.mlir
+// checks for the rest.
+// CHECK:      affine.load
+// CHECK-NEXT: affine.load
+// CHECK-NEXT: cmpf
+// CHECK-NEXT: select
+// CHECK-NEXT: affine.store
+// The min pooling body.
+// CHECK:      affine.load
+// CHECK-NEXT: affine.load
+// CHECK-NEXT: cmpf
+// CHECK-NEXT: select
+// CHECK-NEXT: affine.store


        


More information about the Mlir-commits mailing list