[Mlir-commits] [mlir] a73455a - [mlir][affine] Fix crash in AffineApplyOp canonicalization

Matthias Springer llvmlistbot at llvm.org
Mon Dec 12 00:11:28 PST 2022


Author: Matthias Springer
Date: 2022-12-12T09:11:17+01:00
New Revision: a73455ac8e9bd02f843916c12c261cfd9ccb06a9

URL: https://github.com/llvm/llvm-project/commit/a73455ac8e9bd02f843916c12c261cfd9ccb06a9
DIFF: https://github.com/llvm/llvm-project/commit/a73455ac8e9bd02f843916c12c261cfd9ccb06a9.diff

LOG: [mlir][affine] Fix crash in AffineApplyOp canonicalization

This test case used to crash with a failed assertion:
```
AffineExpr.cpp:659 in AffineExpr simplifyMul(AffineExpr, AffineExpr): lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant()
```

This was caused by combining two affine maps, which created a multiplication of two non-symbols.

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index ffa354f7b52bd..7cafc2749af9c 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -763,6 +763,7 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
                                      unsigned dimOrSymbolPosition,
                                      SmallVectorImpl<Value> &dims,
                                      SmallVectorImpl<Value> &syms) {
+  MLIRContext *ctx = map->getContext();
   bool isDimReplacement = (dimOrSymbolPosition < dims.size());
   unsigned pos = isDimReplacement ? dimOrSymbolPosition
                                   : dimOrSymbolPosition - dims.size();
@@ -781,20 +782,24 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
   // Compute the map, dims and symbols coming from the AffineApplyOp.
   AffineMap composeMap = affineApply.getAffineMap();
   assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
-  AffineExpr composeExpr =
+  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
+                                     affineApply.getMapOperands().end());
+  // Canonicalize the map to promote dims to symbols when possible. This is to
+  // avoid generating invalid maps.
+  canonicalizeMapAndOperands(&composeMap, &composeOperands);
+  AffineExpr replacementExpr =
       composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
   ValueRange composeDims =
-      affineApply.getMapOperands().take_front(composeMap.getNumDims());
+      ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
   ValueRange composeSyms =
-      affineApply.getMapOperands().take_back(composeMap.getNumSymbols());
-
-  // Append the dims and symbols where relevant and perform the replacement.
-  MLIRContext *ctx = map->getContext();
+      ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
   AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
                                           : getAffineSymbolExpr(pos, ctx);
+
+  // Append the dims and symbols where relevant and perform the replacement.
   dims.append(composeDims.begin(), composeDims.end());
   syms.append(composeSyms.begin(), composeSyms.end());
-  *map = map->replace(toReplace, composeExpr, dims.size(), syms.size());
+  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
 
   return success();
 }

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index c25ed37edef62..e47cdde4cf25d 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -canonicalize | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -canonicalize="top-down=0" | FileCheck %s --check-prefix=CHECK-BOTTOM-UP
 
 // -----
 
@@ -1200,3 +1201,21 @@ func.func @simplify_with_operands(%N: index, %A: memref<?x32xf32>) {
 
   return
 }
+
+// -----
+
+//           CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>
+// CHECK-BOTTOM-UP: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>
+//           CHECK-LABEL: func @regression_do_not_perform_invalid_replacements
+// CHECK-BOTTOM-UP-LABEL: func @regression_do_not_perform_invalid_replacements
+func.func @regression_do_not_perform_invalid_replacements(%arg0: index) {
+  // Dim must be promoted to sym before combining both maps.
+  //           CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%{{.*}}]
+  // CHECK-BOTTOM-UP: %[[apply:.*]] = affine.apply #[[$map]]()[%{{.*}}]
+  %0 = affine.apply affine_map<(d0) -> (-d0 + 40961)>(%arg0)
+  %1 = affine.apply affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 512))>(%arg0)[%0]
+  //           CHECK: "test.foo"(%[[apply]])
+  // CHECK-BOTTOM-UP: "test.foo"(%[[apply]])
+  "test.foo"(%1) : (index) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list