[Mlir-commits] [mlir] 9f3ab92 - [MLIR] Improve support for 0-dimensional Affine Maps.

Jeremy Bruestle llvmlistbot at llvm.org
Wed Apr 15 14:15:34 PDT 2020


Author: Jeremy Bruestle
Date: 2020-04-15T14:15:02-07:00
New Revision: 9f3ab92ec86953e310d0814a95d9c0213bfe05d4

URL: https://github.com/llvm/llvm-project/commit/9f3ab92ec86953e310d0814a95d9c0213bfe05d4
DIFF: https://github.com/llvm/llvm-project/commit/9f3ab92ec86953e310d0814a95d9c0213bfe05d4.diff

LOG: [MLIR] Improve support for 0-dimensional Affine Maps.

Summary:
Modified AffineMap::get to remove support for the overload which allowed
an ArrayRef of AffineExpr but no context (and gathered the context from a
presumed first entry, resulting in bugs when there were 0 results).

Instead, we support only a ArrayRef and a context, and a version which
takes a single AffineExpr.

Additionally, removed some now needless case logic which previously
special cased which call to AffineMap::get to use.

Reviewers: flaub, bondhugula, rriddle!, nicolasvasilache, ftynse, ulysseB, mravishankar, antiagainst, aartbik

Subscribers: mehdi_amini, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, bader, grosul1, frgossen, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D78226

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/LoopAnalysis.cpp
    mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
    mlir/lib/Dialect/Affine/EDSC/Builders.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Affine/IR/AffineValueMap.cpp
    mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/lib/Transforms/PipelineDataTransfer.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/test/Dialect/Affine/simplify-affine-structures.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 641039afd15d..0ff455391cb4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -188,9 +188,9 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
       MLIRContext *context = getContext();
       auto r_i = getAffineDimExpr(0, context);
       return SmallVector<AffineMap, 8>{
-        AffineMap::get(1, 0, {r_i}),
-        AffineMap::get(1, 0, {r_i}),
-        AffineMap::get(1, 0, context)};
+        AffineMap::get(1, 0, {r_i}, context),
+        AffineMap::get(1, 0, {r_i}, context),
+        AffineMap::get(1, 0, {}, context)};
     }
   }];
 
@@ -215,8 +215,10 @@ def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
       AffineExpr i, r_j;
       bindDims(context, i, r_j);
       return SmallVector<AffineMap, 8>{
-        AffineMap::get(2, 0, {i, r_j}), AffineMap::get(2, 0, {r_j}),
-        AffineMap::get(2, 0, {i})};
+        AffineMap::get(2, 0, {i, r_j}, context), 
+        AffineMap::get(2, 0, {r_j}, context),
+        AffineMap::get(2, 0, {i}, context)
+      };
     }
   }];
 
@@ -242,9 +244,11 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
       MLIRContext *context = getContext();
       AffineExpr i, j, r_k;
       bindDims(context, i, j, r_k);
-      return SmallVector<AffineMap, 8>{AffineMap::get(3, 0, {i, r_k}),
-                                       AffineMap::get(3, 0, {r_k, j}),
-                                       AffineMap::get(3, 0, {i, j})};
+      return SmallVector<AffineMap, 8>{
+        AffineMap::get(3, 0, {i, r_k}, context),
+        AffineMap::get(3, 0, {r_k, j},context),
+        AffineMap::get(3, 0, {i, j}, context)
+      };
     }
   }];
 
@@ -403,15 +407,15 @@ def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
       auto ws = weightedPoolingInputIndex(*this, xs, zs);
       return SmallVector<AffineMap, 8>{
         // filter[z[0], ..., z[N-1], q, k]
-        AffineMap::get(idx, 0, concat(concat(zs, qs), ks)),
+        AffineMap::get(idx, 0, concat(concat(zs, qs), ks), context),
         // input[b,
         //       x[0]*s[0] + d[0]*z[0] - pad_low[0],
         //       ...
         //       x[N-1]*s[N-1] + d[N-1]*z[N-1] - pad_low[N-1],
         //       q]
-        AffineMap::get(idx, 0, concat(concat(bs, ws), qs)),
+        AffineMap::get(idx, 0, concat(concat(bs, ws), qs), context),
         // output[b, x[0], ..., x[N-1], k]
-        AffineMap::get(idx, 0, concat(concat(bs, xs), ks))};
+        AffineMap::get(idx, 0, concat(concat(bs, xs), ks), context)};
     }
   }];
 
@@ -465,11 +469,11 @@ class SingleInputPoolingBase_Op<string mnemonic>
           weightedPoolingInputIndex(*this, outputDims, windowDims);
       return SmallVector<AffineMap, 8>{
         // input
-        AffineMap::get(idx, 0, inputDims),
+        AffineMap::get(idx, 0, inputDims, context),
         // windowDims
-        AffineMap::get(idx, 0, windowDims),
+        AffineMap::get(idx, 0, windowDims, context),
         // output
-        AffineMap::get(idx, 0, outputDims)
+        AffineMap::get(idx, 0, outputDims, context)
         };
     }
   }];

diff  --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
index bb37bb28a18c..6262e7757c6c 100644
--- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
@@ -25,22 +25,24 @@
 namespace mlir {
 
 inline bool isRowMajorMatmul(ArrayAttr indexingMaps) {
+  auto context = indexingMaps.getContext();
   AffineExpr m, n, k;
-  bindDims(indexingMaps.getContext(), m, n, k);
-  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}));
-  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}));
-  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}));
-  auto maps = ArrayAttr::get({mapA, mapB, mapC}, indexingMaps.getContext());
+  bindDims(context, m, n, k);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
+  auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
   return indexingMaps == maps;
 }
 
 inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
+  auto context = indexingMaps.getContext();
   AffineExpr m, n, k;
-  bindDims(indexingMaps.getContext(), m, n, k);
-  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}));
-  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}));
-  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}));
-  auto maps = ArrayAttr::get({mapA, mapB, mapC}, indexingMaps.getContext());
+  bindDims(context, m, n, k);
+  auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
+  auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
+  auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
+  auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
   return indexingMaps == maps;
 }
 

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 79960cbb61f6..21c39baffeac 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -49,13 +49,13 @@ class AffineMap {
   static AffineMap get(unsigned dimCount, unsigned symbolCount,
                        MLIRContext *context);
 
-  /// Returns an affine map with `dimCount` dimensions and `symbolCount` symbols
-  /// mapping to the given results. The array of results cannot be empty.
+  /// Returns an affine map with `dimCount` dimensions and `symbolCount` mapping
+  /// to a single output dimension
   static AffineMap get(unsigned dimCount, unsigned symbolCount,
-                       ArrayRef<AffineExpr> results);
+                       AffineExpr result);
 
   /// Returns an affine map with `dimCount` dimensions and `symbolCount` mapping
-  /// to the given results, where the number of results can be zero.
+  /// to the given results.
   static AffineMap get(unsigned dimCount, unsigned symbolCount,
                        ArrayRef<AffineExpr> results, MLIRContext *context);
 

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index f9e8bf52deee..17e380d60b05 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1464,11 +1464,8 @@ std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
     lbExprs.push_back(expr);
   }
 
-  auto lbMap = lbExprs.empty() ? AffineMap()
-                               : AffineMap::get(dimCount, symCount, lbExprs);
-
-  auto ubMap = ubExprs.empty() ? AffineMap()
-                               : AffineMap::get(dimCount, symCount, ubExprs);
+  auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
+  auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
 
   return {lbMap, ubMap};
 }

diff  --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index 40e176367d18..0c8ff3167e69 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -62,8 +62,8 @@ void mlir::buildTripCountMapAndOperands(
 
   SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
                                          lbMap.getResult(0));
-  auto lbMapSplat =
-      AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), lbSplatExpr);
+  auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
+                                   lbSplatExpr, b.getContext());
   AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
 
   AffineValueMap tripCountValueMap;

diff  --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index e9acab21fc62..14fdc9b207fb 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -99,11 +99,11 @@ SingleWorkgroupReduction::matchAsPerformingReduction(
       genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>();
   // The indexing map for the input should be `(i) -> (i)`.
   if (inputMap.getValue() !=
-      AffineMap::get(1, 0, {getAffineDimExpr(0, op->getContext())}))
+      AffineMap::get(1, 0, getAffineDimExpr(0, op->getContext())))
     return llvm::None;
   // The indexing map for the input should be `(i) -> (0)`.
   if (outputMap.getValue() !=
-      AffineMap::get(1, 0, {getAffineConstantExpr(0, op->getContext())}))
+      AffineMap::get(1, 0, getAffineConstantExpr(0, op->getContext())))
     return llvm::None;
 
   return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);

diff  --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
index bc4dca225cbf..5d9034c9b6d5 100644
--- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp
@@ -129,7 +129,7 @@ static ValueHandle createBinaryIndexHandle(
   if (v1) {
     operands.push_back(v1);
   }
-  auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)});
+  auto map = AffineMap::get(numDims, numSymbols, affCombiner(d0, d1));
   // TODO: createOrFold when available.
   Operation *op =
       makeComposedAffineApply(ScopedContext::getBuilder(),

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7894515f2417..ca35bafcd2c4 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -522,7 +522,8 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
          "Unexpected number of concatenated symbols");
   auto numDims = dimValueToPosition.size();
   auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
-  auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs);
+  auto auxiliaryMap =
+      AffineMap::get(numDims, numSymbols, auxiliaryExprs, map.getContext());
 
   LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
   LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
@@ -2163,19 +2164,13 @@ LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
 
 void AffineParallelOp::build(Builder *builder, OperationState &result,
                              ArrayRef<int64_t> ranges) {
-  // Default initialize empty maps.
-  auto lbMap = AffineMap::get(builder->getContext());
-  auto ubMap = AffineMap::get(builder->getContext());
-  // If there are ranges, set each to [0, N).
-  if (ranges.size()) {
-    SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
-                                       builder->getAffineConstantExpr(0));
-    lbMap = AffineMap::get(0, 0, lbExprs);
-    SmallVector<AffineExpr, 8> ubExprs;
-    for (int64_t range : ranges)
-      ubExprs.push_back(builder->getAffineConstantExpr(range));
-    ubMap = AffineMap::get(0, 0, ubExprs);
-  }
+  SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
+                                     builder->getAffineConstantExpr(0));
+  auto lbMap = AffineMap::get(0, 0, lbExprs, builder->getContext());
+  SmallVector<AffineExpr, 8> ubExprs;
+  for (int64_t range : ranges)
+    ubExprs.push_back(builder->getAffineConstantExpr(range));
+  auto ubMap = AffineMap::get(0, 0, ubExprs, builder->getContext());
   build(builder, result, lbMap, {}, ubMap, {});
 }
 

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineValueMap.cpp b/mlir/lib/Dialect/Affine/IR/AffineValueMap.cpp
index c17f59323a7f..792ca379cef4 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineValueMap.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineValueMap.cpp
@@ -51,8 +51,9 @@ void AffineValueMap::
diff erence(const AffineValueMap &a,
     
diff Exprs.push_back(normalizer.getAffineMap().getResult(i) -
                         bMap.getResult(i));
 
-  auto 
diff Map = AffineMap::get(normalizer.getNumDims(),
-                                normalizer.getNumSymbols(), 
diff Exprs);
+  auto 
diff Map =
+      AffineMap::get(normalizer.getNumDims(), normalizer.getNumSymbols(),
+                     
diff Exprs, aMap.getContext());
   canonicalizeMapAndOperands(&
diff Map, &bOperands);
   
diff Map = simplifyAffineMap(
diff Map);
   res->reset(
diff Map, bOperands);

diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
index b15a73720c1b..f6d1a5494be2 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
@@ -143,8 +143,9 @@ constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,
       boundExprs.push_back(dim + tileSizes[i]);
       boundExprs.append(origUbMap.getResults().begin(),
                         origUbMap.getResults().end());
-      auto ubMap = AffineMap::get(origUbMap.getNumDims() + 1,
-                                  origUbMap.getNumSymbols(), boundExprs);
+      auto ubMap =
+          AffineMap::get(origUbMap.getNumDims() + 1, origUbMap.getNumSymbols(),
+                         boundExprs, b.getContext());
       newLoops[width + i].setUpperBound(/*operands=*/ubOperands, ubMap);
     } else {
       // No need of the min expression.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2cb7b6c45963..df1a957d344c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -523,8 +523,10 @@ getSymbolLessAffineMaps(ArrayRef<ArrayRef<AffineExpr>> reassociation) {
          "Expected symbol-less expressions");
   SmallVector<AffineMap, 4> maps;
   maps.reserve(reassociation.size());
-  for (auto exprs : reassociation)
-    maps.push_back(AffineMap::get(maxDim + 1, 0, exprs));
+  for (auto exprs : reassociation) {
+    assert(exprs.size() != 0);
+    maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
+  }
   return maps;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
index 9eb3c329a7bd..529448497728 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
@@ -252,7 +252,8 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
       // so having a max op is enough.
       auto maxMap = AffineMap::get(/*dimCount=*/1, 0,
                                    {getAffineDimExpr(/*position=*/0, context),
-                                    getAffineConstantExpr(0, context)});
+                                    getAffineConstantExpr(0, context)},
+                                   context);
       clampedImIdx.push_back(
           affine_max(dim.getType(), maxMap, ValueRange{dim}));
     }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 084142618cc4..f6f69b0fee8f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -294,7 +294,8 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
             /*dimCount=*/3, /*symbolCount=*/0,
             {getAffineDimExpr(/*position=*/0, b.getContext()),
              getAffineDimExpr(/*position=*/1, b.getContext()) -
-                 getAffineDimExpr(/*position=*/2, b.getContext())});
+                 getAffineDimExpr(/*position=*/2, b.getContext())},
+            b.getContext());
         auto d = folded_std_dim(folder, view, r);
         size = folded_affine_min(folder, b.getIndexType(), minMap,
                                  ValueRange{size, d, offset});

diff  --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp
index 16b9b223f2eb..d9cf6df642d4 100644
--- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp
@@ -66,7 +66,8 @@ void mlir::loop::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
       /*dimCount=*/3, /*symbolCount=*/0,
       {getAffineDimExpr(/*position=*/0, b.getContext()),
        getAffineDimExpr(/*position=*/1, b.getContext()) -
-           getAffineDimExpr(/*position=*/2, b.getContext())});
+           getAffineDimExpr(/*position=*/2, b.getContext())},
+      b.getContext());
 
   // Create the inner loop with adjusted bounds.
   SmallVector<Value, 2> newBounds;

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 7a64f5eb5364..e888c5cdfd2f 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1348,9 +1348,7 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
       auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
       results.push_back(targetExpr);
     }
-    // The (...) -> () affine map has its own factory method.
-    return results.empty() ? AffineMap::get(map.getNumDims() - 1, 0, ctx)
-                           : AffineMap::get(map.getNumDims() - 1, 0, results);
+    return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
   }
 
   // Helper to drop dimension from vector type.

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 398e24bb9e22..1ed89e3f7010 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -176,7 +176,7 @@ static AffineMap makePermutationMap(
            "Vectorization prerequisite violated: at most 1 index may be "
            "invariant wrt a vectorized loop");
   }
-  return AffineMap::get(indices.size(), 0, perm);
+  return AffineMap::get(indices.size(), 0, perm, context);
 }
 
 /// Implementation detail that walks up the parents and records the ones with

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index be62a164cc8d..e05556a7cbf9 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -104,7 +104,7 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
   for (auto index : permutation)
     affExprs.push_back(getAffineDimExpr(index, context));
   auto m = std::max_element(permutation.begin(), permutation.end());
-  auto permutationMap = AffineMap::get(*m + 1, 0, affExprs);
+  auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
   assert(permutationMap.isPermutation() && "Invalid permutation vector");
   return permutationMap;
 }
@@ -127,13 +127,16 @@ static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
 template <typename AffineExprContainer>
 static SmallVector<AffineMap, 4>
 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
+  assert(!exprsList.empty());
+  assert(!exprsList[0].empty());
+  auto context = exprsList[0][0].getContext();
   int64_t maxDim = -1, maxSym = -1;
   getMaxDimAndSymbol(exprsList, maxDim, maxSym);
   SmallVector<AffineMap, 4> maps;
   maps.reserve(exprsList.size());
   for (const auto &exprs : exprsList)
     maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
-                                  /*symbolCount=*/maxSym + 1, exprs));
+                                  /*symbolCount=*/maxSym + 1, exprs, context));
   return maps;
 }
 
@@ -153,7 +156,7 @@ AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
   dimExprs.reserve(numDims);
   for (unsigned i = 0; i < numDims; ++i)
     dimExprs.push_back(mlir::getAffineDimExpr(i, context));
-  return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs);
+  return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context);
 }
 
 MLIRContext *AffineMap::getContext() const { return map->context; }
@@ -255,8 +258,7 @@ AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
     results.push_back(
         expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
 
-  return results.empty() ? get(numResultDims, 0, getContext())
-                         : get(numResultDims, numResultSyms, results);
+  return get(numResultDims, numResultSyms, results, getContext());
 }
 
 AffineMap AffineMap::compose(AffineMap map) {
@@ -280,8 +282,7 @@ AffineMap AffineMap::compose(AffineMap map) {
   exprs.reserve(getResults().size());
   for (auto expr : getResults())
     exprs.push_back(expr.compose(newMap));
-  return exprs.empty() ? AffineMap::get(numDims, 0, map.getContext())
-                       : AffineMap::get(numDims, numSymbols, exprs);
+  return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
 }
 
 bool AffineMap::isProjectedPermutation() {
@@ -312,7 +313,7 @@ AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
   for (auto idx : resultPos) {
     exprs.push_back(getResult(idx));
   }
-  return AffineMap::get(getNumDims(), getNumSymbols(), exprs);
+  return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
 }
 
 AffineMap mlir::simplifyAffineMap(AffineMap map) {
@@ -321,7 +322,8 @@ AffineMap mlir::simplifyAffineMap(AffineMap map) {
     exprs.push_back(
         simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
   }
-  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs);
+  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs,
+                        map.getContext());
 }
 
 AffineMap mlir::removeDuplicateExprs(AffineMap map) {
@@ -354,7 +356,7 @@ AffineMap mlir::inversePermutation(AffineMap map) {
       seenExprs.push_back(expr);
   if (seenExprs.size() != map.getNumInputs())
     return AffineMap();
-  return AffineMap::get(map.getNumResults(), 0, seenExprs);
+  return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
 }
 
 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
@@ -369,9 +371,8 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
     results.append(m.getResults().begin(), m.getResults().end());
     numDims = std::max(m.getNumDims(), numDims);
   }
-  return results.empty() ? AffineMap::get(numDims, /*numSymbols=*/0,
-                                          maps.front().getContext())
-                         : AffineMap::get(numDims, /*numSymbols=*/0, results);
+  return AffineMap::get(numDims, /*numSymbols=*/0, results,
+                        maps.front().getContext());
 }
 
 //===----------------------------------------------------------------------===//
@@ -380,8 +381,7 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
 
 MutableAffineMap::MutableAffineMap(AffineMap map)
     : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
-      // A map always has at least 1 result by construction
-      context(map.getResult(0).getContext()) {
+      context(map.getContext()) {
   for (auto result : map.getResults())
     results.push_back(result);
 }
@@ -390,8 +390,7 @@ void MutableAffineMap::reset(AffineMap map) {
   results.clear();
   numDims = map.getNumDims();
   numSymbols = map.getNumSymbols();
-  // A map always has at least 1 result by construction
-  context = map.getResult(0).getContext();
+  context = map.getContext();
   for (auto result : map.getResults())
     results.push_back(result);
 }
@@ -416,5 +415,5 @@ void MutableAffineMap::simplify() {
 }
 
 AffineMap MutableAffineMap::getAffineMap() const {
-  return AffineMap::get(numDims, numSymbols, results);
+  return AffineMap::get(numDims, numSymbols, results, context);
 }

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 40954a69f58f..22abeb5a364f 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -293,12 +293,11 @@ AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
 
 AffineMap Builder::getConstantAffineMap(int64_t val) {
   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
-                        {getAffineConstantExpr(val)});
+                        getAffineConstantExpr(val));
 }
 
 AffineMap Builder::getDimIdentityMap() {
-  return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
-                        {getAffineDimExpr(0)});
+  return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
 }
 
 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
@@ -306,18 +305,19 @@ AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
   dimExprs.reserve(rank);
   for (unsigned i = 0; i < rank; ++i)
     dimExprs.push_back(getAffineDimExpr(i));
-  return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs);
+  return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
+                        context);
 }
 
 AffineMap Builder::getSymbolIdentityMap() {
   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
-                        {getAffineSymbolExpr(0)});
+                        getAffineSymbolExpr(0));
 }
 
 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
   // expr = d0 + shift.
   auto expr = getAffineDimExpr(0) + shift;
-  return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr});
+  return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
 }
 
 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
@@ -325,7 +325,8 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
   shiftedResults.reserve(map.getNumResults());
   for (auto resultExpr : map.getResults())
     shiftedResults.push_back(resultExpr + shift);
-  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults);
+  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
+                        context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index e12926d69b3f..f0f3cc72d03a 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -717,10 +717,8 @@ AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
 }
 
 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
-                         ArrayRef<AffineExpr> results) {
-  // The number of results can't be zero.
-  assert(!results.empty());
-  return getImpl(dimCount, symbolCount, results, results[0].getContext());
+                         AffineExpr result) {
+  return getImpl(dimCount, symbolCount, {result}, result.getContext());
 }
 
 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 4a285db0be4d..903ae92e6baf 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -723,7 +723,7 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
   if (expr != simplifiedLayoutExpr)
     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
-        m.getNumDims(), m.getNumSymbols(), {simplifiedLayoutExpr})});
+        m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
   return MemRefType::Builder(t).setAffineMaps({});
 }
 

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 9361f226b717..e8825ff0fb27 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -3132,12 +3132,8 @@ AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
                                    /*allowEmptyList=*/true))
     return failure();
   // Parsed a valid affine map.
-  if (exprs.empty())
-    map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
-                         getContext());
-  else
-    map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
-                         exprs);
+  map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
+                       exprs, getContext());
   return success();
 }
 
@@ -3166,11 +3162,8 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
   if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
     return AffineMap();
 
-  if (exprs.empty())
-    return AffineMap::get(numDims, numSymbols, getContext());
-
   // Parsed a valid affine map.
-  return AffineMap::get(numDims, numSymbols, exprs);
+  return AffineMap::get(numDims, numSymbols, exprs, getContext());
 }
 
 /// Parse an affine constraint.

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index b5d9f5199bbe..1340c25b817b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -942,7 +942,8 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   }
   auto indexRemap = zeroOffsetCount == rank
                         ? AffineMap()
-                        : AffineMap::get(outerIVs.size() + rank, 0, remapExprs);
+                        : AffineMap::get(outerIVs.size() + rank, 0, remapExprs,
+                                         forOp.getContext());
   // Replace all users of 'oldMemRef' with 'newMemRef'.
   LogicalResult res =
       replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,

diff  --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 01aa25ab0a5c..d4a5ba97d6bc 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -98,8 +98,8 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
   // Create 'iv mod 2' value to index the leading dimension.
   auto d0 = bInner.getAffineDimExpr(0);
   int64_t step = forOp.getStep();
-  auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
-                                  {d0.floorDiv(step) % 2});
+  auto modTwoMap =
+      AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
   auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
                                                  forOp.getInductionVar());
 

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 9fe96437e526..3aebc83678f7 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -103,7 +103,8 @@ static void getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
   operands.clear();
   operands.push_back(lb);
   operands.append(bumpValues.begin(), bumpValues.end());
-  map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs);
+  map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs,
+                       b.getContext());
   // Simplify the map + operands.
   fullyComposeAffineMapAndOperands(&map, &operands);
   map = simplifyAffineMap(map);
@@ -485,7 +486,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp,
     if (!forOpIV.use_empty()) {
       // iv' = iv + 1/2/3...unrollFactor-1;
       auto d0 = builder.getAffineDimExpr(0);
-      auto bumpMap = AffineMap::get(1, 0, {d0 + i * step});
+      auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
       auto ivUnroll =
           builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
       operandMap.map(forOpIV, ivUnroll);
@@ -616,7 +617,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
       if (!forOpIV.use_empty()) {
         // iv' = iv + i, i = 1 to unrollJamFactor-1.
         auto d0 = builder.getAffineDimExpr(0);
-        auto bumpMap = AffineMap::get(1, 0, {d0 + i * step});
+        auto bumpMap = AffineMap::get(1, 0, d0 + i * step);
         auto ivUnroll =
             builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
         operandMap.map(forOpIV, ivUnroll);
@@ -859,7 +860,8 @@ static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map,
   auto bounds = llvm::to_vector<4>(map->getResults());
   bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
   operands->insert(operands->begin() + map->getNumDims(), iv);
-  *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds);
+  *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds,
+                        b.getContext());
   canonicalizeMapAndOperands(map, operands);
 }
 
@@ -1514,7 +1516,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
     b = forOp.getBodyBuilder();
 
     auto fastBufOffsetMap =
-        AffineMap::get(lbOperands.size(), 0, {fastBufOffsets[d]});
+        AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]);
     auto offset = b.create<AffineApplyOp>(loc, fastBufOffsetMap, lbOperands);
 
     // Construct the subscript for the fast memref being copied into/from:
@@ -1529,7 +1531,8 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
     memIndices.push_back(forOp.getInductionVar());
   }
 
-  auto fastBufMap = AffineMap::get(2 * rank, /*symbolCount=*/0, fastBufExprs);
+  auto fastBufMap =
+      AffineMap::get(2 * rank, /*symbolCount=*/0, fastBufExprs, b.getContext());
   fullyComposeAffineMapAndOperands(&fastBufMap, &fastBufMapOperands);
   fastBufMap = simplifyAffineMap(fastBufMap);
   canonicalizeMapAndOperands(&fastBufMap, &fastBufMapOperands);
@@ -1837,7 +1840,8 @@ static LogicalResult generateCopy(
     auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i);
     remapExprs.push_back(dimExpr - fastBufOffsets[i]);
   }
-  auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs);
+  auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs,
+                                   b.getContext());
 
   // Record the begin since it may be invalidated by memref replacement.
   Block::iterator prevOfBegin;

diff  --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
index 49fa339aa88a..9637ba3f4146 100644
--- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
@@ -271,3 +271,13 @@ func @affine.apply(%N : index) {
   // CHECK-NEXT: addi
   return
 }
+
+// -----
+
+// CHECK-DAG: #[[MAP_0D:.*]] = affine_map<() -> ()>
+
+// CHECK-LABEL: func @simplify_zero_dim_map
+func @simplify_zero_dim_map(%in : memref<f32>) -> f32 {
+  %out = affine.load %in[] : memref<f32>
+  return %out : f32
+}


        


More information about the Mlir-commits mailing list