[Mlir-commits] [mlir] c698505 - [mlir][linalg] Cleanup LinalgOp usage in drop unit dims.

Tobias Gysi llvmlistbot at llvm.org
Thu Jun 3 05:52:35 PDT 2021


Author: Tobias Gysi
Date: 2021-06-03T12:27:05Z
New Revision: c698505257598d04f8e92a7ee79bfdf7c2cc6020

URL: https://github.com/llvm/llvm-project/commit/c698505257598d04f8e92a7ee79bfdf7c2cc6020
DIFF: https://github.com/llvm/llvm-project/commit/c698505257598d04f8e92a7ee79bfdf7c2cc6020.diff

LOG: [mlir][linalg] Cleanup LinalgOp usage in drop unit dims.

Replace the uses of deprecated Structured Op Interface methods in DropUnitDims.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103448

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 5e8820535a41..fb5990786802 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -183,9 +183,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
     AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
     if (!invertedMap)
       return failure();
-    SmallVector<int64_t, 4> dims;
-    for (ShapedType shapedType : genericOp.getShapedOperandTypes())
-      dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
+    SmallVector<int64_t> dims = genericOp.getStaticShape();
 
     // Find all the reduction iterators. Those need some special consideration
     // (see below).
@@ -267,17 +265,18 @@ struct UnitExtentReplacementInfo {
 /// - modified index map that can be used to access the replaced result/operand
 /// - the reassociation that converts from the original tensor type to the
 ///   modified tensor type.
-static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
-                                                    RankedTensorType type,
+static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
+                                                    OpOperand *opOperand,
                                                     MLIRContext *context) {
-  ArrayRef<int64_t> shape = type.getShape();
-  ArrayRef<AffineExpr> exprs = indexMap.getResults();
+  AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+  ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
+  ArrayRef<AffineExpr> exprs = indexingMap.getResults();
   SmallVector<AffineExpr, 2> reassociations;
   SmallVector<Attribute, 4> reassociationMaps;
   SmallVector<AffineExpr, 4> newIndexExprs;
   SmallVector<int64_t, 4> newShape;
 
-  int64_t origRank = type.getRank();
+  int64_t origRank = genericOp.getRank(opOperand);
   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
   auto isUnitExtent = [&](int64_t dim) -> bool {
     return shape[dim] == 1 && exprs[dim] == zeroExpr;
@@ -302,8 +301,9 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
     ++dim;
   }
   UnitExtentReplacementInfo info = {
-      RankedTensorType::get(newShape, type.getElementType()),
-      AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
+      RankedTensorType::get(newShape,
+                            getElementTypeOrSelf(opOperand->get().getType())),
+      AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(),
                      newIndexExprs, context),
       ArrayAttr::get(context, reassociationMaps)};
   return info;
@@ -335,15 +335,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
     SmallVector<ArrayAttr, 4> reassociationMaps;
     SmallVector<ShapedType, 4> newInputOutputTypes;
     bool doCanonicalization = false;
-    for (auto it : llvm::zip(genericOp.getIndexingMaps(),
-                             genericOp.getShapedOperandTypes())) {
-      auto replacementInfo = replaceUnitExtents(
-          std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
-          context);
+
+    for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
+      auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
       reassociationMaps.push_back(replacementInfo.reassociation);
       newIndexingMaps.push_back(replacementInfo.indexMap);
       newInputOutputTypes.push_back(replacementInfo.type);
-      doCanonicalization |= replacementInfo.type != std::get<1>(it);
+      doCanonicalization |= replacementInfo.type != opOperand->get().getType();
     }
 
     // If the indexing maps of the result operation are not invertible (i.e. not


        


More information about the Mlir-commits mailing list