[Mlir-commits] [mlir] 28e0449 - [mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.indexed_generic ops.

Hanhan Wang llvmlistbot at llvm.org
Wed Feb 19 16:26:58 PST 2020


Author: Hanhan Wang
Date: 2020-02-19T19:24:27-05:00
New Revision: 28e0449ec690cc828fb0d94ecee30c8680e0a3d7

URL: https://github.com/llvm/llvm-project/commit/28e0449ec690cc828fb0d94ecee30c8680e0a3d7
DIFF: https://github.com/llvm/llvm-project/commit/28e0449ec690cc828fb0d94ecee30c8680e0a3d7.diff

LOG: [mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.indexed_generic ops.

Patch D74638 allows linalg.generic ops to use zero-rank shaped type operands,
this also can be applied to linalg.indexed_generic ops.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 4d3ee89908ff..4ebd51c0d9ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -344,10 +344,14 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
 
     // 1.a. Emit std_load from input views.
     for (unsigned i = 0; i < nInputs; ++i) {
-      ValueHandleArray indexing(makeCanonicalAffineApplies(
-          b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
-      indexedValues[nLoops + i] =
-          std_load(indexedGenericOp.getInput(i), indexing);
+      Value input = indexedGenericOp.getInput(i);
+      if (!input.getType().cast<ShapedType>().getRank()) {
+        indexedValues[nLoops + i] = std_load(input);
+      } else {
+        ValueHandleArray indexing(makeCanonicalAffineApplies(
+            b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs));
+        indexedValues[nLoops + i] = std_load(input, indexing);
+      }
     }
 
     // 1.b. Emit std_load from output views.

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 3d8d7f52445e..6aaa1ba37aa8 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -359,10 +359,9 @@ func @indexed_generic_region(
 
 // -----
 
-
 #broadcast_access = [
   affine_map<(i, j) -> (0)>,
-  affine_map<(i,j) -> (i,j)>
+  affine_map<(i, j) -> (i, j)>
 ]
 
 #trait_broadcast = {
@@ -373,10 +372,10 @@ func @indexed_generic_region(
   library_call = "some_broadcast_external_fn"
 }
 
-func @generic_op_zero_rank(%arg0 : memref<f32>, %arg1:  memref<3x4xf32>)
+func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)
 {
   linalg.generic #trait_broadcast %arg0, %arg1 {
-    ^bb(%a: f32, %b : f32) :
+    ^bb(%a: f32, %b: f32) :
       linalg.yield %a : f32
   } : memref<f32>, memref<3x4xf32>
   return
@@ -389,3 +388,26 @@ func @generic_op_zero_rank(%arg0 : memref<f32>, %arg1:  memref<3x4xf32>)
 // CHECK:   loop.for %[[j:.*]] = {{.*}}
 // CHECK:     %[[a:.*]] = load %[[ARG0]][]
 // CHECK:     store %[[a]], %[[ARG1]][%[[i]], %[[j]]]
+
+func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
+{
+  linalg.indexed_generic #trait_broadcast %arg0, %arg1 {
+    ^bb(%i: index, %j: index, %a: i32, %b: i32) :
+      %ij = addi %i, %j : index
+      %ij_int = index_cast %ij : index to i32
+      %result = addi %a, %ij_int : i32
+      linalg.yield %result : i32
+  } : memref<i32>, memref<3x4xi32>
+  return
+}
+
+// CHECK-LABEL: @indexed_generic_op_zero_rank
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK:   loop.for %[[j:.*]] = {{.*}}
+// CHECK:     %[[a:.*]] = load %[[ARG0]][
+// CHECK:     %[[ij:.*]] = addi %[[i]], %[[j]] : index
+// CHECK:     %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
+// CHECK:     %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
+// CHECK:     store %[[result]], %[[ARG1]][%[[i]], %[[j]]]

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index ec28510d5060..450422411b22 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -347,7 +347,7 @@ func @indexed_generic_with_tensor_input_and_output(
 
 #broadcast_access = [
   affine_map<(i, j) -> (0)>,
-  affine_map<(i,j) -> (i,j)>
+  affine_map<(i, j) -> (i, j)>
 ]
 
 #trait_broadcast = {
@@ -358,7 +358,7 @@ func @indexed_generic_with_tensor_input_and_output(
   library_call = "some_broadcast_external_fn"
 }
 
-func @generic_op_zero_rank(%arg0 : tensor<f32>) ->  (tensor<3x4xf32>)
+func @generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
 {
   %0 = linalg.generic #trait_broadcast %arg0 {
     ^bb(%a: f32) :
@@ -367,6 +367,15 @@ func @generic_op_zero_rank(%arg0 : tensor<f32>) ->  (tensor<3x4xf32>)
   return %0 : tensor<3x4xf32>
 }
 
+func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
+{
+  %0 = linalg.indexed_generic #trait_broadcast %arg0 {
+    ^bb(%i: index, %j: index, %a: f32) :
+      linalg.yield %a : f32
+  } : tensor<f32> -> tensor<3x4xf32>
+  return %0 : tensor<3x4xf32>
+}
+
 // -----
 
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>


        


More information about the Mlir-commits mailing list