[Mlir-commits] [mlir] 1a2bb03 - [MLIR][LINALG] Add canonicalization pattern in `linalg.generic` op for static shape inference.

Prateek Gupta llvmlistbot at llvm.org
Sun Feb 20 23:51:40 PST 2022


Author: Prateek Gupta
Date: 2022-02-21T07:51:13Z
New Revision: 1a2bb03edab9d7aa31beb587d0c863acc6715d27

URL: https://github.com/llvm/llvm-project/commit/1a2bb03edab9d7aa31beb587d0c863acc6715d27
DIFF: https://github.com/llvm/llvm-project/commit/1a2bb03edab9d7aa31beb587d0c863acc6715d27.diff

LOG: [MLIR][LINALG] Add canonicalization pattern in `linalg.generic` op for static shape inference.

This commit adds canonicalization pattern in `linalg.generic` op
for static shape inference. If any of the inputs or outputs have
static shape or is casted from a tensor of static shape, then
shapes of all the inputs and outputs can be inferred by using the
affine map of the static shape input/output.

Signed-Off-By: Prateek Gupta <prateek at nod-labs.com>

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D118929

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/reshape_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6b0d22c8f939e..319a1c318ff8a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -841,11 +841,169 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
     return success();
   }
 };
+
+/// For each of the operand in `operands` this function maps the static sizes of
+/// dimensions to their affine dim expressions.
+static void populateMap(GenericOp genericOp, ArrayRef<OpOperand *> operands,
+                        llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
+  for (OpOperand *opOperand : operands) {
+    if (genericOp.isScalar(opOperand))
+      continue;
+    Value src = opOperand->get();
+    auto sourceType = src.getType().cast<RankedTensorType>();
+    auto sourceMap = genericOp.getTiedIndexingMap(opOperand);
+
+    // Get the `sourceShape` of the `sourceType`. If the operand is a result of
+    // `tensor.cast` operation and source of the cast operation has a static
+    // shape, then assign it to the `sourceShape`.
+    auto parentOp = src.getDefiningOp();
+    ArrayRef<int64_t> sourceShape = sourceType.getShape();
+    if (parentOp) {
+      if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
+        Value castSource = castOp.source();
+        auto castSourceType = castSource.getType().cast<RankedTensorType>();
+        if (castSourceType.hasStaticShape())
+          sourceShape = castSourceType.getShape();
+      }
+    }
+
+    // If the source shape's dimension has a static shape, map the affine dim
+    // expression to the known static size.
+    for (unsigned i = 0; i < sourceShape.size(); i++) {
+      if (sourceType.isDynamicDim(i))
+        continue;
+      if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
+        affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
+    }
+  }
+}
+
+/// Creates new operand w.r.t 'opOperand' of `genericOp` with static sizes
+/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
+/// their result types is stored in `resultTypes`. If `opOperand` requires no
+/// change then `changeNeeded` is false and same operand is added in the
+/// `newOperands` list.
+static void createNewOperandWithStaticSizes(
+    Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
+    llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, GenericOp genericOp,
+    SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
+    bool &changeNeeded) {
+  Value src = opOperand->get();
+  newOperands.push_back(src);
+  if (genericOp.isScalar(opOperand))
+    return;
+  auto sourceType = src.getType().cast<RankedTensorType>();
+  Type resultType = sourceType;
+  if (sourceType.hasStaticShape() && genericOp.isOutputTensor(opOperand)) {
+    resultTypes.push_back(resultType);
+    return;
+  }
+  ArrayRef<int64_t> sourceShape = sourceType.getShape();
+  AffineMap sourceMap = genericOp.getTiedIndexingMap(opOperand);
+  SmallVector<int64_t> newShape;
+  // If operand is updated with new shape, `newOperandNeeded` will be
+  // true.
+  bool newOperandNeeded = false;
+  for (unsigned i = 0; i < sourceShape.size(); i++) {
+    int64_t dimShape = sourceShape[i];
+    AffineExpr dimExpr = sourceMap.getResult(i);
+    if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
+        !sourceType.isDynamicDim(i)) {
+      newShape.push_back(dimShape);
+      continue;
+    }
+    // Dimension has a dynamic shape and corresponding affine dim
+    // expression is present in the map. So assign the size for the
+    // given affine dim expression to the dimension.
+    newShape.push_back(affineExprToSize[dimExpr]);
+    newOperandNeeded = true;
+  }
+  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
+  if (newOperandNeeded) {
+    changeNeeded = true;
+    // Get the new operand value given its size and element type by
+    // casting it.
+    Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
+    unsigned index = opOperand->getOperandNumber();
+    newOperands[index] = newOperand;
+  }
+  if (genericOp.isOutputTensor(opOperand))
+    resultTypes.push_back(resultType);
+}
+
+/// Static shapes for the operands can be inferred if any one of the operands
+/// have a static shape. This can be done by referring to the affine dim
+/// expressions for the operand.
+struct InferStaticShapeOfOperands : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    if (!genericOp.hasTensorSemantics())
+      return failure();
+
+    // Maps must be projected permutations.
+    if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) {
+          return !map.isProjectedPermutation();
+        }))
+      return failure();
+
+    // Maps affine dim expressions to the static size of that dimension.
+    llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
+    Location loc = genericOp.getLoc();
+
+    // For each of the affine dim expression, check if the size is known. If
+    // known add that in the map.
+    populateMap(genericOp, genericOp.getInputAndOutputOperands(),
+                affineExprToSize);
+
+    SmallVector<Value> newOperands;
+    SmallVector<Type> resultTypes;
+
+    // `changeNeeded` is `false` if the operands of `genericOp` require no
+    // change in their types.
+    bool changeNeeded = false;
+    newOperands.reserve(genericOp.getNumInputsAndOutputs());
+    resultTypes.reserve(genericOp.getNumOutputs());
+
+    // Iterate over all the operands and update the static sizes.
+    for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
+      createNewOperandWithStaticSizes(loc, rewriter, opOperand,
+                                      affineExprToSize, genericOp, newOperands,
+                                      resultTypes, changeNeeded);
+    }
+
+    // If the generic op has all the required static information, no
+    // canonicalization needed.
+    if (!changeNeeded)
+      return failure();
+
+    // Clone op.
+    Operation *newOp =
+        cast<linalg::LinalgOp>(genericOp.getOperation())
+            .clone(rewriter, genericOp->getLoc(), resultTypes, newOperands);
+    SmallVector<Value> replacements;
+    replacements.reserve(newOp->getNumResults());
+    for (auto it : llvm::zip(genericOp->getResults(), newOp->getResults())) {
+      Value newResult = std::get<1>(it);
+      Value oldResult = std::get<0>(it);
+      Type newType = newResult.getType();
+      Type oldType = oldResult.getType();
+      replacements.push_back(
+          (newType != oldType)
+              ? rewriter.create<tensor::CastOp>(loc, newType, newResult)
+              : newResult);
+    }
+    rewriter.replaceOp(genericOp, replacements);
+    return success();
+  }
+};
 } // namespace
 
 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
+  results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp,
+              InferStaticShapeOfOperands>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index cbc8e4a50de5f..8a3f201f7cc26 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -650,3 +650,133 @@ func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
   } : tensor<400x273xf32> to tensor<412x276xf32>
   return %pad : tensor<412x276xf32>
 }
+
+// -----
+
+// Tests below verify whether static information is propagated through all the operands of generic op.
+// 1. If one of the inputs of generic op has static info and it has no cast source.
+// 2. If one of the inputs of generic op has static info and it is coming from tensr.cast operation.
+// 3. If one of the outputs of generic op has static info and it is coming from tenso.cast operation.
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @static_input_without_cast
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
+func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
+  %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
+  %3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
+  %4 = linalg.generic {
+    indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<?x?x?xf32>)
+    outs(%3 : tensor<?x?x?xf32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
+    %9 = arith.addf %arg2, %arg3 : f32
+    linalg.yield %9 : f32
+  } -> (tensor<?x?x?xf32>)
+  %5 = tensor.cast %4 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+  return %5 : tensor<2x3x4xf32>
+    //  CHECK:      %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+    //  CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
+    //  CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
+    //  CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @static_input_with_cast
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
+func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
+  %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
+  %3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
+  %4 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32>
+  %5 = linalg.generic {
+    indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } ins(%arg0, %4 : tensor<2x3x4xf32>, tensor<2x?x?xf32>)
+    outs(%3 : tensor<?x?x?xf32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
+    %9 = arith.addf %arg2, %arg3 : f32
+    linalg.yield %9 : f32
+  } -> (tensor<?x?x?xf32>)
+  %6 = tensor.cast %5 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+  return %6: tensor<2x3x4xf32>
+    //  CHECK:      %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+    //  CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
+    //  CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
+    //  CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @static_output_with_cast
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+func @static_output_with_cast(%arg0 : tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %0 = tensor.dim %arg2, %c0 : tensor<2x3x4xf32>
+  %1 = tensor.dim %arg2, %c1 : tensor<2x3x4xf32>
+  %2 = tensor.dim %arg2, %c2 : tensor<2x3x4xf32>
+  %3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
+  %4 = tensor.cast %3 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+  %5 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32>
+  %6 = linalg.generic {
+    indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } ins(%arg0, %5 : tensor<?x?x?xf32>, tensor<2x?x?xf32>)
+    outs(%4 : tensor<2x3x4xf32>) {
+  ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+    %9 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %9 : f32
+  } -> (tensor<2x3x4xf32>)
+  return %6: tensor<2x3x4xf32>
+    //  CHECK:      %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+    //  CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+    //  CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
+    //  CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
+    //  CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
+}
+
+// -----
+
+// This test checks the folding of tensor.cast operation when the source value of cast
+// has more static information than the destination value.
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @cast_source
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
+  %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
+  %3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
+  %4 = tensor.cast %arg0 : tensor<2x3x4xf32> to tensor<2x?x?xf32>
+  %5 = tensor.cast %arg1 : tensor<2x3x4xf32> to tensor<2x?x?xf32>
+  %6 = linalg.generic {
+    indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } ins(%4, %5 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
+    outs(%3 : tensor<?x?x?xf32>) {
+  ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
+    %9 = arith.addf %arg2, %arg3 : f32
+    linalg.yield %9 : f32
+  } -> (tensor<?x?x?xf32>)
+  %7 = tensor.cast %6 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+  return %7: tensor<2x3x4xf32>
+    //  CHECK:      %[[GENERIC_OP:.*]] = linalg.generic
+    //  CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
+    //  CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
+}

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 12d64651cdf36..5aebfcadc33e7 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -533,27 +533,28 @@ func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
 
 // -----
 
-func @no_fuse_mismatched_dynamism(%arg0: tensor<1x1xi64>, %arg1: tensor<?xi64>) -> tensor<1xi64> {
-  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
-  %1 = linalg.init_tensor [1] : tensor<1xi64>
+func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi64>) -> tensor<2xi64> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64>
+  %1 = linalg.init_tensor [2] : tensor<2xi64>
   %2 = linalg.generic
     {indexing_maps = [affine_map<(d0) -> (d0)>,
                       affine_map<(d0) -> (d0)>,
                       affine_map<(d0) -> (d0)>],
      iterator_types = ["parallel"]}
-    ins(%0, %arg1 : tensor<1xi64>, tensor<?xi64>)
-    outs(%1 : tensor<1xi64>) {
+    ins(%0, %arg1 : tensor<2xi64>, tensor<?xi64>)
+    outs(%1 : tensor<2xi64>) {
   ^bb0(%arg4: i64, %arg5: i64, %arg6: i64):  
     %3 = arith.addi %arg4, %arg5 : i64
     linalg.yield %3 : i64
-  } -> tensor<1xi64>
-  return %2 : tensor<1xi64>
+  } -> tensor<2xi64>
+  return %2 : tensor<2xi64>
 }
 
 //      CHECK: func @no_fuse_mismatched_dynamism
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x1xi64>
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<2x1xi64>
 // CHECK-SAME:     %[[ARG1:.+]]: tensor<?xi64>
 //      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
+//      CHECK:   %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?xi64> to tensor<2xi64>
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
-// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] : tensor<1xi64>, tensor<?xi64>)
+// CHECK-SAME:       ins(%[[RESHAPE]], %[[CAST]] : tensor<2xi64>, tensor<2xi64>)
 //      CHECK:   return %[[GENERIC]]


        


More information about the Mlir-commits mailing list