[Mlir-commits] [mlir] 47ec870 - [mlir][Linalg] Revisit 0-D abstraction

Nicolas Vasilache llvmlistbot at llvm.org
Tue Mar 10 12:19:48 PDT 2020


Author: Nicolas Vasilache
Date: 2020-03-10T15:14:09-04:00
New Revision: 47ec8702cbc6f607b2e5cc25270a560eb9e02710

URL: https://github.com/llvm/llvm-project/commit/47ec8702cbc6f607b2e5cc25270a560eb9e02710
DIFF: https://github.com/llvm/llvm-project/commit/47ec8702cbc6f607b2e5cc25270a560eb9e02710.diff

LOG: [mlir][Linalg] Revisit 0-D abstraction

This revision takes advantage of the empty AffineMap to specify the
0-D edge case. This allows removing a bunch of annoying corner cases
that ended up impacting users of Linalg.

Differential Revision: https://reviews.llvm.org/D75831

Added: 
    

Modified: 
    mlir/docs/Dialects/Affine.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/loops.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Affine.md b/mlir/docs/Dialects/Affine.md
index 245ba33fed6c..4a7d5c30eda7 100644
--- a/mlir/docs/Dialects/Affine.md
+++ b/mlir/docs/Dialects/Affine.md
@@ -91,7 +91,8 @@ affine-expr ::= `(` affine-expr `)`
               | bare-id
               | `-`? integer-literal
 
-multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
+multi-dim-affine-expr ::= `(` `)`
+                        | `(` affine-expr (`,` affine-expr)* `)`
 ```
 
 `ceildiv` is the ceiling function which maps the result of the division of its

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index bd9ad75cc766..a93486744a2d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -184,7 +184,9 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
       MLIRContext *context = getContext();
       auto r_i = getAffineDimExpr(0, context);
       return SmallVector<AffineMap, 8>{
-        AffineMap::get(1, 0, {r_i}), AffineMap::get(1, 0, {r_i}), AffineMap()};
+        AffineMap::get(1, 0, {r_i}),
+        AffineMap::get(1, 0, {r_i}),
+        AffineMap::get(1, 0, context)};
     }
   }];
 

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 0bf52e32f3ab..14deb85fb2f0 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -44,6 +44,11 @@ class AffineMap {
   /// Returns a zero result affine map with no dimensions or symbols: () -> ().
   static AffineMap get(MLIRContext *context);
 
+  /// Returns a zero result affine map with `dimCount` dimensions and
+  /// `symbolCount` symbols, e.g.: `(...) -> ()`.
+  static AffineMap get(unsigned dimCount, unsigned symbolCount,
+                       MLIRContext *context);
+
   static AffineMap get(unsigned dimCount, unsigned symbolCount,
                        ArrayRef<AffineExpr> results);
 
@@ -275,8 +280,7 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
 namespace llvm {
 
 // AffineExpr hash just like pointers
-template <>
-struct DenseMapInfo<mlir::AffineMap> {
+template <> struct DenseMapInfo<mlir::AffineMap> {
   static mlir::AffineMap getEmptyKey() {
     auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
     return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d4032a8e8c7..19cf7f55bcc4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -356,15 +356,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
              << idx << " to have " << nLoops
              << " dim(s) to match the number of loops";
 
-    if (m.getNumResults() == 1 && view.getRank() == 0) {
-      auto cst = m.getResult(0).template dyn_cast<AffineConstantExpr>();
-      if (!cst || cst.getValue() != 0)
-        return op.emitOpError("expected indexing_map #")
-               << idx << " to be 0 to match 0-D view: " << view;
-    } else if (m.getNumResults() != view.getRank()) {
+    if (m.getNumResults() != view.getRank())
       return op.emitOpError("expected indexing_map #")
              << idx << " results to match view rank: " << view;
-    }
   }
 
   auto concatMap = concatAffineMaps(indexingMaps);
@@ -886,7 +880,7 @@ AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
   if (maybeMap)
     return maybeMap.getValue();
   if (rank == 0)
-    return AffineMap();
+    return AffineMap::get(context);
   return AffineMap::getMultiDimIdentityMap(rank, context);
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 5701c37bf95f..05722036f8e5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -37,6 +37,8 @@ using edsc::op::operator==;
 static SmallVector<ValueHandle, 8>
 makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
                            ArrayRef<Value> vals) {
+  if (map.isEmpty())
+    return {};
   assert(map.getNumSymbols() == 0);
   assert(map.getNumInputs() == vals.size());
   SmallVector<ValueHandle, 8> res;
@@ -241,26 +243,17 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
 
     // 1.a. Emit std_load from input views.
     for (unsigned i = 0; i < nInputs; ++i) {
-      Value input = genericOp.getInput(i);
-      if (input.getType().cast<ShapedType>().getRank()) {
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, genericOp.getInputIndexingMap(i), allIvs));
-        indexedValues[i] = std_load(input, indexing);
-      } else {
-        indexedValues[i] = std_load(input);
-      }
+      ValueHandleArray indexing(makeCanonicalAffineApplies(
+          b, loc, genericOp.getInputIndexingMap(i), allIvs));
+      indexedValues[i] = std_load(genericOp.getInput(i), indexing);
     }
 
     // 1.b. Emit std_load from output views.
     for (unsigned i = 0; i < nOutputs; ++i) {
       Value output = genericOp.getOutputBuffer(i);
-      if (output.getType().cast<ShapedType>().getRank()) {
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-        indexedValues[nInputs + i] = std_load(output, indexing);
-      } else {
-        indexedValues[nInputs + i] = std_load(output);
-      }
+      ValueHandleArray indexing(makeCanonicalAffineApplies(
+          b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+      indexedValues[nInputs + i] = std_load(output, indexing);
     }
 
     auto funcOp = genericOp.getFunction();
@@ -272,13 +265,9 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
       // 3. Emit std_store.
       for (unsigned i = 0; i < nOutputs; ++i) {
         Value output = genericOp.getOutputBuffer(i);
-        if (output.getType().cast<ShapedType>().getRank()) {
-          ValueHandleArray indexing(makeCanonicalAffineApplies(
-              b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-          std_store(callOp->getResult(i), output, indexing);
-        } else {
-          std_store(callOp->getResult(i), output);
-        }
+        ValueHandleArray indexing(makeCanonicalAffineApplies(
+            b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+        std_store(callOp->getResult(i), output, indexing);
       }
       return;
     }
@@ -297,15 +286,10 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
     auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
     assert(yieldOp->getNumOperands() == nOutputs);
     for (unsigned i = 0; i < nOutputs; ++i) {
-      Value output = genericOp.getOutputBuffer(i);
-      if (output.getType().cast<ShapedType>().getRank()) {
-        ValueHandleArray indexing(makeCanonicalAffineApplies(
-            b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-        std_store(map.lookup(yieldOp->getOperand(i)),
-                  genericOp.getOutputBuffer(i), indexing);
-      } else {
-        std_store(map.lookup(yieldOp->getOperand(i)), output);
-      }
+      ValueHandleArray indexing(makeCanonicalAffineApplies(
+          b, loc, genericOp.getOutputIndexingMap(i), allIvs));
+      std_store(map.lookup(yieldOp->getOperand(i)),
+                genericOp.getOutputBuffer(i), indexing);
     }
   }
 };

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index bbc88e9156b6..94b7d4d05eb1 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -281,7 +281,8 @@ AffineMap AffineMap::compose(AffineMap map) {
   exprs.reserve(getResults().size());
   for (auto expr : getResults())
     exprs.push_back(expr.compose(newMap));
-  return AffineMap::get(numDims, numSymbols, exprs);
+  return exprs.empty() ? AffineMap::get(numDims, 0, map.getContext())
+                       : AffineMap::get(numDims, numSymbols, exprs);
 }
 
 bool AffineMap::isProjectedPermutation() {
@@ -325,7 +326,7 @@ AffineMap mlir::simplifyAffineMap(AffineMap map) {
 }
 
 AffineMap mlir::inversePermutation(AffineMap map) {
-  if (!map)
+  if (map.isEmpty())
     return map;
   assert(map.getNumSymbols() == 0 && "expected map without symbols");
   SmallVector<AffineExpr, 4> exprs(map.getNumDims());
@@ -351,18 +352,18 @@ AffineMap mlir::inversePermutation(AffineMap map) {
 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
   unsigned numResults = 0;
   for (auto m : maps)
-    numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0;
+    numResults += m.getNumResults();
   unsigned numDims = 0;
   SmallVector<AffineExpr, 8> results;
   results.reserve(numResults);
   for (auto m : maps) {
-    if (!m || m.isSingleConstant())
-      continue;
     assert(m.getNumSymbols() == 0 && "expected map without symbols");
     results.append(m.getResults().begin(), m.getResults().end());
     numDims = std::max(m.getNumDims(), numDims);
   }
-  return numDims == 0 ? AffineMap() : AffineMap::get(numDims, 0, results);
+  return results.empty() ? AffineMap::get(numDims, /*numSymbols=*/0,
+                                          maps.front().getContext())
+                         : AffineMap::get(numDims, /*numSymbols=*/0, results);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index d38fdb00cd17..198becaed5cc 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -611,6 +611,11 @@ AffineMap AffineMap::get(MLIRContext *context) {
   return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
 }
 
+AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
+                         MLIRContext *context) {
+  return getImpl(dimCount, /*symbolCount=*/0, /*results=*/{}, context);
+}
+
 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
                          ArrayRef<AffineExpr> results) {
   // The number of results can't be zero.

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 9c699284b746..209668adc6f3 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -3068,14 +3068,16 @@ AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
   };
 
   // Parse a multi-dimensional affine expression (a comma-separated list of
-  // 1-d affine expressions); the list cannot be empty. Grammar:
-  // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+  // 1-d affine expressions); the list can be empty. Grammar:
+  // multi-dim-affine-expr ::= `(` `)`
+  //                         | `(` affine-expr (`,` affine-expr)* `)`
   if (parseCommaSeparatedListUntil(rightToken, parseElt,
                                    /*allowEmptyList=*/true))
     return failure();
   // Parsed a valid affine map.
   if (exprs.empty())
-    map = AffineMap::get(getContext());
+    map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
+                         getContext());
   else
     map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
                          exprs);
@@ -3101,13 +3103,14 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
   };
 
   // Parse a multi-dimensional affine expression (a comma-separated list of
-  // 1-d affine expressions); the list cannot be empty. Grammar:
-  // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+  // 1-d affine expressions). Grammar:
+  // multi-dim-affine-expr ::= `(` `)`
+  //                         | `(` affine-expr (`,` affine-expr)* `)`
   if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
     return AffineMap();
 
   if (exprs.empty())
-    return AffineMap::get(getContext());
+    return AffineMap::get(numDims, numSymbols, getContext());
 
   // Parsed a valid affine map.
   return AffineMap::get(numDims, numSymbols, exprs);

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c7c0752e112f..444b91bbe19e 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -170,7 +170,7 @@ func @generic_symbol_in_map(%arg0: memref<i32>) {
 
 func @foo(%0: i32) -> i32 { return %0: i32 }
 
-func @generic_wrong_dim_in_map(%arg0: memref<i32>) {
+func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
   // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
   linalg.generic {
     args_in = 0,
@@ -178,22 +178,7 @@ func @generic_wrong_dim_in_map(%arg0: memref<i32>) {
     fun = @foo,
     indexing_maps =  [ affine_map<() -> (0)> ],
     iterator_types = ["parallel"]
-  } %arg0: memref<i32>
-}
-
-// -----
-
-func @foo(%0: i32) -> i32 { return %0: i32 }
-
-func @generic_zero_d_view(%arg0: memref<i32>) {
-  // expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref<i32>'}}
-  linalg.generic {
-    args_in = 0,
-    args_out = 1,
-    fun = @foo,
-    indexing_maps =  [ affine_map<() -> (1)> ],
-    iterator_types = []
-  } %arg0: memref<i32>
+  } %arg0: memref<1xi32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 59487c71eedb..f0c9a8bf6e16 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -360,7 +360,7 @@ func @indexed_generic_region(
 // -----
 
 #broadcast_access = [
-  affine_map<(i, j) -> (0)>,
+  affine_map<(i, j) -> ()>,
   affine_map<(i, j) -> (i, j)>
 ]
 
@@ -414,7 +414,7 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
 
 #reduce_1D_access = [
   affine_map<(i) -> (i)>,
-  affine_map<(i) -> (0)>
+  affine_map<(i) -> ()>
 ]
 
 #trait_reduce_1D = {
@@ -446,8 +446,8 @@ func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
 
 #reduce_init_1D_access = [
   affine_map<(i) -> (i)>,
-  affine_map<(i) -> (0)>,
-  affine_map<(i) -> (0)>
+  affine_map<(i) -> ()>,
+  affine_map<(i) -> ()>
 ]
 
 #trait_reduce_init_1D = {

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 450422411b22..5cc3ab621df5 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -346,7 +346,7 @@ func @indexed_generic_with_tensor_input_and_output(
 // -----
 
 #broadcast_access = [
-  affine_map<(i, j) -> (0)>,
+  affine_map<(i, j) -> ()>,
   affine_map<(i, j) -> (i, j)>
 ]
 


        


More information about the Mlir-commits mailing list