[Mlir-commits] [mlir] f6b4e08 - [mlir][linalg] Prepare drop unit dims for scalar operands.

Tobias Gysi llvmlistbot at llvm.org
Fri Jun 11 06:19:42 PDT 2021


Author: Tobias Gysi
Date: 2021-06-11T13:18:06Z
New Revision: f6b4e081dc9cf74fb5c22439f552fa035f2c2651

URL: https://github.com/llvm/llvm-project/commit/f6b4e081dc9cf74fb5c22439f552fa035f2c2651
DIFF: https://github.com/llvm/llvm-project/commit/f6b4e081dc9cf74fb5c22439f552fa035f2c2651.diff

LOG: [mlir][linalg] Prepare drop unit dims for scalar operands.

Adapt drop unit dims for structured ops taking scalar operands.

Differential Revision: https://reviews.llvm.org/D103890

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index fb5990786802a..102dbdb4e2c36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -249,7 +249,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
 };
 
 struct UnitExtentReplacementInfo {
-  RankedTensorType type;
+  Type type;
   AffineMap indexMap;
   ArrayAttr reassociation;
 };
@@ -271,10 +271,10 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
   ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
-  SmallVector<AffineExpr, 2> reassociations;
-  SmallVector<Attribute, 4> reassociationMaps;
-  SmallVector<AffineExpr, 4> newIndexExprs;
-  SmallVector<int64_t, 4> newShape;
+  SmallVector<AffineExpr> reassociations;
+  SmallVector<Attribute> reassociationMaps;
+  SmallVector<AffineExpr> newIndexExprs;
+  SmallVector<int64_t> newShape;
 
   int64_t origRank = genericOp.getRank(opOperand);
   AffineExpr zeroExpr = getAffineConstantExpr(0, context);
@@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
     return shape[dim] == 1 && exprs[dim] == zeroExpr;
   };
 
-  unsigned dim = 0;
+  int64_t dim = 0;
   // Fold dimensions that are unit-extent at the beginning of the tensor.
   while (dim < origRank && isUnitExtent(dim))
     reassociations.push_back(getAffineDimExpr(dim++, context));
@@ -300,12 +300,16 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
     reassociations.clear();
     ++dim;
   }
-  UnitExtentReplacementInfo info = {
-      RankedTensorType::get(newShape,
-                            getElementTypeOrSelf(opOperand->get().getType())),
-      AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(),
-                     newIndexExprs, context),
-      ArrayAttr::get(context, reassociationMaps)};
+  // Compute the tensor or scalar replacement type.
+  Type elementType = getElementTypeOrSelf(opOperand->get().getType());
+  Type replacementType = elementType == opOperand->get().getType()
+                             ? elementType
+                             : RankedTensorType::get(newShape, elementType);
+  UnitExtentReplacementInfo info = {replacementType,
+                                    AffineMap::get(indexingMap.getNumDims(),
+                                                   indexingMap.getNumSymbols(),
+                                                   newIndexExprs, context),
+                                    ArrayAttr::get(context, reassociationMaps)};
   return info;
 }
 
@@ -331,13 +335,14 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
     MLIRContext *context = rewriter.getContext();
     Location loc = genericOp.getLoc();
 
-    SmallVector<AffineMap, 4> newIndexingMaps;
-    SmallVector<ArrayAttr, 4> reassociationMaps;
-    SmallVector<ShapedType, 4> newInputOutputTypes;
+    SmallVector<AffineMap> newIndexingMaps;
+    SmallVector<ArrayAttr> reassociationMaps;
+    SmallVector<Type> newInputOutputTypes;
     bool doCanonicalization = false;
 
     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
-      auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
+      UnitExtentReplacementInfo replacementInfo =
+          replaceUnitExtents(genericOp, opOperand, context);
       reassociationMaps.push_back(replacementInfo.reassociation);
       newIndexingMaps.push_back(replacementInfo.indexMap);
       newInputOutputTypes.push_back(replacementInfo.type);


        


More information about the Mlir-commits mailing list