[Mlir-commits] [mlir] 62570b7 - [mlir][linalg] Fix crash in vectorizer when expanding affine apply

Thomas Raoux llvmlistbot at llvm.org
Fri Feb 3 00:17:03 PST 2023


Author: Thomas Raoux
Date: 2023-02-03T08:16:49Z
New Revision: 62570b722fa36fddde0d24bf06a245efadda66f5

URL: https://github.com/llvm/llvm-project/commit/62570b722fa36fddde0d24bf06a245efadda66f5
DIFF: https://github.com/llvm/llvm-project/commit/62570b722fa36fddde0d24bf06a245efadda66f5.diff

LOG: [mlir][linalg] Fix crash in vectorizer when expanding affine apply

Fix the insert point when expanding affine apply and handle cases with
symbols. Also add missing precondition to dynamic shape vectorization.

Differential Revision: https://reviews.llvm.org/D143243

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3173a447b5756..05a21102c2a2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -960,6 +960,10 @@ static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
   if (!isa<linalg::GenericOp>(op))
     return failure();
 
+  // TODO: Index vectorization assumes static shape.
+  if (op.hasIndexSemantics())
+    return failure();
+
   // TODO: 0-d vectors are not supported yet.
   if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) {
         return map.isEmpty() || map.getResults().empty();
@@ -1052,15 +1056,15 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
 
 /// Converts affine.apply Ops to arithmetic operations.
 static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
-  auto &newIP = linalgOp.getBlock()->front();
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPointAfter(&newIP);
   auto toReplace = linalgOp.getBlock()->getOps<AffineApplyOp>();
 
   for (auto op : make_early_inc_range(toReplace)) {
-    auto expanded =
-        expandAffineExpr(rewriter, op->getLoc(), op.getAffineMap().getResult(0),
-                         op.getOperands(), ValueRange{});
+    rewriter.setInsertionPoint(op);
+    auto expanded = expandAffineExpr(
+        rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+        op.getOperands().take_front(op.getAffineMap().getNumDims()),
+        op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
     rewriter.replaceOp(op, expanded);
   }
 }

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index f966b0e241159..a6c5602b9468b 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -301,7 +301,8 @@ func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor
   ^bb0(%arg1: f32, %arg2: i32):
     %2 = linalg.index 0 : index
     %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3)
-    %3 = arith.index_cast %12 : index to i32
+    %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3]
+    %3 = arith.index_cast %13 : index to i32
     linalg.yield %3 : i32
   } -> tensor<32xi32>
   return %1 : tensor<32xi32>
@@ -315,7 +316,9 @@ func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor
 // CHECK:   %[[EMPTY:.*]] = tensor.empty() : tensor<32xi32>
 // CHECK:   %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
 // CHECK:   %[[ADDI:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<32xindex>
-// CHECK:   %[[CAST:.*]] = arith.index_cast %[[ADDI]] : vector<32xindex> to vector<32xi32>
+// CHECK:   %[[BCAST2:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex>
+// CHECK:   %[[ADDI2:.*]] = arith.addi %[[ADDI]], %[[BCAST2]] : vector<32xindex>
+// CHECK:   %[[CAST:.*]] = arith.index_cast %[[ADDI2]] : vector<32xindex> to vector<32xi32>
 // CHECK:   vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<32xi32>, tensor<32xi32>
 
 transform.sequence failures(propagate) {


        


More information about the Mlir-commits mailing list