[Mlir-commits] [mlir] f740bdb - [mlir][Linalg] Modify `InferStaticShapeOfOperands` to work on Linalg Ops.

Mahesh Ravishankar llvmlistbot at llvm.org
Tue Mar 8 10:55:01 PST 2022


Author: Mahesh Ravishankar
Date: 2022-03-08T18:54:45Z
New Revision: f740bdbd2d084bbef52dd08d445497d3ec2ac24e

URL: https://github.com/llvm/llvm-project/commit/f740bdbd2d084bbef52dd08d445497d3ec2ac24e
DIFF: https://github.com/llvm/llvm-project/commit/f740bdbd2d084bbef52dd08d445497d3ec2ac24e.diff

LOG: [mlir][Linalg] Modify `InferStaticShapeOfOperands` to work on Linalg Ops.

Commit rG1a2bb03edab9d7aa31beb587d0c863acc6715d27 introduced a pattern
to convert dynamic dimensions in operands of `GenericOp`s to static
values based on indexing maps and shapes of other operands. The logic
is directly usable to any `LinalgOp`. Move that pattern as an
`OpInterfaceRewritePattern`.

Differential Revision: https://reviews.llvm.org/D120968

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 010695172518c..02ed7555a418c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -862,169 +862,11 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
-
-/// For each of the operand in `operands` this function maps the static sizes of
-/// dimensions to their affine dim expressions.
-static void populateMap(GenericOp genericOp, ArrayRef<OpOperand *> operands,
-                        llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
-  for (OpOperand *opOperand : operands) {
-    if (genericOp.isScalar(opOperand))
-      continue;
-    Value src = opOperand->get();
-    auto sourceType = src.getType().cast<RankedTensorType>();
-    auto sourceMap = genericOp.getTiedIndexingMap(opOperand);
-
-    // Get the `sourceShape` of the `sourceType`. If the operand is a result of
-    // `tensor.cast` operation and source of the cast operation has a static
-    // shape, then assign it to the `sourceShape`.
-    auto *parentOp = src.getDefiningOp();
-    ArrayRef<int64_t> sourceShape = sourceType.getShape();
-    if (parentOp) {
-      if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
-        Value castSource = castOp.source();
-        auto castSourceType = castSource.getType().cast<RankedTensorType>();
-        if (castSourceType.hasStaticShape())
-          sourceShape = castSourceType.getShape();
-      }
-    }
-
-    // If the source shape's dimension has a static shape, map the affine dim
-    // expression to the known static size.
-    for (unsigned i = 0; i < sourceShape.size(); i++) {
-      if (sourceType.isDynamicDim(i))
-        continue;
-      if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
-        affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
-    }
-  }
-}
-
-/// Creates new operand w.r.t 'opOperand' of `genericOp` with static sizes
-/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
-/// their result types is stored in `resultTypes`. If `opOperand` requires no
-/// change then `changeNeeded` is false and same operand is added in the
-/// `newOperands` list.
-static void createNewOperandWithStaticSizes(
-    Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
-    llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, GenericOp genericOp,
-    SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
-    bool &changeNeeded) {
-  Value src = opOperand->get();
-  newOperands.push_back(src);
-  if (genericOp.isScalar(opOperand))
-    return;
-  auto sourceType = src.getType().cast<RankedTensorType>();
-  Type resultType = sourceType;
-  if (sourceType.hasStaticShape() && genericOp.isOutputTensor(opOperand)) {
-    resultTypes.push_back(resultType);
-    return;
-  }
-  ArrayRef<int64_t> sourceShape = sourceType.getShape();
-  AffineMap sourceMap = genericOp.getTiedIndexingMap(opOperand);
-  SmallVector<int64_t> newShape;
-  // If operand is updated with new shape, `newOperandNeeded` will be
-  // true.
-  bool newOperandNeeded = false;
-  for (unsigned i = 0; i < sourceShape.size(); i++) {
-    int64_t dimShape = sourceShape[i];
-    AffineExpr dimExpr = sourceMap.getResult(i);
-    if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
-        !sourceType.isDynamicDim(i)) {
-      newShape.push_back(dimShape);
-      continue;
-    }
-    // Dimension has a dynamic shape and corresponding affine dim
-    // expression is present in the map. So assign the size for the
-    // given affine dim expression to the dimension.
-    newShape.push_back(affineExprToSize[dimExpr]);
-    newOperandNeeded = true;
-  }
-  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
-  if (newOperandNeeded) {
-    changeNeeded = true;
-    // Get the new operand value given its size and element type by
-    // casting it.
-    Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
-    unsigned index = opOperand->getOperandNumber();
-    newOperands[index] = newOperand;
-  }
-  if (genericOp.isOutputTensor(opOperand))
-    resultTypes.push_back(resultType);
-}
-
-/// Static shapes for the operands can be inferred if any one of the operands
-/// have a static shape. This can be done by referring to the affine dim
-/// expressions for the operand.
-struct InferStaticShapeOfOperands : public OpRewritePattern<GenericOp> {
-  using OpRewritePattern<GenericOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override {
-    if (!genericOp.hasTensorSemantics())
-      return failure();
-
-    // Maps must be projected permutations.
-    if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) {
-          return !map.isProjectedPermutation();
-        }))
-      return failure();
-
-    // Maps affine dim expressions to the static size of that dimension.
-    llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
-    Location loc = genericOp.getLoc();
-
-    // For each of the affine dim expression, check if the size is known. If
-    // known add that in the map.
-    populateMap(genericOp, genericOp.getInputAndOutputOperands(),
-                affineExprToSize);
-
-    SmallVector<Value> newOperands;
-    SmallVector<Type> resultTypes;
-
-    // `changeNeeded` is `false` if the operands of `genericOp` require no
-    // change in their types.
-    bool changeNeeded = false;
-    newOperands.reserve(genericOp.getNumInputsAndOutputs());
-    resultTypes.reserve(genericOp.getNumOutputs());
-
-    // Iterate over all the operands and update the static sizes.
-    for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
-      createNewOperandWithStaticSizes(loc, rewriter, opOperand,
-                                      affineExprToSize, genericOp, newOperands,
-                                      resultTypes, changeNeeded);
-    }
-
-    // If the generic op has all the required static information, no
-    // canonicalization needed.
-    if (!changeNeeded)
-      return failure();
-
-    // Clone op.
-    Operation *newOp =
-        cast<linalg::LinalgOp>(genericOp.getOperation())
-            .clone(rewriter, genericOp->getLoc(), resultTypes, newOperands);
-    SmallVector<Value> replacements;
-    replacements.reserve(newOp->getNumResults());
-    for (auto it : llvm::zip(genericOp->getResults(), newOp->getResults())) {
-      Value newResult = std::get<1>(it);
-      Value oldResult = std::get<0>(it);
-      Type newType = newResult.getType();
-      Type oldType = oldResult.getType();
-      replacements.push_back(
-          (newType != oldType)
-              ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
-              : newResult);
-    }
-    rewriter.replaceOp(genericOp, replacements);
-    return success();
-  }
-};
 } // namespace
 
 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
-              InferStaticShapeOfOperands>(context);
+  results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1811,6 +1653,162 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
   }
 };
 
+/// For each of the operand in `operands` this function maps the static sizes of
+/// dimensions to their affine dim expressions.
+static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
+                        llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
+  for (OpOperand *opOperand : operands) {
+    if (linalgOp.isScalar(opOperand))
+      continue;
+    Value src = opOperand->get();
+    auto sourceType = src.getType().cast<RankedTensorType>();
+    auto sourceMap = linalgOp.getTiedIndexingMap(opOperand);
+
+    // Get the `sourceShape` of the `sourceType`. If the operand is a result of
+    // `tensor.cast` operation and source of the cast operation has a static
+    // shape, then assign it to the `sourceShape`.
+    auto parentOp = src.getDefiningOp();
+    ArrayRef<int64_t> sourceShape = sourceType.getShape();
+    if (parentOp) {
+      if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
+        Value castSource = castOp.source();
+        auto castSourceType = castSource.getType().cast<RankedTensorType>();
+        if (castSourceType.hasStaticShape())
+          sourceShape = castSourceType.getShape();
+      }
+    }
+
+    // If the source shape's dimension has a static shape, map the affine dim
+    // expression to the known static size.
+    for (unsigned i = 0; i < sourceShape.size(); i++) {
+      if (sourceType.isDynamicDim(i))
+        continue;
+      if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
+        affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
+    }
+  }
+}
+
+/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
+/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
+/// their result types is stored in `resultTypes`. If `opOperand` requires no
+/// change then `changeNeeded` is false and same operand is added in the
+/// `newOperands` list.
+static void createNewOperandWithStaticSizes(
+    Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
+    llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
+    SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
+    bool &changeNeeded) {
+  Value src = opOperand->get();
+  newOperands.push_back(src);
+  if (linalgOp.isScalar(opOperand))
+    return;
+  auto sourceType = src.getType().cast<RankedTensorType>();
+  Type resultType = sourceType;
+  if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) {
+    resultTypes.push_back(resultType);
+    return;
+  }
+  ArrayRef<int64_t> sourceShape = sourceType.getShape();
+  AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand);
+  SmallVector<int64_t> newShape;
+  // If operand is updated with new shape, `newOperandNeeded` will be
+  // true.
+  bool newOperandNeeded = false;
+  for (unsigned i = 0; i < sourceShape.size(); i++) {
+    int64_t dimShape = sourceShape[i];
+    AffineExpr dimExpr = sourceMap.getResult(i);
+    if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
+        !sourceType.isDynamicDim(i)) {
+      newShape.push_back(dimShape);
+      continue;
+    }
+    // Dimension has a dynamic shape and corresponding affine dim
+    // expression is present in the map. So assign the size for the
+    // given affine dim expression to the dimension.
+    newShape.push_back(affineExprToSize[dimExpr]);
+    newOperandNeeded = true;
+  }
+  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
+  if (newOperandNeeded) {
+    changeNeeded = true;
+    // Get the new operand value given its size and element type by
+    // casting it.
+    Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
+    unsigned index = opOperand->getOperandNumber();
+    newOperands[index] = newOperand;
+  }
+  if (linalgOp.isOutputTensor(opOperand))
+    resultTypes.push_back(resultType);
+}
+
+/// Static shapes for the operands can be inferred if any one of the operands
+/// have a static shape. This can be done by referring to the affine dim
+/// expressions for the operand.
+struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(LinalgOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    if (!linalgOp.hasTensorSemantics())
+      return failure();
+
+    // Maps must be projected permutations.
+    if (llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap map) {
+          return !map.isProjectedPermutation();
+        }))
+      return failure();
+
+    // Maps affine dim expressions to the static size of that dimension.
+    llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
+    Location loc = linalgOp.getLoc();
+
+    // For each of the affine dim expression, check if the size is known. If
+    // known add that in the map.
+    populateMap(linalgOp, linalgOp.getInputAndOutputOperands(),
+                affineExprToSize);
+
+    SmallVector<Value> newOperands;
+    SmallVector<Type> resultTypes;
+
+    // `changeNeeded` is `false` if the operands of `linalgOp` require no
+    // change in their types.
+    bool changeNeeded = false;
+    newOperands.reserve(linalgOp.getNumInputsAndOutputs());
+    resultTypes.reserve(linalgOp.getNumOutputs());
+
+    // Iterate over all the operands and update the static sizes.
+    for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
+      createNewOperandWithStaticSizes(loc, rewriter, opOperand,
+                                      affineExprToSize, linalgOp, newOperands,
+                                      resultTypes, changeNeeded);
+    }
+
+    // If the generic op has all the required static information, no
+    // canonicalization needed.
+    if (!changeNeeded)
+      return failure();
+
+    // Clone op.
+    Operation *newOp =
+        linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands);
+    SmallVector<Value> replacements;
+    replacements.reserve(newOp->getNumResults());
+    for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
+      Value newResult = std::get<1>(it);
+      Value oldResult = std::get<0>(it);
+      Type newType = newResult.getType();
+      Type oldType = oldResult.getType();
+      replacements.push_back(
+          (newType != oldType)
+              ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
+              : newResult);
+    }
+    rewriter.replaceOp(linalgOp, replacements);
+    return success();
+  }
+};
+
 } // namespace
 
 #define LINALGOP_FOLDERS(XXX)                                                  \
@@ -1832,7 +1830,8 @@ LINALGOP_FOLDERS(GenericOp)
 void LinalgDialect::getCanonicalizationPatterns(
     RewritePatternSet &results) const {
   results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
-              FoldTensorCastProducerOp>(getContext());
+              FoldTensorCastProducerOp, InferStaticShapeOfOperands>(
+      getContext());
 }
 
 Operation *LinalgDialect::materializeConstant(OpBuilder &builder,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index b24a3e78e32b1..0dd435c01d818 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -772,9 +772,32 @@ func @fold_linalgop_with_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
-//       CHECK:  %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor<?x?xf32> to tensor<4x8xf32>
+//   CHECK-DAG:  %[[LHS_CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32>
+//   CHECK-DAG:  %[[RHS_CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<?x8xf32>
+//   CHECK-DAG:  %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor<?x?xf32> to tensor<4x8xf32>
 //       CHECK:  %[[MATMUL:.+]] = linalg.matmul
-//  CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]] :
+//  CHECK-SAME:      ins(%[[LHS_CAST]], %[[RHS_CAST]] :
 //  CHECK-SAME:      outs(%[[OUT_CAST]] :
 //       CHECK:  %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]]
 //       CHECK:  return %[[MATMUL]], %[[RESULT_CAST]]
+
+// -----
+
+func @fold_conv_op_with_cast_consumer(%arg0 : tensor<?x?x?x?xf32>,
+    %arg1 : tensor<?x?x?x?xf32>,  %arg2 : tensor<?x?x?x?xf32>) ->
+    (tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>) {
+  %0 = linalg.conv_2d_nchw_fchw ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  %1 = tensor.cast %0 : tensor<?x?x?x?xf32> to tensor<4x8x12x16xf32>
+  return %1, %0 : tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>
+}
+//       CHECK: func @fold_conv_op_with_cast_consumer(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>)
+//       CHECK:  %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor<?x?x?x?xf32> to tensor<4x8x12x16xf32>
+//       CHECK:  %[[CONV:.+]] = linalg.conv_2d_nchw_fchw
+//  CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]] :
+//  CHECK-SAME:      outs(%[[OUT_CAST]] :
+//       CHECK:  %[[RESULT_CAST:.+]] = tensor.cast %[[CONV]]
+//       CHECK:  return %[[CONV]], %[[RESULT_CAST]]

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 716e38a32e03c..23a991bf6b7a4 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -47,6 +47,7 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
 //       CHECK: scf.for %[[I:[0-9a-z]*]]
 //       CHECK:   %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]]
 //       CHECK:   %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
+//       CHECK:   %[[castA:.*]] = tensor.cast %[[stA]] : tensor<?x?xf32> to tensor<2x?xf32>
 //       CHECK:   scf.for %[[J:[0-9a-z]*]]
 //  CHECK-NEXT:     scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
 //   CHECK-DAG:       %[[stB1:.*]] = tensor.extract_slice %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1]  : tensor<?x?xf32> to tensor<4x3xf32>
@@ -57,7 +58,8 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
 //       CHECK:       %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
 //       CHECK:       %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [%[[sizeA0]], %[[sizeB1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
 //   CHECK-DAG:       %[[castC:.+]] = tensor.cast %[[stC]] : tensor<?x?xf32> to tensor<2x4xf32>
-//       CHECK:       %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[castC]] : tensor<2x4xf32>)  -> tensor<2x4xf32>
+//   CHECK-DAG:       %[[castB:.+]] = tensor.cast %[[stB2]] : tensor<?x?xf32> to tensor<?x4xf32>
+//       CHECK:       %[[stD:.*]] = linalg.matmul ins(%[[castA]], %[[castB]] : tensor<2x?xf32>, tensor<?x4xf32>) outs(%[[castC]] : tensor<2x4xf32>)  -> tensor<2x4xf32>
 //  CHECK-NEXT:       %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
 //  CHECK-NEXT:       tensor.insert_slice %[[stG]] into %[[RES]][%[[I]], %[[J]]]
 


        


More information about the Mlir-commits mailing list