[Mlir-commits] [mlir] a8355b5 - [mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.generic ops.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 18 13:23:53 PST 2020


Author: MaheshRavishankar
Date: 2020-02-18T13:23:28-08:00
New Revision: a8355b5c0f67d560ed8ec28133c58442bd5b93be

URL: https://github.com/llvm/llvm-project/commit/a8355b5c0f67d560ed8ec28133c58442bd5b93be
DIFF: https://github.com/llvm/llvm-project/commit/a8355b5c0f67d560ed8ec28133c58442bd5b93be.diff

LOG: [mlir][Linalg] Allow specifiying zero-rank shaped type operands to linalg.generic ops.

Fixing a bug where using a zero-rank shaped type operand to
linalg.generic ops hit an unrelated assert. This also meant that
lowering the operation to loops was not supported. Adding roundtrip
tests and lowering to loops test for zero-rank shaped type operand
with fixes to make the test pass.

Differential Revision: https://reviews.llvm.org/D74638

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9adc20c4a79e..a2fe01847edb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -361,11 +361,10 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
       if (!cst || cst.getValue() != 0)
         return op.emitOpError("expected indexing_map #")
                << idx << " to be 0 to match 0-D view: " << view;
-    }
-
-    if (m.getNumResults() != view.getRank())
+    } else if (m.getNumResults() != view.getRank()) {
       return op.emitOpError("expected indexing_map #")
              << idx << " results to match view rank: " << view;
+    }
   }
 
   auto concatMap = concatAffineMaps(indexingMaps);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index a160ccd1e5c6..4d3ee89908ff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -238,9 +238,14 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
 
     // 1.a. Emit std_load from input views.
     for (unsigned i = 0; i < nInputs; ++i) {
-      ValueHandleArray indexing(makeCanonicalAffineApplies(
-          b, loc, genericOp.getInputIndexingMap(i), allIvs));
-      indexedValues[i] = std_load(genericOp.getInput(i), indexing);
+      Value input = genericOp.getInput(i);
+      if (!input.getType().cast<ShapedType>().getRank()) {
+        indexedValues[i] = std_load(input);
+      } else {
+        ValueHandleArray indexing(makeCanonicalAffineApplies(
+            b, loc, genericOp.getInputIndexingMap(i), allIvs));
+        indexedValues[i] = std_load(input, indexing);
+      }
     }
 
     // 1.b. Emit std_load from output views.

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 53e69498ed7f..bbc88e9156b6 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -351,12 +351,12 @@ AffineMap mlir::inversePermutation(AffineMap map) {
 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
   unsigned numResults = 0;
   for (auto m : maps)
-    numResults += m ? m.getNumResults() : 0;
+    numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0;
   unsigned numDims = 0;
   SmallVector<AffineExpr, 8> results;
   results.reserve(numResults);
   for (auto m : maps) {
-    if (!m)
+    if (!m || m.isSingleConstant())
       continue;
     assert(m.getNumSymbols() == 0 && "expected map without symbols");
     results.append(m.getResults().begin(), m.getResults().end());

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 260f602e4ede..3d8d7f52445e 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -356,3 +356,36 @@ func @indexed_generic_region(
 // CHECK:       %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
 // CHECK:       store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
 // CHECK:       store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
+
+// -----
+
+
+#broadcast_access = [
+  affine_map<(i, j) -> (0)>,
+  affine_map<(i,j) -> (i,j)>
+]
+
+#trait_broadcast = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #broadcast_access,
+  iterator_types = ["parallel", "parallel"],
+  library_call = "some_broadcast_external_fn"
+}
+
+func @generic_op_zero_rank(%arg0 : memref<f32>, %arg1:  memref<3x4xf32>)
+{
+  linalg.generic #trait_broadcast %arg0, %arg1 {
+    ^bb(%a: f32, %b : f32) :
+      linalg.yield %a : f32
+  } : memref<f32>, memref<3x4xf32>
+  return
+}
+
+// CHECK-LABEL: @generic_op_zero_rank
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<f32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xf32>
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK:   loop.for %[[j:.*]] = {{.*}}
+// CHECK:     %[[a:.*]] = load %[[ARG0]][]
+// CHECK:     store %[[a]], %[[ARG1]][%[[i]], %[[j]]]

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4b81dce48c8e..ec28510d5060 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -345,6 +345,30 @@ func @indexed_generic_with_tensor_input_and_output(
 
 // -----
 
+#broadcast_access = [
+  affine_map<(i, j) -> (0)>,
+  affine_map<(i,j) -> (i,j)>
+]
+
+#trait_broadcast = {
+  args_in = 1,
+  args_out = 1,
+  indexing_maps = #broadcast_access,
+  iterator_types = ["parallel", "parallel"],
+  library_call = "some_broadcast_external_fn"
+}
+
+func @generic_op_zero_rank(%arg0 : tensor<f32>) ->  (tensor<3x4xf32>)
+{
+  %0 = linalg.generic #trait_broadcast %arg0 {
+    ^bb(%a: f32) :
+      linalg.yield %a : f32
+  } : tensor<f32> -> tensor<3x4xf32>
+  return %0 : tensor<3x4xf32>
+}
+
+// -----
+
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
 


        


More information about the Mlir-commits mailing list