[Mlir-commits] [mlir] 5bc4f88 - s[mlir] Tighten computation of inferred SubView result type.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Feb 11 14:44:32 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-11T22:38:16Z
New Revision: 5bc4f8846c07bc3b355c8f303416784a10d1a298

URL: https://github.com/llvm/llvm-project/commit/5bc4f8846c07bc3b355c8f303416784a10d1a298
DIFF: https://github.com/llvm/llvm-project/commit/5bc4f8846c07bc3b355c8f303416784a10d1a298.diff

LOG: s[mlir] Tighten computation of inferred SubView result type.

The AffineMap in the MemRef inferred by SubViewOp may have uncompressed symbols which result in type mismatch on otherwise unused symbols. Make the computation of the AffineMap compress those unused symbols which results in better canonical types.
Additionally, improve the error message to report which inferred type was expected.

Differential Revision: https://reviews.llvm.org/D96551

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/IR/AffineExpr.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/IR/AffineExpr.cpp
    mlir/lib/IR/AffineMap.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 30905d3af411..a7142d298b66 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -305,14 +305,14 @@ class DmaWaitOp
 };
 
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
-/// `originalShape` with some `1` entries erased, return the vector of booleans
-/// that specifies which of the entries of `originalShape` are keep to obtain
+/// `originalShape` with some `1` entries erased, return the set of indices
+/// that specifies which of the entries of `originalShape` are dropped to obtain
 /// `reducedShape`. The returned mask can be applied as a projection to
 /// `originalShape` to obtain the `reducedShape`. This mask is useful to track
 /// which dimensions must be kept when e.g. compute MemRef strides under
 /// rank-reducing operations. Return None if reducedShape cannot be obtained
 /// by dropping only `1` entries in `originalShape`.
-llvm::Optional<SmallVector<bool, 4>>
+llvm::Optional<llvm::SmallDenseSet<unsigned>>
 computeRankReductionMask(ArrayRef<int64_t> originalShape,
                          ArrayRef<int64_t> reducedShape);
 

diff  --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index e71448716930..937efb945e4b 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -127,6 +127,12 @@ class AffineExpr {
   AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
                                    ArrayRef<AffineExpr> symReplacements) const;
 
+  /// Dim-only version of replaceDimsAndSymbols.
+  AffineExpr replaceDims(ArrayRef<AffineExpr> dimReplacements) const;
+
+  /// Symbol-only version of replaceDimsAndSymbols.
+  AffineExpr replaceSymbols(ArrayRef<AffineExpr> symReplacements) const;
+
   /// Sparse replace method. Replace `expr` by `replacement` and return the
   /// modified expression tree.
   AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 7884cb5ca8bf..86480529aa05 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -18,6 +18,7 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/DenseSet.h"
 
 namespace mlir {
 
@@ -311,6 +312,20 @@ struct MutableAffineMap {
 /// Simplifies an affine map by simplifying its underlying AffineExpr results.
 AffineMap simplifyAffineMap(AffineMap map);
 
+/// Drop the dims that are not used.
+AffineMap compressUnusedDims(AffineMap map);
+
+/// Drop the dims that are not listed in `unusedDims`.
+AffineMap compressDims(AffineMap map,
+                       const llvm::SmallDenseSet<unsigned> &unusedDims);
+
+/// Drop the symbols that are not used.
+AffineMap compressUnusedSymbols(AffineMap map);
+
+/// Drop the symbols that are not listed in `unusedSymbols`.
+AffineMap compressSymbols(AffineMap map,
+                          const llvm::SmallDenseSet<unsigned> &unusedSymbols);
+
 /// Returns a map with the same dimension and symbol count as `map`, but whose
 /// results are the unique affine expressions of `map`.
 AffineMap removeDuplicateExprs(AffineMap map);
@@ -390,8 +405,11 @@ AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
 /// 3) map                  : affine_map<(d0, d1, d2) -> (d0, d1)>
 ///    projected_dimensions : {1}
 ///    result               : affine_map<(d0, d1) -> (d0, 0)>
-AffineMap getProjectedMap(AffineMap map,
-                          ArrayRef<unsigned> projectedDimensions);
+///
+/// This function also compresses unused symbols away.
+AffineMap
+getProjectedMap(AffineMap map,
+                const llvm::SmallDenseSet<unsigned> &projectedDimensions);
 
 inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
   map.print(os);
@@ -402,7 +420,8 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
 namespace llvm {
 
 // AffineExpr hash just like pointers
-template <> struct DenseMapInfo<mlir::AffineMap> {
+template <>
+struct DenseMapInfo<mlir::AffineMap> {
   static mlir::AffineMap getEmptyKey() {
     auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
     return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 2396214bdaf7..d483ee4a4f2d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -566,6 +566,10 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
 /// Return true if the layout for `t` is compatible with strided semantics.
 bool isStrided(MemRefType t);
 
+/// Return the layout map in strided linear layout AffineMap form.
+/// Return null if the layout is not compatible with a strided layout.
+AffineMap getStridedLinearLayoutMap(MemRefType t);
+
 } // end namespace mlir
 
 #endif // MLIR_IR_BUILTINTYPES_H

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index da28ecbfc035..444c94ae6d42 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -3277,7 +3277,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
     auto inferredShape = inferredType.getShape();
     size_t inferredShapeRank = inferredShape.size();
     size_t resultShapeRank = shape.size();
-    SmallVector<bool, 4> mask =
+    llvm::SmallDenseSet<unsigned> unusedDims =
         computeRankReductionMask(inferredShape, shape).getValue();
 
     // Extract strides needed to compute offset.
@@ -3318,7 +3318,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
            "expected sizes and strides of equal length");
     for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
          i >= 0 && j >= 0; --i) {
-      if (!mask[i])
+      if (unusedDims.contains(i))
         continue;
 
       // `i` may overflow subViewOp.getMixedSizes because of trailing semantics.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index f6289b9541ce..7e52daef6588 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -536,10 +536,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
 /// Prune all dimensions that are of reduction iterator type from `map`.
 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
                                            AffineMap map) {
-  SmallVector<unsigned, 2> projectedDims;
+  llvm::SmallDenseSet<unsigned> projectedDims;
   for (auto attr : llvm::enumerate(iteratorTypes)) {
     if (!isParallelIterator(attr.value()))
-      projectedDims.push_back(attr.index());
+      projectedDims.insert(attr.index());
   }
   return getProjectedMap(map, projectedDims);
 }

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 9af00be6368e..3ef48ced1b65 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2957,35 +2957,44 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
 /// For ViewLikeOpInterface.
 Value SubViewOp::getViewSource() { return source(); }
 
-llvm::Optional<SmallVector<bool, 4>>
+/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
+/// `originalShape` with some `1` entries erased, return the set of indices
+/// that specifies which of the entries of `originalShape` are dropped to obtain
+/// `reducedShape`. The returned mask can be applied as a projection to
+/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
+/// which dimensions must be kept when e.g. compute MemRef strides under
+/// rank-reducing operations. Return None if reducedShape cannot be obtained
+/// by dropping only `1` entries in `originalShape`.
+llvm::Optional<llvm::SmallDenseSet<unsigned>>
 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
                                ArrayRef<int64_t> reducedShape) {
   size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
-  SmallVector<bool, 4> mask(originalRank);
+  llvm::SmallDenseSet<unsigned> unusedDims;
   unsigned reducedIdx = 0;
   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
-    // Skip matching dims greedily.
-    mask[originalIdx] =
-        (reducedIdx < reducedRank) &&
-        (originalShape[originalIdx] == reducedShape[reducedIdx]);
-    if (mask[originalIdx])
+    // Greedily insert `originalIdx` if no match.
+    if (reducedIdx < reducedRank &&
+        originalShape[originalIdx] == reducedShape[reducedIdx]) {
       reducedIdx++;
-    // 1 is the only non-matching allowed.
-    else if (originalShape[originalIdx] != 1)
-      return {};
-  }
+      continue;
+    }
 
+    unusedDims.insert(originalIdx);
+    // If no match on `originalIdx`, the `originalShape` at this dimension
+    // must be 1, otherwise we bail.
+    if (originalShape[originalIdx] != 1)
+      return llvm::None;
+  }
+  // The whole reducedShape must be scanned, otherwise we bail.
   if (reducedIdx != reducedRank)
-    return {};
-
-  return mask;
+    return llvm::None;
+  return unusedDims;
 }
 
 enum SubViewVerificationResult {
   Success,
   RankTooLarge,
   SizeMismatch,
-  StrideMismatch,
   ElemTypeMismatch,
   MemSpaceMismatch,
   AffineMapMismatch
@@ -2994,8 +3003,9 @@ enum SubViewVerificationResult {
 /// Checks if `original` Type type can be rank reduced to `reduced` type.
 /// This function is slight variant of `is subsequence` algorithm where
 /// not matching dimension must be 1.
-static SubViewVerificationResult isRankReducedType(Type originalType,
-                                                   Type candidateReducedType) {
+static SubViewVerificationResult
+isRankReducedType(Type originalType, Type candidateReducedType,
+                  std::string *errMsg = nullptr) {
   if (originalType == candidateReducedType)
     return SubViewVerificationResult::Success;
   if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
@@ -3019,13 +3029,17 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
   if (candidateReducedRank > originalRank)
     return SubViewVerificationResult::RankTooLarge;
 
-  auto optionalMask =
+  auto optionalUnusedDimsMask =
       computeRankReductionMask(originalShape, candidateReducedShape);
 
   // Sizes cannot be matched in case empty vector is returned.
-  if (!optionalMask.hasValue())
+  if (!optionalUnusedDimsMask.hasValue())
     return SubViewVerificationResult::SizeMismatch;
 
+  if (originalShapedType.getElementType() !=
+      candidateReducedShapedType.getElementType())
+    return SubViewVerificationResult::ElemTypeMismatch;
+
   // We are done for the tensor case.
   if (originalType.isa<RankedTensorType>())
     return SubViewVerificationResult::Success;
@@ -3033,74 +3047,54 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
   // Strided layout logic is relevant for MemRefType only.
   MemRefType original = originalType.cast<MemRefType>();
   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
-  MLIRContext *c = original.getContext();
-  int64_t originalOffset, candidateReducedOffset;
-  SmallVector<int64_t, 4> originalStrides, candidateReducedStrides, keepStrides;
-  SmallVector<bool, 4> keepMask = optionalMask.getValue();
-  (void)getStridesAndOffset(original, originalStrides, originalOffset);
-  (void)getStridesAndOffset(candidateReduced, candidateReducedStrides,
-                            candidateReducedOffset);
-
-  // Filter strides based on the mask and check that they are the same
-  // as candidateReduced ones.
-  unsigned candidateReducedIdx = 0;
-  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
-    if (keepMask[originalIdx]) {
-      if (originalStrides[originalIdx] !=
-          candidateReducedStrides[candidateReducedIdx++])
-        return SubViewVerificationResult::StrideMismatch;
-      keepStrides.push_back(originalStrides[originalIdx]);
-    }
-  }
-
-  if (original.getElementType() != candidateReduced.getElementType())
-    return SubViewVerificationResult::ElemTypeMismatch;
-
   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
     return SubViewVerificationResult::MemSpaceMismatch;
 
-  // reducedMap is obtained by projecting away the dimensions inferred from
-  // matching the 1's positions in candidateReducedType.
-  auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c);
-
-  MemRefType expectedReducedType = MemRefType::get(
-      candidateReduced.getShape(), candidateReduced.getElementType(),
-      reducedMap, candidateReduced.getMemorySpace());
-  expectedReducedType = canonicalizeStridedLayout(expectedReducedType);
-
-  if (expectedReducedType != canonicalizeStridedLayout(candidateReduced))
+  llvm::SmallDenseSet<unsigned> unusedDims = optionalUnusedDimsMask.getValue();
+  auto inferredType =
+      getProjectedMap(getStridedLinearLayoutMap(original), unusedDims);
+  AffineMap candidateLayout;
+  if (candidateReduced.getAffineMaps().empty())
+    candidateLayout = getStridedLinearLayoutMap(candidateReduced);
+  else
+    candidateLayout = candidateReduced.getAffineMaps().front();
+  if (inferredType != candidateLayout) {
+    if (errMsg) {
+      llvm::raw_string_ostream os(*errMsg);
+      os << "inferred type: " << inferredType;
+    }
     return SubViewVerificationResult::AffineMapMismatch;
-
+  }
   return SubViewVerificationResult::Success;
 }
 
 template <typename OpTy>
 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
-                                            OpTy op, Type expectedType) {
+                                            OpTy op, Type expectedType,
+                                            StringRef errMsg = "") {
   auto memrefType = expectedType.cast<ShapedType>();
   switch (result) {
   case SubViewVerificationResult::Success:
     return success();
   case SubViewVerificationResult::RankTooLarge:
     return op.emitError("expected result rank to be smaller or equal to ")
-           << "the source rank.";
+           << "the source rank. " << errMsg;
   case SubViewVerificationResult::SizeMismatch:
     return op.emitError("expected result type to be ")
            << expectedType
-           << " or a rank-reduced version. (mismatch of result sizes)";
-  case SubViewVerificationResult::StrideMismatch:
-    return op.emitError("expected result type to be ")
-           << expectedType
-           << " or a rank-reduced version. (mismatch of result strides)";
+           << " or a rank-reduced version. (mismatch of result sizes) "
+           << errMsg;
   case SubViewVerificationResult::ElemTypeMismatch:
     return op.emitError("expected result element type to be ")
-           << memrefType.getElementType();
+           << memrefType.getElementType() << errMsg;
   case SubViewVerificationResult::MemSpaceMismatch:
-    return op.emitError("expected result and source memory spaces to match.");
+    return op.emitError("expected result and source memory spaces to match.")
+           << errMsg;
   case SubViewVerificationResult::AffineMapMismatch:
     return op.emitError("expected result type to be ")
            << expectedType
-           << " or a rank-reduced version. (mismatch of result affine map)";
+           << " or a rank-reduced version. (mismatch of result affine map) "
+           << errMsg;
   }
   llvm_unreachable("unexpected subview verification result");
 }
@@ -3126,8 +3120,9 @@ static LogicalResult verify(SubViewOp op) {
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
 
-  auto result = isRankReducedType(expectedType, subViewType);
-  return produceSubViewErrorMsg(result, op, expectedType);
+  std::string errMsg;
+  auto result = isRankReducedType(expectedType, subViewType, &errMsg);
+  return produceSubViewErrorMsg(result, op, expectedType, errMsg);
 }
 
 raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) {

diff  --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index da1ba6d75398..c31d96e1abdc 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -92,6 +92,15 @@ AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
   llvm_unreachable("Unknown AffineExpr");
 }
 
+AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
+  return replaceDimsAndSymbols(dimReplacements, {});
+}
+
+AffineExpr
+AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
+  return replaceDimsAndSymbols({}, symReplacements);
+}
+
 /// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1].
 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift) const {
   SmallVector<AffineExpr, 4> dims;

diff  --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 37fb3b18da74..312e940d20b4 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -420,6 +420,71 @@ AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
       llvm::seq<unsigned>(getNumResults() - numResults, getNumResults())));
 }
 
+AffineMap mlir::compressDims(AffineMap map,
+                             const llvm::SmallDenseSet<unsigned> &unusedDims) {
+  unsigned numDims = 0;
+  SmallVector<AffineExpr> dimReplacements;
+  dimReplacements.reserve(map.getNumDims());
+  MLIRContext *context = map.getContext();
+  for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) {
+    if (unusedDims.contains(dim))
+      dimReplacements.push_back(getAffineConstantExpr(0, context));
+    else
+      dimReplacements.push_back(getAffineDimExpr(numDims++, context));
+  }
+  SmallVector<AffineExpr> resultExprs;
+  resultExprs.reserve(map.getNumResults());
+  for (auto e : map.getResults())
+    resultExprs.push_back(e.replaceDims(dimReplacements));
+  return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context);
+}
+
+AffineMap mlir::compressUnusedDims(AffineMap map) {
+  llvm::SmallDenseSet<unsigned> usedDims;
+  map.walkExprs([&](AffineExpr expr) {
+    if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+      usedDims.insert(dimExpr.getPosition());
+  });
+  llvm::SmallDenseSet<unsigned> unusedDims;
+  for (unsigned d = 0, e = map.getNumDims(); d != e; ++d)
+    if (!usedDims.contains(d))
+      unusedDims.insert(d);
+  return compressDims(map, unusedDims);
+}
+
+AffineMap
+mlir::compressSymbols(AffineMap map,
+                      const llvm::SmallDenseSet<unsigned> &unusedSymbols) {
+  unsigned numSymbols = 0;
+  SmallVector<AffineExpr> symReplacements;
+  symReplacements.reserve(map.getNumSymbols());
+  MLIRContext *context = map.getContext();
+  for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) {
+    if (unusedSymbols.contains(sym))
+      symReplacements.push_back(getAffineConstantExpr(0, context));
+    else
+      symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context));
+  }
+  SmallVector<AffineExpr> resultExprs;
+  resultExprs.reserve(map.getNumResults());
+  for (auto e : map.getResults())
+    resultExprs.push_back(e.replaceSymbols(symReplacements));
+  return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context);
+}
+
+AffineMap mlir::compressUnusedSymbols(AffineMap map) {
+  llvm::SmallDenseSet<unsigned> usedSymbols;
+  map.walkExprs([&](AffineExpr expr) {
+    if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
+      usedSymbols.insert(symExpr.getPosition());
+  });
+  llvm::SmallDenseSet<unsigned> unusedSymbols;
+  for (unsigned d = 0, e = map.getNumSymbols(); d != e; ++d)
+    if (!usedSymbols.contains(d))
+      unusedSymbols.insert(d);
+  return compressSymbols(map, unusedSymbols);
+}
+
 AffineMap mlir::simplifyAffineMap(AffineMap map) {
   SmallVector<AffineExpr, 8> exprs;
   for (auto e : map.getResults()) {
@@ -480,20 +545,10 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
                         maps.front().getContext());
 }
 
-AffineMap mlir::getProjectedMap(AffineMap map,
-                                ArrayRef<unsigned> projectedDimensions) {
-  DenseSet<unsigned> projectedDims(projectedDimensions.begin(),
-                                   projectedDimensions.end());
-  MLIRContext *context = map.getContext();
-  SmallVector<AffineExpr, 4> resultExprs;
-  for (auto dim : enumerate(llvm::seq<unsigned>(0, map.getNumDims()))) {
-    if (!projectedDims.count(dim.value()))
-      resultExprs.push_back(getAffineDimExpr(dim.index(), context));
-    else
-      resultExprs.push_back(getAffineConstantExpr(0, context));
-  }
-  return map.compose(AffineMap::get(
-      map.getNumDims() - projectedDimensions.size(), 0, resultExprs, context));
+AffineMap
+mlir::getProjectedMap(AffineMap map,
+                      const llvm::SmallDenseSet<unsigned> &unusedDims) {
+  return compressUnusedSymbols(compressDims(map, unusedDims));
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index b61073f699ef..289434ad52cd 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -829,7 +829,17 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
 /// Return true if the layout for `t` is compatible with strided semantics.
 bool mlir::isStrided(MemRefType t) {
   int64_t offset;
-  SmallVector<int64_t, 4> stridesAndOffset;
-  auto res = getStridesAndOffset(t, stridesAndOffset, offset);
+  SmallVector<int64_t, 4> strides;
+  auto res = getStridesAndOffset(t, strides, offset);
   return succeeded(res);
 }
+
+/// Return the layout map in strided linear layout AffineMap form.
+/// Return null if the layout is not compatible with a strided layout.
+AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  if (failed(getStridesAndOffset(t, strides, offset)))
+    return AffineMap();
+  return makeStridedLinearLayoutMap(strides, offset, t.getContext());
+}

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index ef645030bacf..2f1b7f9760b8 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -812,6 +812,10 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
   // CHECK: subview %{{.*}}[%{{.*}}, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref<f32, #[[$SUBVIEW_MAP12]]>
   %28 = subview %24[%arg0, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref<f32, affine_map<()[s0] -> (s0)>>
 
+  // CHECK: subview %{{.*}}[0, %{{.*}}] [%{{.*}}, 1] [1, 1] : memref<?x?xf32> to memref<?xf32, #[[$SUBVIEW_MAP1]]>
+  %a30 = alloc(%arg0, %arg0) : memref<?x?xf32>
+  %30 = subview %a30[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
+
   return
 }
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 595a950781b1..b5ee96839fbd 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -970,7 +970,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
 
 func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
   %0 = alloc() : memref<8x16x4xf32>
-  // expected-error at +1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result strides)}}
+  // expected-error at +1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result affine map)}}
   %1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2]
     : memref<8x16x4xf32> to
       memref<?x?x?xf32, offset: ?, strides: [64, 4, 1]>
@@ -1022,13 +1022,22 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
 // -----
 
 func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
-  // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}}
+  // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map)}}
   %0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>
   return
 }
 
 // -----
 
+// The affine map affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)> has an extra unused symbol.
+func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
+  // expected-error at +1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map) inferred type: (d0)[s0, s1] -> (d0 * s1 + s0)}}
+  %0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32, affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)>>
+  return
+}
+
+// -----
+
 func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
   // expected-error at +1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}}
   %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>


        


More information about the Mlir-commits mailing list