[Mlir-commits] [mlir] f2676b1 - [mlir][tensor] Add canonicalization for tensor.cast from extract_slice

Thomas Raoux llvmlistbot at llvm.org
Thu May 19 10:35:47 PDT 2022


Author: Thomas Raoux
Date: 2022-05-19T17:34:59Z
New Revision: f2676b151d6feb16d6f259ea8bac8ef16cc08f99

URL: https://github.com/llvm/llvm-project/commit/f2676b151d6feb16d6f259ea8bac8ef16cc08f99
DIFF: https://github.com/llvm/llvm-project/commit/f2676b151d6feb16d6f259ea8bac8ef16cc08f99.diff

LOG: [mlir][tensor] Add canonicalization for tensor.cast from extract_slice

Propagate static size information into extract_slice producer if
possible.

Differential Revision: https://reviews.llvm.org/D125972

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 932973a13c217..fff1563051786 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -229,11 +229,58 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
   }
 };
 
+/// Fold tensor.cast into tesor.extract_slice producer.
+/// Example:
+/// ```
+///  %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
+///    tensor<128x512xf32> to tensor<?x512xf32>
+///  %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
+/// ```
+/// ->
+/// ```
+/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
+///   tensor<128x512xf32> to tensor<16x512xf32>
+/// ```
+struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
+  using OpRewritePattern<CastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CastOp tensorCast,
+                                PatternRewriter &rewriter) const final {
+    auto extractOperand =
+        tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
+
+    if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
+        tensorCast.getType().getShape() ==
+            tensorCast.source().getType().cast<RankedTensorType>().getShape())
+      return failure();
+
+    SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
+    auto dimMask = computeRankReductionMask(
+        extractFromI64ArrayAttr(extractOperand.static_sizes()),
+        extractOperand.getType().getShape());
+    size_t dimIndex = 0;
+    for (size_t i = 0, e = sizes.size(); i < e; i++) {
+      if (dimMask && dimMask->count(i))
+        continue;
+      int64_t dim = tensorCast.getType().getShape()[dimIndex++];
+      if (ShapedType::isDynamic(dim))
+        continue;
+      sizes[i] = rewriter.getIndexAttr(dim);
+    }
+
+    rewriter.replaceOpWithNewOp<ExtractSliceOp>(
+        tensorCast, tensorCast.getType().cast<RankedTensorType>(),
+        extractOperand.source(), extractOperand.getMixedOffsets(), sizes,
+        extractOperand.getMixedStrides());
+    return success();
+  }
+};
+
 } // namespace
 
 void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<ChainedTensorCast>(context);
+  results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 6f7c5caf8ccb0..6c0bbc3c1cf95 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -30,9 +30,6 @@ func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2:
   return %3 : tensor<?x?xf32>
 }
 
-//       CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
-//       CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
-
 //       CHECK: func @matmul_tensors(
 //  CHECK-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
 //  CHECK-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
@@ -40,26 +37,20 @@ func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2:
 
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[dA0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
 //   CHECK-DAG: %[[dA1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
 //   CHECK-DAG: %[[dB0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
 //   CHECK-DAG: %[[dB1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
 //       CHECK: scf.for %[[I:[0-9a-z]*]]
-//       CHECK:   %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]]
-//       CHECK:   %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
-//       CHECK:   %[[castA:.*]] = tensor.cast %[[stA]] : tensor<?x?xf32> to tensor<2x?xf32>
+//       CHECK:   %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1]  : tensor<?x?xf32> to tensor<2x?xf32>
 //       CHECK:   scf.for %[[J:[0-9a-z]*]]
 //  CHECK-NEXT:     scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
 //   CHECK-DAG:       %[[stB1:.*]] = tensor.extract_slice %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1]  : tensor<?x?xf32> to tensor<4x3xf32>
 //   CHECK-DAG:       %[[stF:.*]] = tensor.extract_slice %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1]  : tensor<?x?xf32> to tensor<2x3xf32>
 //
 // slices of the producing matmul.
-//       CHECK:       %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]]
-//       CHECK:       %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
-//       CHECK:       %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [%[[sizeA0]], %[[sizeB1]]] [1, 1]  : tensor<?x?xf32> to tensor<?x?xf32>
-//   CHECK-DAG:       %[[castC:.+]] = tensor.cast %[[stC]] : tensor<?x?xf32> to tensor<2x4xf32>
-//   CHECK-DAG:       %[[castB:.+]] = tensor.cast %[[stB2]] : tensor<?x?xf32> to tensor<?x4xf32>
-//       CHECK:       %[[stD:.*]] = linalg.matmul ins(%[[castA]], %[[castB]] : tensor<2x?xf32>, tensor<?x4xf32>) outs(%[[castC]] : tensor<2x4xf32>)  -> tensor<2x4xf32>
+//   CHECK-DAG:       %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], 4] [1, 1]  : tensor<?x?xf32> to tensor<?x4xf32>
+//   CHECK-DAG:       %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1]  : tensor<?x?xf32> to tensor<2x4xf32>
+//       CHECK:       %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) outs(%[[stC]] : tensor<2x4xf32>)  -> tensor<2x4xf32>
 //  CHECK-NEXT:       %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
 //  CHECK-NEXT:       tensor.insert_slice %[[stG]] into %[[RES]][%[[I]], %[[J]]]
 

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index a0bbcc61e4cf7..ee5653af3e62f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1401,3 +1401,27 @@ func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %
   // CHECK: return %[[RES]] : tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @cast_extract_slice
+func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
+    -> tensor<16x512xf32> {
+// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 512] [1, 1] : tensor<128x512xf32> to tensor<16x512xf32>
+  %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : tensor<128x512xf32> to tensor<?x512xf32>
+  %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
+// CHECK: return %[[E]] : tensor<16x512xf32>
+  return %1 : tensor<16x512xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @cast_extract_slice_rank_reduce
+func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
+    -> tensor<16xf32> {
+// CHECK: %[[E:.*]]  = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 1] [1, 1] : tensor<128x512xf32> to tensor<16xf32>
+  %0 = tensor.extract_slice %arg0[%o, 0] [%s, 1] [1, 1] : tensor<128x512xf32> to tensor<?xf32>
+  %1 = tensor.cast %0 : tensor<?xf32> to tensor<16xf32>
+// CHECK: return %[[E]] : tensor<16xf32>
+  return %1 : tensor<16xf32>
+}


        


More information about the Mlir-commits mailing list