[Mlir-commits] [mlir] 28e0449 - [mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.indexed_generic ops.
Hanhan Wang
llvmlistbot at llvm.org
Wed Feb 19 16:26:58 PST 2020
Author: Hanhan Wang
Date: 2020-02-19T19:24:27-05:00
New Revision: 28e0449ec690cc828fb0d94ecee30c8680e0a3d7
URL: https://github.com/llvm/llvm-project/commit/28e0449ec690cc828fb0d94ecee30c8680e0a3d7
DIFF: https://github.com/llvm/llvm-project/commit/28e0449ec690cc828fb0d94ecee30c8680e0a3d7.diff
LOG: [mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.indexed_generic ops.
Patch D74638 allows linalg.generic ops to use zero-rank shaped type operands,
this also can be applied to linalg.indexed_generic ops.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 4d3ee89908ff..4ebd51c0d9ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -344,10 +344,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
- ValueHandleArray indexing(makeCanonicalAffineApplies(
- b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
- indexedValues[nLoops + i] =
- std_load(indexedGenericOp.getInput(i), indexing);
+ Value input = indexedGenericOp.getInput(i);
+ if (!input.getType().cast<ShapedType>().getRank()) {
+ indexedValues[nLoops + i] = std_load(input);
+ } else {
+ ValueHandleArray indexing(makeCanonicalAffineApplies(
+ b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
+ indexedValues[nLoops + i] = std_load(input, indexing);
+ }
}
// 1.b. Emit std_load from output views.
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 3d8d7f52445e..6aaa1ba37aa8 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -359,10 +359,9 @@ func @indexed_generic_region(
// -----
-
#broadcast_access = [
affine_map<(i, j) -> (0)>,
- affine_map<(i,j) -> (i,j)>
+ affine_map<(i, j) -> (i, j)>
]
#trait_broadcast = {
@@ -373,10 +372,10 @@ func @indexed_generic_region(
library_call = "some_broadcast_external_fn"
}
-func @generic_op_zero_rank(%arg0 : memref<f32>, %arg1: memref<3x4xf32>)
+func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)
{
linalg.generic #trait_broadcast %arg0, %arg1 {
- ^bb(%a: f32, %b : f32) :
+ ^bb(%a: f32, %b: f32) :
linalg.yield %a : f32
} : memref<f32>, memref<3x4xf32>
return
@@ -389,3 +388,26 @@ func @generic_op_zero_rank(%arg0 : memref<f32>, %arg1: memref<3x4xf32>)
// CHECK: loop.for %[[j:.*]] = {{.*}}
// CHECK: %[[a:.*]] = load %[[ARG0]][]
// CHECK: store %[[a]], %[[ARG1]][%[[i]], %[[j]]]
+
+func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
+{
+ linalg.indexed_generic #trait_broadcast %arg0, %arg1 {
+ ^bb(%i: index, %j: index, %a: i32, %b: i32) :
+ %ij = addi %i, %j : index
+ %ij_int = index_cast %ij : index to i32
+ %result = addi %a, %ij_int : i32
+ linalg.yield %result : i32
+ } : memref<i32>, memref<3x4xi32>
+ return
+}
+
+// CHECK-LABEL: @indexed_generic_op_zero_rank
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK: loop.for %[[j:.*]] = {{.*}}
+// CHECK: %[[a:.*]] = load %[[ARG0]][
+// CHECK: %[[ij:.*]] = addi %[[i]], %[[j]] : index
+// CHECK: %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
+// CHECK: %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
+// CHECK: store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index ec28510d5060..450422411b22 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -347,7 +347,7 @@ func @indexed_generic_with_tensor_input_and_output(
#broadcast_access = [
affine_map<(i, j) -> (0)>,
- affine_map<(i,j) -> (i,j)>
+ affine_map<(i, j) -> (i, j)>
]
#trait_broadcast = {
@@ -358,7 +358,7 @@ func @indexed_generic_with_tensor_input_and_output(
library_call = "some_broadcast_external_fn"
}
-func @generic_op_zero_rank(%arg0 : tensor<f32>) -> (tensor<3x4xf32>)
+func @generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
{
%0 = linalg.generic #trait_broadcast %arg0 {
^bb(%a: f32) :
@@ -367,6 +367,15 @@ func @generic_op_zero_rank(%arg0 : tensor<f32>) -> (tensor<3x4xf32>)
return %0 : tensor<3x4xf32>
}
+func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
+{
+ %0 = linalg.indexed_generic #trait_broadcast %arg0 {
+ ^bb(%i: index, %j: index, %a: f32) :
+ linalg.yield %a : f32
+ } : tensor<f32> -> tensor<3x4xf32>
+ return %0 : tensor<3x4xf32>
+}
+
// -----
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
More information about the Mlir-commits
mailing list