[Mlir-commits] [mlir] 9c27fa3 - [mlir][linalg] Prepare fusion on tensors for scalar operands.

Tobias Gysi llvmlistbot at llvm.org
Wed Jun 9 00:11:13 PDT 2021


Author: Tobias Gysi
Date: 2021-06-09T07:09:46Z
New Revision: 9c27fa3821dc5c04f5710e64411815893de160ce

URL: https://github.com/llvm/llvm-project/commit/9c27fa3821dc5c04f5710e64411815893de160ce
DIFF: https://github.com/llvm/llvm-project/commit/9c27fa3821dc5c04f5710e64411815893de160ce.diff

LOG: [mlir][linalg] Prepare fusion on tensors for scalar operands.

Adapt fusion on tensors to support structured ops taking scalar operands.

Differential Revision: https://reviews.llvm.org/D103889

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d4dbb5aeb7c27..f65a0fa1772d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -701,24 +701,27 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
       }));
 
   SmallVector<Value> expandedOpOperands;
+  expandedOpOperands.reserve(genericOp.getNumInputs());
   for (OpOperand *opOperand : genericOp.getInputOperands()) {
     if (opOperand == fusableOpOperand) {
       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
                                                : collapsingReshapeOp.src());
       continue;
     }
-    AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
-    RankedTensorType expandedOperandType =
-        getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
-                        indexingMap, expansionInfo);
-    if (expandedOperandType != opOperand->get().getType()) {
-      // Reshape the operand to get the right type.
-      SmallVector<ReassociationIndices> reassociation =
-          getReassociationForExpansion(indexingMap, expansionInfo);
-      expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
-          genericOp.getLoc(), expandedOperandType, opOperand->get(),
-          reassociation));
-      continue;
+    if (genericOp.isInputTensor(opOperand)) {
+      AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+      RankedTensorType expandedOperandType =
+          getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
+                          indexingMap, expansionInfo);
+      if (expandedOperandType != opOperand->get().getType()) {
+        // Reshape the operand to get the right type.
+        SmallVector<ReassociationIndices> reassociation =
+            getReassociationForExpansion(indexingMap, expansionInfo);
+        expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
+            genericOp.getLoc(), expandedOperandType, opOperand->get(),
+            reassociation));
+        continue;
+      }
     }
     expandedOpOperands.push_back(opOperand->get());
   }
@@ -1035,7 +1038,7 @@ class FoldWithProducerReshapeOpByExpansion
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    for (OpOperand *opOperand : genericOp.getInputOperands()) {
+    for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
       TensorCollapseShapeOp reshapeOp =
           opOperand->get().getDefiningOp<TensorCollapseShapeOp>();
       if (!reshapeOp)


        


More information about the Mlir-commits mailing list