[Mlir-commits] [mlir] 4d27f06 - [mlir][Tensor] Fix ExtractSliceFromReshape transform edge case

Christopher Bate llvmlistbot at llvm.org
Mon Sep 19 13:02:53 PDT 2022


Author: Christopher Bate
Date: 2022-09-19T14:02:45-06:00
New Revision: 4d27f06f9454a6733c3f801c8b992193702607b3

URL: https://github.com/llvm/llvm-project/commit/4d27f06f9454a6733c3f801c8b992193702607b3
DIFF: https://github.com/llvm/llvm-project/commit/4d27f06f9454a6733c3f801c8b992193702607b3.diff

LOG: [mlir][Tensor] Fix ExtractSliceFromReshape transform edge case

The transformation would fail if none of the sliced dimensions were
linearized by the producing `tensor.collapse_shape`. This is a trivial
edge case but it wasn't correctly tested. Fixes the issue and adds a test.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D134088

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e6b6048f8180..f693b3503abf 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -441,14 +441,16 @@ class SliceFromCollapseHelper {
   /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
   /// multi-index (%3) that would be passed to this function to generate the
   /// parameters for the `tensor.extract_slice` op (%4).
-  SmallVector<Range> getExtractSliceParams(ArrayRef<ValueRange> multiIndices);
+  SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
+                                           ArrayRef<ValueRange> multiIndices);
 
   /// This function takes indices in the index space of the "tiled dimensions"
   /// described above and returns a set of Range variables that describe how the
   /// slice should be inserted into the destination. In the example above, `%iv`
   /// would be passed to this function to generate the parameters for the
   /// `tensor.insert_slice` op producing %6.
-  SmallVector<Range> getInsertSliceParams(ValueRange tileIndices);
+  SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
+                                          ValueRange tileIndices);
 
 private:
   SmallVector<ReassociationIndices> reassociationIndices;

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
index 4acd5482e823..dcee9deff5ca 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp
@@ -164,13 +164,14 @@ tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
     }
   }
 
-  auto extractParams = helper.getExtractSliceParams(multiIndices);
+  SmallVector<Range> extractParams =
+      helper.getExtractSliceParams(builder.getContext(), multiIndices);
 
   Value subTileResult = builder.create<tensor::ExtractSliceOp>(
       loc, collapseShapeOp.getSrc(), extractParams);
 
   SmallVector<Range> insertParams =
-      helper.getInsertSliceParams(tileInductionVars);
+      helper.getInsertSliceParams(builder.getContext(), tileInductionVars);
 
   // Collapse the dimensions of the source slice back down.
   Value collapsedResult = builder.create<tensor::CollapseShapeOp>(

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 7f5b63814e69..9bca50f64321 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -298,11 +298,8 @@ llvm::SmallBitVector mlir::getLinearizedDimensions(
 }
 
 SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
-    ArrayRef<ValueRange> multiIndices) {
-  assert(!multiIndices.empty() && !multiIndices[0].empty() &&
-         "multiIndices should not be empty");
+    MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
   unsigned loopIdx = 0;
-  MLIRContext *ctx = multiIndices[0][0].getContext();
   auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
   auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
   SmallVector<Range> offsetsSizesAndStrides;
@@ -339,8 +336,8 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
 }
 
 SmallVector<Range>
-SliceFromCollapseHelper::getInsertSliceParams(ValueRange tileIndices) {
-  MLIRContext *ctx = tileIndices[0].getContext();
+SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
+                                              ValueRange tileIndices) {
   auto one = IntegerAttr::get(IndexType::get(ctx), 1);
   auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
   SmallVector<Range> insertParams;

diff  --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
index d8ca129bf59a..02e2502f9ffd 100644
--- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
+++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
@@ -162,3 +162,18 @@ func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32
   // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1]
   return %slice : tensor<?x22xf32>
 }
+
+// -----
+
+// CHECK: @no_sliced_linearized_dims(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
+func.func @no_sliced_linearized_dims(%input: tensor<30x11x100xf32>, %offt: index, %size: index) -> tensor<330x?xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<30x11x100xf32> into tensor<330x100xf32>
+  %slice = tensor.extract_slice %collapsed [0, %offt] [330, %size] [1, 1] : tensor<330x100xf32> to tensor<330x?xf32>
+  // CHECK-NOT: scf.for  
+  // CHECK: %[[init:.+]] = linalg.init_tensor [330, %[[arg2]]]
+  // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, %[[arg1]]] [30, 11, %[[arg2]]] [1, 1, 1]
+  // CHECK: %[[c:.+]] = tensor.collapse_shape %[[e]] {{\[}}[0, 1], [2]]
+  // CHECK: %[[res:.+]] = tensor.insert_slice %[[c]] into %[[init]]
+  // CHECK: return %[[res]]
+  return %slice : tensor<330x?xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index f5a7f984ab0a..5dd5d763388a 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -151,6 +151,13 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor
     auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
     SmallVector<Value> lbs(numTiledDims, zero);
     SmallVector<Value> steps(numTiledDims, one);
+
+    // Below, we pass out the result of the loop body builder lambda via the
+    // `insertResult` variable. In certain cases, no loops will be created, but
+    // the body builder will still execute. In this case, the results will not
+    // be passed to the LoopNest object.
+    // TODO: remove this workaround if `scf::buildLoopNest` behavior is updated.
+    Value insertResult = nullptr;
     scf::LoopNest nest = scf::buildLoopNest(
         rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest,
         [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
@@ -159,11 +166,16 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor
               helper.emitLoopNestBody(nestedBuilder, loc, outputIvs);
 
           // Insert the slice into the destination.
-          Value result = nestedBuilder.create<tensor::InsertSliceOp>(
+          insertResult = nestedBuilder.create<tensor::InsertSliceOp>(
               loc, tile, iterArgs[0], insertParams);
-          return {result};
+          return {insertResult};
         });
-    rewriter.replaceOp(op, nest.getResults()[0]);
+
+    if (!nest.loops.empty())
+      rewriter.replaceOp(op, nest.getResults());
+    else
+      rewriter.replaceOp(op, insertResult);
+
     return success();
   }
 };


        


More information about the Mlir-commits mailing list