[Mlir-commits] [mlir] 755c050 - [mlir][Linalg] Fix load/store operations generated while lower loops when

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 4 17:05:17 PST 2020


Author: MaheshRavishankar
Date: 2020-03-04T17:04:30-08:00
New Revision: 755c050200bad608dffd376929a230cd5d9936d7

URL: https://github.com/llvm/llvm-project/commit/755c050200bad608dffd376929a230cd5d9936d7
DIFF: https://github.com/llvm/llvm-project/commit/755c050200bad608dffd376929a230cd5d9936d7.diff

LOG: [mlir][Linalg] Fix load/store operations generated while lower loops when
output has zero rank.

While lowering to loops, no indices should be used in the load/store
operation if the buffer is zero-rank.

Differential Revision: https://reviews.llvm.org/D75391

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 3c1ab7842ac3..5701c37bf95f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -242,21 +242,25 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
     // 1.a. Emit std_load from input views.
     for (unsigned i = 0; i < nInputs; ++i) {
       Value input = genericOp.getInput(i);
-      if (!input.getType().cast<ShapedType>().getRank()) {
-        indexedValues[i] = std_load(input);
-      } else {
+      if (input.getType().cast<ShapedType>().getRank()) {
         ValueHandleArray indexing(makeCanonicalAffineApplies(
             b, loc, genericOp.getInputIndexingMap(i), allIvs));
         indexedValues[i] = std_load(input, indexing);
+      } else {
+        indexedValues[i] = std_load(input);
       }
     }
 
     // 1.b. Emit std_load from output views.
     for (unsigned i = 0; i < nOutputs; ++i) {
-      ValueHandleArray indexing(makeCanonicalAffineApplies(
-          b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-      indexedValues[nInputs + i] =
-          std_load(genericOp.getOutputBuffer(i), indexing);
+      Value output = genericOp.getOutputBuffer(i);
+      if (output.getType().cast<ShapedType>().getRank()) {
+        ValueHandleArray indexing(makeCanonicalAffineApplies(
+            b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+        indexedValues[nInputs + i] = std_load(output, indexing);
+      } else {
+        indexedValues[nInputs + i] = std_load(output);
+      }
     }
 
     auto funcOp = genericOp.getFunction();
@@ -267,9 +271,14 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
 
       // 3. Emit std_store.
       for (unsigned i = 0; i < nOutputs; ++i) {
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-        std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing);
+        Value output = genericOp.getOutputBuffer(i);
+        if (output.getType().cast<ShapedType>().getRank()) {
+          ValueHandleArray indexing(makeCanonicalAffineApplies(
+              b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+          std_store(callOp->getResult(i), output, indexing);
+        } else {
+          std_store(callOp->getResult(i), output);
+        }
       }
       return;
     }
@@ -288,10 +297,15 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
     auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
     assert(yieldOp->getNumOperands() == nOutputs);
     for (unsigned i = 0; i < nOutputs; ++i) {
-      ValueHandleArray indexing(makeCanonicalAffineApplies(
-          b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-      std_store(map.lookup(yieldOp->getOperand(i)),
-                genericOp.getOutputBuffer(i), indexing);
+      Value output = genericOp.getOutputBuffer(i);
+      if (output.getType().cast<ShapedType>().getRank()) {
+        ValueHandleArray indexing(makeCanonicalAffineApplies(
+            b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+        std_store(map.lookup(yieldOp->getOperand(i)),
+                  genericOp.getOutputBuffer(i), indexing);
+      } else {
+        std_store(map.lookup(yieldOp->getOperand(i)), output);
+      }
     }
   }
 };
@@ -348,21 +362,25 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
     // 1.a. Emit std_load from input views.
     for (unsigned i = 0; i < nInputs; ++i) {
       Value input = indexedGenericOp.getInput(i);
-      if (!input.getType().cast<ShapedType>().getRank()) {
-        indexedValues[nLoops + i] = std_load(input);
-      } else {
+      if (input.getType().cast<ShapedType>().getRank()) {
         ValueHandleArray indexing(makeCanonicalAffineApplies(
             b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
         indexedValues[nLoops + i] = std_load(input, indexing);
+      } else {
+        indexedValues[nLoops + i] = std_load(input);
       }
     }
 
     // 1.b. Emit std_load from output views.
     for (unsigned i = 0; i < nOutputs; ++i) {
-      ValueHandleArray indexing(makeCanonicalAffineApplies(
-          b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
-      indexedValues[nLoops + nInputs + i] =
-          std_load(indexedGenericOp.getOutputBuffer(i), indexing);
+      Value output = indexedGenericOp.getOutputBuffer(i);
+      if (output.getType().cast<ShapedType>().getRank()) {
+        ValueHandleArray indexing(makeCanonicalAffineApplies(
+            b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+        indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
+      } else {
+        indexedValues[nLoops + nInputs + i] = std_load(output);
+      }
     }
 
     if (auto funcOp = indexedGenericOp.getFunction()) {
@@ -372,10 +390,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
 
       // 3. Emit std_store.
       for (unsigned i = 0; i < nOutputs; ++i) {
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
-        std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i),
-                  indexing);
+        Value output = indexedGenericOp.getOutputBuffer(i);
+        if (output.getType().cast<ShapedType>().getRank()) {
+          ValueHandleArray indexing(makeCanonicalAffineApplies(
+              b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+          std_store(callOp->getResult(i), output, indexing);
+        } else {
+          std_store(callOp->getResult(i), output);
+        }
       }
       return;
     }
@@ -394,10 +416,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
     auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
     assert(yieldOp->getNumOperands() == nOutputs);
     for (unsigned i = 0; i < nOutputs; ++i) {
-      ValueHandleArray indexing(makeCanonicalAffineApplies(
-          b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
-      std_store(map.lookup(yieldOp->getOperand(i)),
-                indexedGenericOp.getOutputBuffer(i), indexing);
+      Value output = indexedGenericOp.getOutputBuffer(i);
+      if (output.getType().cast<ShapedType>().getRank()) {
+        ValueHandleArray indexing(makeCanonicalAffineApplies(
+            b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+        std_store(map.lookup(yieldOp->getOperand(i)), output, indexing);
+      } else {
+        std_store(map.lookup(yieldOp->getOperand(i)), output);
+      }
     }
   }
 };

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 6aaa1ba37aa8..59487c71eedb 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -411,3 +411,75 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
 // CHECK:     %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
 // CHECK:     %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
 // CHECK:     store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
+
+#reduce_1D_access = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (0)>
+]
+
+#trait_reduce_1D = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #reduce_1D_access,
+  iterator_types = ["reduction"],
+  library_call = "some_reduce_external_fn"
+}
+
+func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
+{
+  linalg.generic #trait_reduce_1D %arg0, %arg1 {
+    ^bb(%a: f32, %b: f32) :
+      %0 = addf %a, %b : f32
+      linalg.yield %0 : f32
+  } : memref<?xf32>, memref<f32>
+  return
+}
+// CHECK-LABEL: @generic_op_1D_reduce
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK:   %[[a:.*]] = load %[[ARG0]][%[[i]]]
+// CHECK:   %[[b:.*]] = load %[[ARG1]][]
+// CHECK:   %[[c:.*]] = addf %[[a]], %[[b]] : f32
+// CHECK:   store %[[c]], %[[ARG1]][]
+
+
+#reduce_init_1D_access = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (0)>,
+  affine_map<(i) -> (0)>
+]
+
+#trait_reduce_init_1D = {
+  args_in = 2,
+  args_out = 1,
+  indexing_maps = #reduce_init_1D_access,
+  iterator_types = ["reduction"],
+  library_call = "some_reduce_external_fn"
+}
+
+func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
+                                   %arg1: memref<f32>,
+                                   %arg2: memref<f32>)
+{
+  linalg.indexed_generic #trait_reduce_init_1D %arg0, %arg1, %arg2 {
+    ^bb(%i : index, %a: f32, %b: f32, %c: f32) :
+      %0 = constant 0 : index
+      %1 = cmpi "eq", %0, %i : index
+      %2 = select %1, %b, %c : f32
+      %3 = addf %a, %2 : f32
+      linalg.yield %3 : f32
+  } : memref<?xf32>, memref<f32>, memref<f32>
+  return
+}
+// CHECK-LABEL: @indexed_generic_op_1D_reduce
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK:   %[[a:.*]] = load %[[ARG0]][%[[i]]]
+// CHECK:   %[[b:.*]] = load %[[ARG1]][]
+// CHECK:   %[[c:.*]] = load %[[ARG2]][]
+// CHECK:   %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
+// CHECK:   %[[e:.*]] = addf %[[a]], %[[d]]
+// CHECK:   store %[[e]], %[[ARG2]][]


        


More information about the Mlir-commits mailing list