[Mlir-commits] [mlir] 4747e1b - [mlir][Linalg] Fix tensor.extract_slice(linalg.init_tensor) canonicalization for rank-reducing extract

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jul 8 11:14:09 PDT 2021


Author: Nicolas Vasilache
Date: 2021-07-08T18:13:51Z
New Revision: 4747e1b83ba0117a88551a358f8960060ffa7558

URL: https://github.com/llvm/llvm-project/commit/4747e1b83ba0117a88551a358f8960060ffa7558
DIFF: https://github.com/llvm/llvm-project/commit/4747e1b83ba0117a88551a358f8960060ffa7558.diff

LOG: [mlir][Linalg] Fix tensor.extract_slice(linalg.init_tensor) canonicalization for rank-reducing extract

Differential Revision: https://reviews.llvm.org/D105636

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ea12a312d9c0..4ef942d54776 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -772,11 +772,11 @@ struct FoldInitTensorWithExtractSliceOp
                                 PatternRewriter &rewriter) const override {
     if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>())
       return failure();
+    // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
+    // as well as its result type.
     rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
         sliceOp, sliceOp.sizes(),
-        llvm::to_vector<4>(llvm::map_range(
-            sliceOp.static_sizes(),
-            [](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
+        sliceOp.result().getType().cast<RankedTensorType>().getShape(),
         sliceOp.getSourceType().getElementType());
     return success();
   }

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 864c79a35756..5f7ad8ddfde2 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -890,3 +890,15 @@ func @init_canonicalize(%i : index) {
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing_init_extract
+func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
+  // CHECK: linalg.init_tensor [2] : tensor<2xf32>
+  %a = linalg.init_tensor [%sz, 2] : tensor<?x2xf32>
+
+  // CHECK-NOT: extract
+  %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
+  return %r: tensor<2xf32>
+}


        


More information about the Mlir-commits mailing list