[Mlir-commits] [mlir] 755c050 - [mlir][Linalg] Fix load/store operations generated while lower loops when
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 4 17:05:17 PST 2020
Author: MaheshRavishankar
Date: 2020-03-04T17:04:30-08:00
New Revision: 755c050200bad608dffd376929a230cd5d9936d7
URL: https://github.com/llvm/llvm-project/commit/755c050200bad608dffd376929a230cd5d9936d7
DIFF: https://github.com/llvm/llvm-project/commit/755c050200bad608dffd376929a230cd5d9936d7.diff
LOG: [mlir][Linalg] Fix load/store operations generated while lower loops when
output has zero rank.
While lowering to loops, no indices should be used in the load/store
operation if the buffer is zero-rank.
Differential Revision: https://reviews.llvm.org/D75391
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
mlir/test/Dialect/Linalg/loops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 3c1ab7842ac3..5701c37bf95f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -242,21 +242,25 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
Value input = genericOp.getInput(i);
- if (!input.getType().cast<ShapedType>().getRank()) {
- indexedValues[i] = std_load(input);
- } else {
+ if (input.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs));
indexedValues[i] = std_load(input, indexing);
+ } else {
+ indexedValues[i] = std_load(input);
}
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(makeCanonicalAffineApplies(
- b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- indexedValues[nInputs + i] =
- std_load(genericOp.getOutputBuffer(i), indexing);
+ Value output = genericOp.getOutputBuffer(i);
+ if (output.getType().cast<ShapedType>().getRank()) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+ indexedValues[nInputs + i] = std_load(output, indexing);
+ } else {
+ indexedValues[nInputs + i] = std_load(output);
+ }
}
auto funcOp = genericOp.getFunction();
@@ -267,9 +271,14 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(makeCanonicalAffineApplies(
- b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing);
+ Value output = genericOp.getOutputBuffer(i);
+ if (output.getType().cast<ShapedType>().getRank()) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+ std_store(callOp->getResult(i), output, indexing);
+ } else {
+ std_store(callOp->getResult(i), output);
+ }
}
return;
}
@@ -288,10 +297,15 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(makeCanonicalAffineApplies(
- b, loc, genericOp.getOutputIndexingMap(i), allIvs));
- std_store(map.lookup(yieldOp->getOperand(i)),
- genericOp.getOutputBuffer(i), indexing);
+ Value output = genericOp.getOutputBuffer(i);
+ if (output.getType().cast<ShapedType>().getRank()) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+ std_store(map.lookup(yieldOp->getOperand(i)),
+ genericOp.getOutputBuffer(i), indexing);
+ } else {
+ std_store(map.lookup(yieldOp->getOperand(i)), output);
+ }
}
}
};
@@ -348,21 +362,25 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
Value input = indexedGenericOp.getInput(i);
- if (!input.getType().cast<ShapedType>().getRank()) {
- indexedValues[nLoops + i] = std_load(input);
- } else {
+ if (input.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
indexedValues[nLoops + i] = std_load(input, indexing);
+ } else {
+ indexedValues[nLoops + i] = std_load(input);
}
}
// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(makeCanonicalAffineApplies(
- b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
- indexedValues[nLoops + nInputs + i] =
- std_load(indexedGenericOp.getOutputBuffer(i), indexing);
+ Value output = indexedGenericOp.getOutputBuffer(i);
+ if (output.getType().cast<ShapedType>().getRank()) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+ indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
+ } else {
+ indexedValues[nLoops + nInputs + i] = std_load(output);
+ }
}
if (auto funcOp = indexedGenericOp.getFunction()) {
@@ -372,10 +390,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(makeCanonicalAffineApplies(
- b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
- std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i),
- indexing);
+ Value output = indexedGenericOp.getOutputBuffer(i);
+ if (output.getType().cast<ShapedType>().getRank()) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+ std_store(callOp->getResult(i), output, indexing);
+ } else {
+ std_store(callOp->getResult(i), output);
+ }
}
return;
}
@@ -394,10 +416,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0; i < nOutputs; ++i) {
- ValueHandleArray indexing(makeCanonicalAffineApplies(
- b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
- std_store(map.lookup(yieldOp->getOperand(i)),
- indexedGenericOp.getOutputBuffer(i), indexing);
+ Value output = indexedGenericOp.getOutputBuffer(i);
+ if (output.getType().cast<ShapedType>().getRank()) {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
+ std_store(map.lookup(yieldOp->getOperand(i)), output, indexing);
+ } else {
+ std_store(map.lookup(yieldOp->getOperand(i)), output);
+ }
}
}
};
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 6aaa1ba37aa8..59487c71eedb 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -411,3 +411,75 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
// CHECK: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
// CHECK: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
// CHECK: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
+
+#reduce_1D_access = [
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> (0)>
+]
+
+#trait_reduce_1D = {
+ args_in = 1,
+ args_out = 1,
+ indexing_maps = #reduce_1D_access,
+ iterator_types = ["reduction"],
+ library_call = "some_reduce_external_fn"
+}
+
+func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
+{
+ linalg.generic #trait_reduce_1D %arg0, %arg1 {
+ ^bb(%a: f32, %b: f32) :
+ %0 = addf %a, %b : f32
+ linalg.yield %0 : f32
+ } : memref<?xf32>, memref<f32>
+ return
+}
+// CHECK-LABEL: @generic_op_1D_reduce
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]]
+// CHECK: %[[b:.*]] = load %[[ARG1]][]
+// CHECK: %[[c:.*]] = addf %[[a]], %[[b]] : f32
+// CHECK: store %[[c]], %[[ARG1]][]
+
+
+#reduce_init_1D_access = [
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> (0)>,
+ affine_map<(i) -> (0)>
+]
+
+#trait_reduce_init_1D = {
+ args_in = 2,
+ args_out = 1,
+ indexing_maps = #reduce_init_1D_access,
+ iterator_types = ["reduction"],
+ library_call = "some_reduce_external_fn"
+}
+
+func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
+ %arg1: memref<f32>,
+ %arg2: memref<f32>)
+{
+ linalg.indexed_generic #trait_reduce_init_1D %arg0, %arg1, %arg2 {
+ ^bb(%i : index, %a: f32, %b: f32, %c: f32) :
+ %0 = constant 0 : index
+ %1 = cmpi "eq", %0, %i : index
+ %2 = select %1, %b, %c : f32
+ %3 = addf %a, %2 : f32
+ linalg.yield %3 : f32
+ } : memref<?xf32>, memref<f32>, memref<f32>
+ return
+}
+// CHECK-LABEL: @indexed_generic_op_1D_reduce
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK: %[[a:.*]] = load %[[ARG0]][%[[i]]]
+// CHECK: %[[b:.*]] = load %[[ARG1]][]
+// CHECK: %[[c:.*]] = load %[[ARG2]][]
+// CHECK: %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
+// CHECK: %[[e:.*]] = addf %[[a]], %[[d]]
+// CHECK: store %[[e]], %[[ARG2]][]
More information about the Mlir-commits
mailing list