[Mlir-commits] [mlir] 995e9d8 - [MLIR] Fix getCommonBlock utility in affine analysis

Uday Bondhugula llvmlistbot at llvm.org
Fri Jul 29 18:45:03 PDT 2022


Author: Uday Bondhugula
Date: 2022-07-30T07:14:54+05:30
New Revision: 995e9d84f8f90dd237871d15cf7237866902e5b2

URL: https://github.com/llvm/llvm-project/commit/995e9d84f8f90dd237871d15cf7237866902e5b2
DIFF: https://github.com/llvm/llvm-project/commit/995e9d84f8f90dd237871d15cf7237866902e5b2.diff

LOG: [MLIR] Fix getCommonBlock utility in affine analysis

Fix the hardcoded check for `FuncOp` in `getCommonBlock` utility: the
check should have been for an op that starts an affine scope. The
incorrect block returned in turn causes dependence analysis to function
incorrectly.

This change allows affine store-load forwarding to work correctly inside
any ops that start an affine scope.

Reviewed By: ftynse, dcaballe

Differential Revision: https://reviews.llvm.org/D130749

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
    mlir/test/Dialect/Affine/scalrep.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 3a0a11535605b..f861abe6b331a 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -322,9 +322,8 @@ static Block *getCommonBlock(const MemRefAccess &srcAccess,
 
   if (numCommonLoops == 0) {
     Block *block = srcAccess.opInst->getBlock();
-    while (!llvm::isa<func::FuncOp>(block->getParentOp())) {
+    while (!block->getParentOp()->hasTrait<OpTrait::AffineScope>())
       block = block->getParentOp()->getBlock();
-    }
     return block;
   }
   Value commonForIV = srcDomain.getValue(numCommonLoops - 1);

diff  --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir
index 4c09e3e27fbde..062a8d55327dc 100644
--- a/mlir/test/Dialect/Affine/scalrep.mlir
+++ b/mlir/test/Dialect/Affine/scalrep.mlir
@@ -701,3 +701,46 @@ func.func @with_inner_ops(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: i1)
 // CHECK:      } else {
 // CHECK:        scf.yield %[[pi]] : f64
 // CHECK:      }
+
+// Check if scalar replacement works correctly when affine memory ops are in the
+// body of an scf.for.
+
+// CHECK-LABEL: func @affine_store_load_in_scope
+func.func @affine_store_load_in_scope(%memref: memref<1x4094x510x1xf32>, %memref_2: memref<4x4x1x64xf32>, %memref_0: memref<1x2046x254x1x64xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c64 = arith.constant 64 : index
+  %c768 = arith.constant 768 : index
+  scf.for %i = %c0 to %c768 step %c1 {
+    %9 = arith.remsi %i, %c64 : index
+    %10 = arith.divsi %i, %c64 : index
+    %11 = arith.remsi %10, %c2 : index
+    %12 = arith.divsi %10, %c2 : index
+    test.affine_scope {
+      %14 = arith.muli %12, %c2 : index
+      %15 = arith.addi %c2, %14 : index
+      %16 = arith.addi %15, %c0 : index
+      %18 = arith.muli %11, %c2 : index
+      %19 = arith.addi %c2, %18 : index
+      %20 = affine.load %memref[0, symbol(%16), symbol(%19), 0] : memref<1x4094x510x1xf32>
+      %21 = affine.load %memref_2[0, 0, 0, symbol(%9)] : memref<4x4x1x64xf32>
+      %24 = affine.load %memref_0[0, symbol(%12), symbol(%11), 0, symbol(%9)] : memref<1x2046x254x1x64xf32>
+      %25 = arith.mulf %20, %21 : f32
+      %26 = arith.addf %24, %25 : f32
+      // CHECK: %[[A:.*]] = arith.addf
+      affine.store %26, %memref_0[0, symbol(%12), symbol(%11), 0, symbol(%9)] : memref<1x2046x254x1x64xf32>
+      %27 = arith.addi %19, %c1 : index
+      %28 = affine.load %memref[0, symbol(%16), symbol(%27), 0] : memref<1x4094x510x1xf32>
+      %29 = affine.load %memref_2[0, 1, 0, symbol(%9)] : memref<4x4x1x64xf32>
+      %30 = affine.load %memref_0[0, symbol(%12), symbol(%11), 0, symbol(%9)] : memref<1x2046x254x1x64xf32>
+      %31 = arith.mulf %28, %29 : f32
+      %32 = arith.addf %30, %31 : f32
+      // The addf above will get the forwarded value from the store on
+      // %memref_0 above which is being loaded into %30..
+      // CHECK: arith.addf %[[A]],
+      "terminate"() : () -> ()
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list