[llvm-branch-commits] [mlir] ca1bad5 - [MLIR][GPU] Properly model step in parallel loop to gpu conversion.

Stephan Herhut via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Feb 25 05:07:00 PST 2020


Author: Stephan Herhut
Date: 2020-02-25T14:04:39+01:00
New Revision: ca1bad5253a18ea6ab6573abdc26b740ff4593c4

URL: https://github.com/llvm/llvm-project/commit/ca1bad5253a18ea6ab6573abdc26b740ff4593c4
DIFF: https://github.com/llvm/llvm-project/commit/ca1bad5253a18ea6ab6573abdc26b740ff4593c4.diff

LOG: [MLIR][GPU] Properly model step in parallel loop to gpu conversion.

Summary:
The original patch had TODOs to add support for step computations,
which this commit addresses. The computations are expressed using
affine expressions so that the affine canonicalizers can simplify
the full bound and index computations.

Also cleans up the code a little and exposes the pass in the
header file.

Differential Revision: https://reviews.llvm.org/D75052

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
    mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
    mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
    mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
index ed91f1b4df63..d5f48d29ea6c 100644
--- a/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
+++ b/mlir/include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h
@@ -15,6 +15,7 @@
 namespace mlir {
 class FuncOp;
 template <typename T> class OpPassBase;
+class Pass;
 
 /// Create a pass that converts loop nests into GPU kernels.  It considers
 /// top-level affine.for and linalg.for operations as roots of loop nests and
@@ -36,6 +37,13 @@ createSimpleLoopsToGPUPass(unsigned numBlockDims, unsigned numThreadDims);
 std::unique_ptr<OpPassBase<FuncOp>>
 createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups,
                     ArrayRef<int64_t> workGroupSize);
+
+/// Creates a pass that converts loop.parallel operations into a gpu.launch
+/// operation. The mapping of loop dimensions to launch dimensions is derived
+/// from mapping attributes. See ParallelToGpuLaunchLowering::matchAndRewrite
+/// for a description of the used attributes.
+std::unique_ptr<Pass> createParallelLoopToGpuPass();
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_LOOPSTOGPU_LOOPSTOGPUPASS_H_

diff  --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
index f28409f23045..5b6b3a2d4f56 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
@@ -531,25 +531,19 @@ static MappingAnnotation extractMappingAnnotation(Attribute attribute) {
 
 /// Tries to derive a static upper bound from the defining operation of
 /// `upperBound`.
-static Value deriveStaticUpperBound(Value upperBound) {
-  Value constantBound = {};
+static Value deriveStaticUpperBound(Value upperBound,
+                                    PatternRewriter &rewriter) {
   if (AffineMinOp minOp =
           dyn_cast_or_null<AffineMinOp>(upperBound.getDefiningOp())) {
-    auto map = minOp.map();
-    auto operands = minOp.operands();
-    for (int sub = 0, e = map.getNumResults(); sub < e; ++sub) {
-      AffineExpr expr = map.getResult(sub);
-      if (AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>()) {
-        auto dimOperand = operands[dimExpr.getPosition()];
-        auto defOp = dimOperand.getDefiningOp();
-        if (ConstantOp constOp = dyn_cast_or_null<ConstantOp>(defOp)) {
-          constantBound = constOp;
-          break;
-        }
+    for (const AffineExpr &result : minOp.map().getResults()) {
+      if (AffineConstantExpr constExpr =
+              result.dyn_cast<AffineConstantExpr>()) {
+        return rewriter.create<ConstantIndexOp>(minOp.getLoc(),
+                                                constExpr.getValue());
       }
     }
   }
-  return constantBound;
+  return {};
 }
 
 /// Modifies the current transformation state to capture the effect of the given
@@ -614,46 +608,62 @@ static LogicalResult processParallelLoop(ParallelOp parallelOp,
 
     if (annotation.processor < gpu::LaunchOp::kNumConfigOperands) {
       // Use the corresponding thread/grid index as replacement for the loop iv.
-      // TODO(herhut): Make the iv calculation depend on lower & upper bound.
       Value operand = launchOp.body().front().getArgument(annotation.processor);
-      Value appliedMap =
-          rewriter.create<AffineApplyOp>(loc, annotation.indexMap, operand);
-      // Add the lower bound, as the maps are 0 based but the loop might not be.
-      // TODO(herhut): Maybe move this explicitly into the maps?
-      newIndex = rewriter.create<AddIOp>(
-          loc, appliedMap, cloningMap.lookupOrDefault(lowerBound));
+      // Take the indexmap and add the lower bound and step computations in.
+      // This computes operand * step + lowerBound.
+      // Use an affine map here so that it composes nicely with the provided
+      // annotation.
+      AffineMap lowerAndStep = AffineMap::get(
+          1, 2,
+          rewriter.getAffineDimExpr(0) * rewriter.getAffineSymbolExpr(0) +
+              rewriter.getAffineSymbolExpr(1));
+      newIndex = rewriter.create<AffineApplyOp>(
+          loc, annotation.indexMap.compose(lowerAndStep),
+          ValueRange{operand, step, lowerBound});
       // If there was also a bound, insert that, too.
       // TODO(herhut): Check that we do not assign bounds twice.
       if (annotation.boundMap) {
         // We pass as the single opererand to the bound-map the number of
-        // iterations, which is upperBound - lowerBound. To support inner loops
-        // with dynamic upper bounds (as generated by e.g. tiling), try to
-        // derive a max for the bounds. If the used bound for the hardware id is
-        // inprecise, wrap the contained code into a conditional.
-        // If the lower-bound is constant or defined before the launch, we can
-        // use it in the launch bounds. Otherwise fail.
+        // iterations, which is (upperBound - lowerBound) ceilDiv step. To
+        // support inner loops with dynamic upper bounds (as generated by e.g.
+        // tiling), try to derive a max for the bounds. If the used bound for
+        // the hardware id is imprecise, wrap the contained code into a
+        // conditional. If the lower-bound is constant or defined before the
+        // launch, we can use it in the launch bounds. Otherwise fail.
         if (!launchIndependent(lowerBound) &&
             !isa<ConstantOp>(lowerBound.getDefiningOp()))
           return failure();
+        // The step must also be constant or defined outside of the loop nest.
+        if (!launchIndependent(step) && !isa<ConstantOp>(step.getDefiningOp()))
+          return failure();
         // If the upper-bound is constant or defined before the launch, we can
         // use it in the launch bounds directly. Otherwise try derive a bound.
         bool boundIsPrecise = launchIndependent(upperBound) ||
                               isa<ConstantOp>(upperBound.getDefiningOp());
-        if (!boundIsPrecise) {
-          upperBound = deriveStaticUpperBound(upperBound);
-          if (!upperBound)
-            return failure();
-        }
         {
           PatternRewriter::InsertionGuard guard(rewriter);
           rewriter.setInsertionPoint(launchOp);
-
-          Value iterations = rewriter.create<SubIOp>(
-              loc,
-              ensureLaunchIndependent(cloningMap.lookupOrDefault(upperBound)),
-              ensureLaunchIndependent(cloningMap.lookupOrDefault(lowerBound)));
+          if (!boundIsPrecise) {
+            upperBound = deriveStaticUpperBound(upperBound, rewriter);
+            if (!upperBound)
+              return failure();
+          }
+          // Compute the number of iterations needed. We compute this as an
+          // affine expression ceilDiv (upperBound - lowerBound) step. We use
+          // affine.apply here so that it composes nicely with the provided map.
+          AffineMap stepMap =
+              AffineMap::get(0, 3,
+                             (rewriter.getAffineSymbolExpr(0) -
+                              rewriter.getAffineSymbolExpr(1).ceilDiv(
+                                  rewriter.getAffineSymbolExpr(2))));
           Value launchBound = rewriter.create<AffineApplyOp>(
-              loc, annotation.boundMap, iterations);
+              loc, annotation.boundMap.compose(stepMap),
+              ValueRange{
+                  ensureLaunchIndependent(
+                      cloningMap.lookupOrDefault(upperBound)),
+                  ensureLaunchIndependent(
+                      cloningMap.lookupOrDefault(lowerBound)),
+                  ensureLaunchIndependent(cloningMap.lookupOrDefault(step))});
           launchOp.setOperand(annotation.processor, launchBound);
         }
         if (!boundIsPrecise) {
@@ -747,8 +757,6 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
   bool leftNestingScope = false;
   while (!worklist.empty()) {
     Operation *op = worklist.pop_back_val();
-    launchOp.dump();
-
     // Now walk over the body and clone it.
     // TODO: This is only correct if there either is no further loop.parallel
     //       nested or this code is side-effect free. Otherwise we might need
@@ -787,30 +795,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
   return matchSuccess();
 }
 
-namespace {
-struct ParallelLoopToGpuPass : public OperationPass<ParallelLoopToGpuPass> {
-  void runOnOperation() override;
-};
-} // namespace
-
 void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
                                              MLIRContext *ctx) {
   patterns.insert<ParallelToGpuLaunchLowering>(ctx);
 }
-
-void ParallelLoopToGpuPass::runOnOperation() {
-  OwningRewritePatternList patterns;
-  populateParallelLoopToGPUPatterns(patterns, &getContext());
-  ConversionTarget target(getContext());
-  target.addLegalDialect<StandardOpsDialect>();
-  target.addLegalDialect<AffineOpsDialect>();
-  target.addLegalDialect<gpu::GPUDialect>();
-  target.addLegalDialect<loop::LoopOpsDialect>();
-  target.addIllegalOp<loop::ParallelOp>();
-  if (failed(applyPartialConversion(getOperation(), target, patterns)))
-    signalPassFailure();
-}
-
-static PassRegistration<ParallelLoopToGpuPass>
-    pass("convert-parallel-loops-to-gpu", "Convert mapped loop.parallel ops"
-                                          " to gpu launch operations.");

diff  --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
index 73d46e8f14a0..9a703199cba1 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp
@@ -9,9 +9,11 @@
 #include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h"
 #include "mlir/Conversion/LoopsToGPU/LoopsToGPU.h"
 #include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/CommandLine.h"
@@ -115,6 +117,21 @@ struct ImperfectlyNestedForLoopMapper
   SmallVector<int64_t, 3> workGroupSize;
 };
 
+struct ParallelLoopToGpuPass : public OperationPass<ParallelLoopToGpuPass> {
+  void runOnOperation() override {
+    OwningRewritePatternList patterns;
+    populateParallelLoopToGPUPatterns(patterns, &getContext());
+    ConversionTarget target(getContext());
+    target.addLegalDialect<StandardOpsDialect>();
+    target.addLegalDialect<AffineOpsDialect>();
+    target.addLegalDialect<gpu::GPUDialect>();
+    target.addLegalDialect<loop::LoopOpsDialect>();
+    target.addIllegalOp<loop::ParallelOp>();
+    if (failed(applyPartialConversion(getOperation(), target, patterns)))
+      signalPassFailure();
+  }
+};
+
 } // namespace
 
 std::unique_ptr<OpPassBase<FuncOp>>
@@ -130,6 +147,10 @@ mlir::createLoopToGPUPass(ArrayRef<int64_t> numWorkGroups,
                                                           workGroupSize);
 }
 
+std::unique_ptr<Pass> mlir::createParallelLoopToGpuPass() {
+  return std::make_unique<ParallelLoopToGpuPass>();
+}
+
 static PassRegistration<ForLoopMapper>
     registration(PASS_NAME, "Convert top-level loops to GPU kernels", [] {
       return std::make_unique<ForLoopMapper>(clNumBlockDims.getValue(),
@@ -145,3 +166,7 @@ static PassRegistration<ImperfectlyNestedForLoopMapper> loopOpToGPU(
       return std::make_unique<ImperfectlyNestedForLoopMapper>(numWorkGroups,
                                                               workGroupSize);
     });
+
+static PassRegistration<ParallelLoopToGpuPass>
+    pass("convert-parallel-loops-to-gpu", "Convert mapped loop.parallel ops"
+                                          " to gpu launch operations.");

diff  --git a/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir b/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir
index 2045f7a08981..b4b91453a60a 100644
--- a/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir
+++ b/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir
@@ -15,24 +15,21 @@ func @parallel_loop_bidy_bidx(%arg0 : index, %arg1 : index, %arg2 : index,
   return
 }
 
-// CHECK:       #map0 = affine_map<(d0) -> (d0)>
-// CHECK:       module {
+// CHECK:       #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s0 - s1 ceildiv s2)>
+// CHECK:       #[[MAP1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
 
+// CHECK:       module {
 // CHECK-LABEL:   func @parallel_loop_bidy_bidx(
-// CHECK-SAME:                        [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index, [[VAL_5:%.*]]: memref<?x?xf32>, [[VAL_6:%.*]]: memref<?x?xf32>) {
+// CHECK-SAME:                                  [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index, [[VAL_5:%.*]]: memref<?x?xf32>, [[VAL_6:%.*]]: memref<?x?xf32>) {
 // CHECK:           [[VAL_7:%.*]] = constant 2 : index
 // CHECK:           [[VAL_8:%.*]] = constant 1 : index
-// CHECK:           [[VAL_9:%.*]] = subi [[VAL_2]], [[VAL_0]] : index
-// CHECK:           [[VAL_10:%.*]] = affine.apply #map0([[VAL_9]])
-// CHECK:           [[VAL_11:%.*]] = subi [[VAL_3]], [[VAL_1]] : index
-// CHECK:           [[VAL_12:%.*]] = affine.apply #map0([[VAL_11]])
-// CHECK:           gpu.launch blocks([[VAL_13:%.*]], [[VAL_14:%.*]], [[VAL_15:%.*]]) in ([[VAL_16:%.*]] = [[VAL_12]], [[VAL_17:%.*]] = [[VAL_10]], [[VAL_18:%.*]] = [[VAL_8]]) threads([[VAL_19:%.*]], [[VAL_20:%.*]], [[VAL_21:%.*]]) in ([[VAL_22:%.*]] = [[VAL_8]], [[VAL_23:%.*]] = [[VAL_8]], [[VAL_24:%.*]] = [[VAL_8]]) {
-// CHECK:             [[VAL_25:%.*]] = affine.apply #map0([[VAL_14]])
-// CHECK:             [[VAL_26:%.*]] = addi [[VAL_25]], [[VAL_0]] : index
-// CHECK:             [[VAL_27:%.*]] = affine.apply #map0([[VAL_13]])
-// CHECK:             [[VAL_28:%.*]] = addi [[VAL_27]], [[VAL_1]] : index
-// CHECK:             [[VAL_29:%.*]] = load [[VAL_5]]{{\[}}[[VAL_26]], [[VAL_28]]] : memref<?x?xf32>
-// CHECK:             store [[VAL_29]], [[VAL_6]]{{\[}}[[VAL_28]], [[VAL_26]]] : memref<?x?xf32>
+// CHECK:           [[VAL_9:%.*]] = affine.apply #[[MAP0]](){{\[}}[[VAL_2]], [[VAL_0]], [[VAL_4]]]
+// CHECK:           [[VAL_10:%.*]] = affine.apply #[[MAP0]](){{\[}}[[VAL_3]], [[VAL_1]], [[VAL_7]]]
+// CHECK:           gpu.launch blocks([[VAL_11:%.*]], [[VAL_12:%.*]], [[VAL_13:%.*]]) in ([[VAL_14:%.*]] = [[VAL_10]], [[VAL_15:%.*]] = [[VAL_9]], [[VAL_16:%.*]] = [[VAL_8]]) threads([[VAL_17:%.*]], [[VAL_18:%.*]], [[VAL_19:%.*]]) in ([[VAL_20:%.*]] = [[VAL_8]], [[VAL_21:%.*]] = [[VAL_8]], [[VAL_22:%.*]] = [[VAL_8]]) {
+// CHECK:             [[VAL_23:%.*]] = affine.apply #[[MAP1]]([[VAL_12]]){{\[}}[[VAL_4]], [[VAL_0]]]
+// CHECK:             [[VAL_24:%.*]] = affine.apply #[[MAP1]]([[VAL_11]]){{\[}}[[VAL_7]], [[VAL_1]]]
+// CHECK:             [[VAL_25:%.*]] = load [[VAL_5]]{{\[}}[[VAL_23]], [[VAL_24]]] : memref<?x?xf32>
+// CHECK:             store [[VAL_25]], [[VAL_6]]{{\[}}[[VAL_24]], [[VAL_23]]] : memref<?x?xf32>
 // CHECK:             gpu.terminator
 // CHECK:           }
 // CHECK:           return
@@ -69,36 +66,29 @@ func @parallel_loop_tiled(%arg0 : index, %arg1 : index, %arg2 : index,
   return
 }
 
-// CHECK:       #map0 = affine_map<(d0) -> (d0)>
-// CHECK:       module {
+// CHECK:       #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s0 - s1 ceildiv s2)>
+// CHECK:       #[[MAP1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
 
+// CHECK:       module {
 // CHECK-LABEL:   func @parallel_loop_tiled(
-// CHECK-SAME:                              [[VAL_30:%.*]]: index, [[VAL_31:%.*]]: index, [[VAL_32:%.*]]: index, [[VAL_33:%.*]]: index, [[VAL_34:%.*]]: memref<?x?xf32>, [[VAL_35:%.*]]: memref<?x?xf32>) {
-// CHECK:           [[VAL_36:%.*]] = constant 0 : index
-// CHECK:           [[VAL_37:%.*]] = constant 1 : index
-// CHECK:           [[VAL_38:%.*]] = constant 4 : index
-// CHECK:           [[VAL_39:%.*]] = constant 1 : index
-// CHECK:           [[VAL_40:%.*]] = subi [[VAL_32]], [[VAL_30]] : index
-// CHECK:           [[VAL_41:%.*]] = affine.apply #map0([[VAL_40]])
-// CHECK:           [[VAL_42:%.*]] = subi [[VAL_33]], [[VAL_31]] : index
-// CHECK:           [[VAL_43:%.*]] = affine.apply #map0([[VAL_42]])
-// CHECK:           [[VAL_44:%.*]] = subi [[VAL_38]], [[VAL_36]] : index
-// CHECK:           [[VAL_45:%.*]] = affine.apply #map0([[VAL_44]])
-// CHECK:           [[VAL_46:%.*]] = subi [[VAL_38]], [[VAL_36]] : index
-// CHECK:           [[VAL_47:%.*]] = affine.apply #map0([[VAL_46]])
-// CHECK:           gpu.launch blocks([[VAL_48:%.*]], [[VAL_49:%.*]], [[VAL_50:%.*]]) in ([[VAL_51:%.*]] = [[VAL_43]], [[VAL_52:%.*]] = [[VAL_41]], [[VAL_53:%.*]] = [[VAL_39]]) threads([[VAL_54:%.*]], [[VAL_55:%.*]], [[VAL_56:%.*]]) in ([[VAL_57:%.*]] = [[VAL_47]], [[VAL_58:%.*]] = [[VAL_45]], [[VAL_59:%.*]] = [[VAL_39]]) {
-// CHECK:             [[VAL_60:%.*]] = affine.apply #map0([[VAL_49]])
-// CHECK:             [[VAL_61:%.*]] = addi [[VAL_60]], [[VAL_30]] : index
-// CHECK:             [[VAL_62:%.*]] = affine.apply #map0([[VAL_48]])
-// CHECK:             [[VAL_63:%.*]] = addi [[VAL_62]], [[VAL_31]] : index
-// CHECK:             [[VAL_64:%.*]] = affine.apply #map0([[VAL_55]])
-// CHECK:             [[VAL_65:%.*]] = addi [[VAL_64]], [[VAL_36]] : index
-// CHECK:             [[VAL_66:%.*]] = affine.apply #map0([[VAL_54]])
-// CHECK:             [[VAL_67:%.*]] = addi [[VAL_66]], [[VAL_36]] : index
-// CHECK:             [[VAL_68:%.*]] = addi [[VAL_61]], [[VAL_65]] : index
-// CHECK:             [[VAL_69:%.*]] = addi [[VAL_63]], [[VAL_67]] : index
-// CHECK:             [[VAL_70:%.*]] = load [[VAL_34]]{{\[}}[[VAL_68]], [[VAL_69]]] : memref<?x?xf32>
-// CHECK:             store [[VAL_70]], [[VAL_35]]{{\[}}[[VAL_69]], [[VAL_68]]] : memref<?x?xf32>
+// CHECK-SAME:                              [[VAL_26:%.*]]: index, [[VAL_27:%.*]]: index, [[VAL_28:%.*]]: index, [[VAL_29:%.*]]: index, [[VAL_30:%.*]]: memref<?x?xf32>, [[VAL_31:%.*]]: memref<?x?xf32>) {
+// CHECK:           [[VAL_32:%.*]] = constant 0 : index
+// CHECK:           [[VAL_33:%.*]] = constant 1 : index
+// CHECK:           [[VAL_34:%.*]] = constant 4 : index
+// CHECK:           [[VAL_35:%.*]] = constant 1 : index
+// CHECK:           [[VAL_36:%.*]] = affine.apply #[[MAP0]](){{\[}}[[VAL_28]], [[VAL_26]], [[VAL_34]]]
+// CHECK:           [[VAL_37:%.*]] = affine.apply #[[MAP0]](){{\[}}[[VAL_29]], [[VAL_27]], [[VAL_34]]]
+// CHECK:           [[VAL_38:%.*]] = affine.apply #[[MAP0]](){{\[}}[[VAL_34]], [[VAL_32]], [[VAL_33]]]
+// CHECK:           [[VAL_39:%.*]] = affine.apply #[[MAP0]](){{\[}}[[VAL_34]], [[VAL_32]], [[VAL_33]]]
+// CHECK:           gpu.launch blocks([[VAL_40:%.*]], [[VAL_41:%.*]], [[VAL_42:%.*]]) in ([[VAL_43:%.*]] = [[VAL_37]], [[VAL_44:%.*]] = [[VAL_36]], [[VAL_45:%.*]] = [[VAL_35]]) threads([[VAL_46:%.*]], [[VAL_47:%.*]], [[VAL_48:%.*]]) in ([[VAL_49:%.*]] = [[VAL_39]], [[VAL_50:%.*]] = [[VAL_38]], [[VAL_51:%.*]] = [[VAL_35]]) {
+// CHECK:             [[VAL_52:%.*]] = affine.apply #[[MAP1]]([[VAL_41]]){{\[}}[[VAL_34]], [[VAL_26]]]
+// CHECK:             [[VAL_53:%.*]] = affine.apply #[[MAP1]]([[VAL_40]]){{\[}}[[VAL_34]], [[VAL_27]]]
+// CHECK:             [[VAL_54:%.*]] = affine.apply #[[MAP1]]([[VAL_47]]){{\[}}[[VAL_33]], [[VAL_32]]]
+// CHECK:             [[VAL_55:%.*]] = affine.apply #[[MAP1]]([[VAL_46]]){{\[}}[[VAL_33]], [[VAL_32]]]
+// CHECK:             [[VAL_56:%.*]] = addi [[VAL_52]], [[VAL_54]] : index
+// CHECK:             [[VAL_57:%.*]] = addi [[VAL_53]], [[VAL_55]] : index
+// CHECK:             [[VAL_58:%.*]] = load [[VAL_30]]{{\[}}[[VAL_56]], [[VAL_57]]] : memref<?x?xf32>
+// CHECK:             store [[VAL_58]], [[VAL_31]]{{\[}}[[VAL_57]], [[VAL_56]]] : memref<?x?xf32>
 // CHECK:             gpu.terminator
 // CHECK:           }
 // CHECK:           return
@@ -125,21 +115,20 @@ func @parallel_loop_bidy_seq(%arg0 : index, %arg1 : index, %arg2 : index,
   return
 }
 
-// CHECK:       #map0 = affine_map<(d0) -> (d0)>
-// CHECK:       module {
+// CHECK:       #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s0 - s1 ceildiv s2)>
+// CHECK:       #[[MAP1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
 
+// CHECK:       module {
 // CHECK-LABEL:   func @parallel_loop_bidy_seq(
-// CHECK-SAME:                        [[VAL_71:%.*]]: index, [[VAL_72:%.*]]: index, [[VAL_73:%.*]]: index, [[VAL_74:%.*]]: index, [[VAL_75:%.*]]: index, [[VAL_76:%.*]]: memref<?x?xf32>, [[VAL_77:%.*]]: memref<?x?xf32>) {
-// CHECK:           [[VAL_78:%.*]] = constant 2 : index
-// CHECK:           [[VAL_79:%.*]] = constant 1 : index
-// CHECK:           [[VAL_80:%.*]] = subi [[VAL_73]], [[VAL_71]] : index
-// CHECK:           [[VAL_81:%.*]] = affine.apply #map0([[VAL_80]])
-// CHECK:           gpu.launch blocks([[VAL_82:%.*]], [[VAL_83:%.*]], [[VAL_84:%.*]]) in ([[VAL_85:%.*]] = [[VAL_79]], [[VAL_86:%.*]] = [[VAL_81]], [[VAL_87:%.*]] = [[VAL_79]]) threads([[VAL_88:%.*]], [[VAL_89:%.*]], [[VAL_90:%.*]]) in ([[VAL_91:%.*]] = [[VAL_79]], [[VAL_92:%.*]] = [[VAL_79]], [[VAL_93:%.*]] = [[VAL_79]]) {
-// CHECK:             [[VAL_94:%.*]] = affine.apply #map0([[VAL_83]])
-// CHECK:             [[VAL_95:%.*]] = addi [[VAL_94]], [[VAL_71]] : index
-// CHECK:             loop.for [[VAL_96:%.*]] = [[VAL_72]] to [[VAL_74]] step [[VAL_78]] {
-// CHECK:               [[VAL_97:%.*]] = load [[VAL_76]]{{\[}}[[VAL_95]], [[VAL_96]]] : memref<?x?xf32>
-// CHECK:               store [[VAL_97]], [[VAL_77]]{{\[}}[[VAL_96]], [[VAL_95]]] : memref<?x?xf32>
+// CHECK-SAME:                                 [[VAL_59:%.*]]: index, [[VAL_60:%.*]]: index, [[VAL_61:%.*]]: index, [[VAL_62:%.*]]: index, [[VAL_63:%.*]]: index, [[VAL_64:%.*]]: memref<?x?xf32>, [[VAL_65:%.*]]: memref<?x?xf32>) {
+// CHECK:           [[VAL_66:%.*]] = constant 2 : index
+// CHECK:           [[VAL_67:%.*]] = constant 1 : index
+// CHECK:           [[VAL_68:%.*]] = affine.apply #[[MAP0]](){{\[}}[[VAL_61]], [[VAL_59]], [[VAL_63]]]
+// CHECK:           gpu.launch blocks([[VAL_69:%.*]], [[VAL_70:%.*]], [[VAL_71:%.*]]) in ([[VAL_72:%.*]] = [[VAL_67]], [[VAL_73:%.*]] = [[VAL_68]], [[VAL_74:%.*]] = [[VAL_67]]) threads([[VAL_75:%.*]], [[VAL_76:%.*]], [[VAL_77:%.*]]) in ([[VAL_78:%.*]] = [[VAL_67]], [[VAL_79:%.*]] = [[VAL_67]], [[VAL_80:%.*]] = [[VAL_67]]) {
+// CHECK:             [[VAL_81:%.*]] = affine.apply #[[MAP1]]([[VAL_70]]){{\[}}[[VAL_63]], [[VAL_59]]]
+// CHECK:             loop.for [[VAL_82:%.*]] = [[VAL_60]] to [[VAL_62]] step [[VAL_66]] {
+// CHECK:               [[VAL_83:%.*]] = load [[VAL_64]]{{\[}}[[VAL_81]], [[VAL_82]]] : memref<?x?xf32>
+// CHECK:               store [[VAL_83]], [[VAL_65]]{{\[}}[[VAL_82]], [[VAL_81]]] : memref<?x?xf32>
 // CHECK:             }
 // CHECK:             gpu.terminator
 // CHECK:           }
@@ -177,30 +166,27 @@ func @parallel_loop_tiled_seq(%arg0 : index, %arg1 : index, %arg2 : index,
   return
 }
 
-// CHECK:       #map0 = affine_map<(d0) -> (d0)>
-// CHECK:       module {
+// CHECK:       #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s0 - s1 ceildiv s2)>
+// CHECK:       #[[MAP1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
 
+// CHECK:       module {
 // CHECK-LABEL:   func @parallel_loop_tiled_seq(
-// CHECK-SAME:                        [[VAL_98:%.*]]: index, [[VAL_99:%.*]]: index, [[VAL_100:%.*]]: index, [[VAL_101:%.*]]: index, [[VAL_102:%.*]]: memref<?x?xf32>, [[VAL_103:%.*]]: memref<?x?xf32>) {
-// CHECK:           [[VAL_104:%.*]] = constant 0 : index
-// CHECK:           [[VAL_105:%.*]] = constant 1 : index
-// CHECK:           [[VAL_106:%.*]] = constant 4 : index
-// CHECK:           [[VAL_107:%.*]] = constant 1 : index
-// CHECK:           [[VAL_108:%.*]] = subi [[VAL_100]], [[VAL_98]] : index
-// CHECK:           [[VAL_109:%.*]] = affine.apply #map0([[VAL_108]])
-// CHECK:           [[VAL_110:%.*]] = subi [[VAL_106]], [[VAL_104]] : index
-// CHECK:           [[VAL_111:%.*]] = affine.apply #map0([[VAL_110]])
-// CHECK:           gpu.launch blocks([[VAL_112:%.*]], [[VAL_113:%.*]], [[VAL_114:%.*]]) in ([[VAL_115:%.*]] = [[VAL_107]], [[VAL_116:%.*]] = [[VAL_109]], [[VAL_117:%.*]] = [[VAL_107]]) threads([[VAL_118:%.*]], [[VAL_119:%.*]], [[VAL_120:%.*]]) in ([[VAL_121:%.*]] = [[VAL_107]], [[VAL_122:%.*]] = [[VAL_111]], [[VAL_123:%.*]] = [[VAL_107]]) {
-// CHECK:             [[VAL_124:%.*]] = affine.apply #map0([[VAL_113]])
-// CHECK:             [[VAL_125:%.*]] = addi [[VAL_124]], [[VAL_98]] : index
-// CHECK:             loop.for [[VAL_126:%.*]] = [[VAL_99]] to [[VAL_101]] step [[VAL_106]] {
-// CHECK:               [[VAL_127:%.*]] = affine.apply #map0([[VAL_119]])
-// CHECK:               [[VAL_128:%.*]] = addi [[VAL_127]], [[VAL_104]] : index
-// CHECK:               loop.for [[VAL_129:%.*]] = [[VAL_104]] to [[VAL_106]] step [[VAL_105]] {
-// CHECK:                 [[VAL_130:%.*]] = addi [[VAL_125]], [[VAL_128]] : index
-// CHECK:                 [[VAL_131:%.*]] = addi [[VAL_126]], [[VAL_129]] : index
-// CHECK:                 [[VAL_132:%.*]] = load [[VAL_102]]{{\[}}[[VAL_130]], [[VAL_131]]] : memref<?x?xf32>
-// CHECK:                 store [[VAL_132]], [[VAL_103]]{{\[}}[[VAL_131]], [[VAL_130]]] : memref<?x?xf32>
+// CHECK-SAME:                                  [[VAL_84:%.*]]: index, [[VAL_85:%.*]]: index, [[VAL_86:%.*]]: index, [[VAL_87:%.*]]: index, [[VAL_88:%.*]]: memref<?x?xf32>, [[VAL_89:%.*]]: memref<?x?xf32>) {
+// CHECK:           [[VAL_90:%.*]] = constant 0 : index
+// CHECK:           [[VAL_91:%.*]] = constant 1 : index
+// CHECK:           [[VAL_92:%.*]] = constant 4 : index
+// CHECK:           [[VAL_93:%.*]] = constant 1 : index
+// CHECK:           [[VAL_94:%.*]] = affine.apply #[[MAP0]](){{\[}}[[VAL_86]], [[VAL_84]], [[VAL_92]]]
+// CHECK:           [[VAL_95:%.*]] = affine.apply #[[MAP0]](){{\[}}[[VAL_92]], [[VAL_90]], [[VAL_91]]]
+// CHECK:           gpu.launch blocks([[VAL_96:%.*]], [[VAL_97:%.*]], [[VAL_98:%.*]]) in ([[VAL_99:%.*]] = [[VAL_93]], [[VAL_100:%.*]] = [[VAL_94]], [[VAL_101:%.*]] = [[VAL_93]]) threads([[VAL_102:%.*]], [[VAL_103:%.*]], [[VAL_104:%.*]]) in ([[VAL_105:%.*]] = [[VAL_93]], [[VAL_106:%.*]] = [[VAL_95]], [[VAL_107:%.*]] = [[VAL_93]]) {
+// CHECK:             [[VAL_108:%.*]] = affine.apply #[[MAP1]]([[VAL_97]]){{\[}}[[VAL_92]], [[VAL_84]]]
+// CHECK:             loop.for [[VAL_109:%.*]] = [[VAL_85]] to [[VAL_87]] step [[VAL_92]] {
+// CHECK:               [[VAL_110:%.*]] = affine.apply #[[MAP1]]([[VAL_103]]){{\[}}[[VAL_91]], [[VAL_90]]]
+// CHECK:               loop.for [[VAL_111:%.*]] = [[VAL_90]] to [[VAL_92]] step [[VAL_91]] {
+// CHECK:                 [[VAL_112:%.*]] = addi [[VAL_108]], [[VAL_110]] : index
+// CHECK:                 [[VAL_113:%.*]] = addi [[VAL_109]], [[VAL_111]] : index
+// CHECK:                 [[VAL_114:%.*]] = load [[VAL_88]]{{\[}}[[VAL_112]], [[VAL_113]]] : memref<?x?xf32>
+// CHECK:                 store [[VAL_114]], [[VAL_89]]{{\[}}[[VAL_113]], [[VAL_112]]] : memref<?x?xf32>
 // CHECK:               }
 // CHECK:             }
 // CHECK:             gpu.terminator
@@ -212,9 +198,9 @@ func @parallel_loop_tiled_seq(%arg0 : index, %arg1 : index, %arg2 : index,
 // -----
 
 #map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
-#map2 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
-#map3 = affine_map<(d0) -> (d0)>
+#map1 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+#map2 = affine_map<(d0)[s0] -> (3, -d0 + s0)>
+#map3 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
 
 module {
   func @sum(%arg0: memref<?x?xf32, #map0>, %arg1: memref<?x?xf32, #map0>, %arg2: memref<?x?xf32, #map0>) {
@@ -226,96 +212,86 @@ module {
     %1 = dim %arg0, 1 : memref<?x?xf32, #map0>
     loop.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c2, %c3) {
       %2 = dim %arg0, 0 : memref<?x?xf32, #map0>
-      %3 = affine.min #map1(%c2, %2, %arg3)
+      %3 = affine.min #map1(%arg3)[%2]
       %4 = dim %arg0, 1 : memref<?x?xf32, #map0>
-      %5 = affine.min #map1(%c3, %4, %arg4)
-      %6 = std.subview %arg0[%arg3, %arg4][%3, %5][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map2>
+      %5 = affine.min #map2(%arg4)[%4]
+      %6 = std.subview %arg0[%arg3, %arg4][%3, %5][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
       %7 = dim %arg1, 0 : memref<?x?xf32, #map0>
-      %8 = affine.min #map1(%c2, %7, %arg3)
+      %8 = affine.min #map1(%arg3)[%7]
       %9 = dim %arg1, 1 : memref<?x?xf32, #map0>
-      %10 = affine.min #map1(%c3, %9, %arg4)
-      %11 = std.subview %arg1[%arg3, %arg4][%8, %10][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map2>
+      %10 = affine.min #map2(%arg4)[%9]
+      %11 = std.subview %arg1[%arg3, %arg4][%8, %10][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
       %12 = dim %arg2, 0 : memref<?x?xf32, #map0>
-      %13 = affine.min #map1(%c2, %12, %arg3)
+      %13 = affine.min #map1(%arg3)[%12]
       %14 = dim %arg2, 1 : memref<?x?xf32, #map0>
-      %15 = affine.min #map1(%c3, %14, %arg4)
-      %16 = std.subview %arg2[%arg3, %arg4][%13, %15][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map2>
+      %15 = affine.min #map2(%arg4)[%14]
+      %16 = std.subview %arg2[%arg3, %arg4][%13, %15][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
       loop.parallel (%arg5, %arg6) = (%c0, %c0) to (%3, %5) step (%c1, %c1) {
-        %17 = load %6[%arg5, %arg6] : memref<?x?xf32, #map2>
-        %18 = load %11[%arg5, %arg6] : memref<?x?xf32, #map2>
-        %19 = load %16[%arg5, %arg6] : memref<?x?xf32, #map2>
+        %17 = load %6[%arg5, %arg6] : memref<?x?xf32, #map3>
+        %18 = load %11[%arg5, %arg6] : memref<?x?xf32, #map3>
+        %19 = load %16[%arg5, %arg6] : memref<?x?xf32, #map3>
         %20 = addf %17, %18 : f32
-        store %20, %16[%arg5, %arg6] : memref<?x?xf32, #map2>
+        store %20, %16[%arg5, %arg6] : memref<?x?xf32, #map3>
         loop.yield
-      } { mapping = [
-          {processor = 3, map = #map3, bound = #map3},
-          {processor = 4, map = #map3, bound = #map3}
-        ] }
+      } {mapping = [{bound = affine_map<(d0) -> (d0)>, map = affine_map<(d0) -> (d0)>, processor = 3 : i64}, {bound = affine_map<(d0) -> (d0)>, map = affine_map<(d0) -> (d0)>, processor = 4 : i64}]}
       loop.yield
-    } { mapping = [
-        {processor = 0, map = #map3, bound = #map3},
-        {processor = 1, map = #map3, bound = #map3}
-    ] }
+    } {mapping = [{bound = affine_map<(d0) -> (d0)>, map = affine_map<(d0) -> (d0)>, processor = 0 : i64}, {bound = affine_map<(d0) -> (d0)>, map = affine_map<(d0) -> (d0)>, processor = 1 : i64}]}
     return
   }
 }
 
-// CHECK:       #map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK:       #map1 = affine_map<(d0) -> (d0)>
-// CHECK:       #map2 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
-// CHECK:       #map3 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
-// CHECK:       module {
+// CHECK:       #[[MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK:       #[[MAP1:.*]] = affine_map<()[s0, s1, s2] -> (s0 - s1 ceildiv s2)>
+// CHECK:       #[[MAP2:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK:       #[[MAP3:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+// CHECK:       #[[MAP4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
+// CHECK:       #[[MAP5:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
 
+// CHECK:       module {
 // CHECK-LABEL:   func @sum(
-// CHECK-SAME:              [[VAL_133:%.*]]: memref<?x?xf32, #map0>, [[VAL_134:%.*]]: memref<?x?xf32, #map0>, [[VAL_135:%.*]]: memref<?x?xf32, #map0>) {
-// CHECK:           [[VAL_136:%.*]] = constant 1 : index
-// CHECK:           [[VAL_137:%.*]] = constant 0 : index
-// CHECK:           [[VAL_138:%.*]] = constant 3 : index
-// CHECK:           [[VAL_139:%.*]] = constant 2 : index
-// CHECK:           [[VAL_140:%.*]] = dim [[VAL_133]], 0 : memref<?x?xf32, #map0>
-// CHECK:           [[VAL_141:%.*]] = dim [[VAL_133]], 1 : memref<?x?xf32, #map0>
-// CHECK:           [[VAL_142:%.*]] = constant 1 : index
-// CHECK:           [[VAL_143:%.*]] = subi [[VAL_140]], [[VAL_137]] : index
-// CHECK:           [[VAL_144:%.*]] = affine.apply #map1([[VAL_143]])
-// CHECK:           [[VAL_145:%.*]] = subi [[VAL_141]], [[VAL_137]] : index
-// CHECK:           [[VAL_146:%.*]] = affine.apply #map1([[VAL_145]])
-// CHECK:           [[VAL_148:%.*]] = subi [[VAL_139]], [[VAL_137]] : index
-// CHECK:           [[VAL_149:%.*]] = affine.apply #map1([[VAL_148]])
-// CHECK:           [[VAL_151:%.*]] = subi [[VAL_138]], [[VAL_137]] : index
-// CHECK:           [[VAL_152:%.*]] = affine.apply #map1([[VAL_151]])
-// CHECK:           gpu.launch blocks([[VAL_153:%.*]], [[VAL_154:%.*]], [[VAL_155:%.*]]) in ([[VAL_156:%.*]] = [[VAL_144]], [[VAL_157:%.*]] = [[VAL_146]], [[VAL_158:%.*]] = [[VAL_142]]) threads([[VAL_159:%.*]], [[VAL_160:%.*]], [[VAL_161:%.*]]) in ([[VAL_162:%.*]] = [[VAL_149]], [[VAL_163:%.*]] = [[VAL_152]], [[VAL_164:%.*]] = [[VAL_142]]) {
-// CHECK:             [[VAL_165:%.*]] = affine.apply #map1([[VAL_153]])
-// CHECK:             [[VAL_166:%.*]] = addi [[VAL_165]], [[VAL_137]] : index
-// CHECK:             [[VAL_167:%.*]] = affine.apply #map1([[VAL_154]])
-// CHECK:             [[VAL_168:%.*]] = addi [[VAL_167]], [[VAL_137]] : index
-// CHECK:             [[VAL_169:%.*]] = dim [[VAL_133]], 0 : memref<?x?xf32, #map0>
-// CHECK:             [[VAL_170:%.*]] = affine.min #map2([[VAL_139]], [[VAL_169]], [[VAL_166]])
-// CHECK:             [[VAL_171:%.*]] = dim [[VAL_133]], 1 : memref<?x?xf32, #map0>
-// CHECK:             [[VAL_172:%.*]] = affine.min #map2([[VAL_138]], [[VAL_171]], [[VAL_168]])
-// CHECK:             [[VAL_173:%.*]] = std.subview [[VAL_133]]{{\[}}[[VAL_166]], [[VAL_168]]]{{\[}}[[VAL_170]], [[VAL_172]]]{{\[}}[[VAL_136]], [[VAL_136]]] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
-// CHECK:             [[VAL_174:%.*]] = dim [[VAL_134]], 0 : memref<?x?xf32, #map0>
-// CHECK:             [[VAL_175:%.*]] = affine.min #map2([[VAL_139]], [[VAL_174]], [[VAL_166]])
-// CHECK:             [[VAL_176:%.*]] = dim [[VAL_134]], 1 : memref<?x?xf32, #map0>
-// CHECK:             [[VAL_177:%.*]] = affine.min #map2([[VAL_138]], [[VAL_176]], [[VAL_168]])
-// CHECK:             [[VAL_178:%.*]] = std.subview [[VAL_134]]{{\[}}[[VAL_166]], [[VAL_168]]]{{\[}}[[VAL_175]], [[VAL_177]]]{{\[}}[[VAL_136]], [[VAL_136]]] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
-// CHECK:             [[VAL_179:%.*]] = dim [[VAL_135]], 0 : memref<?x?xf32, #map0>
-// CHECK:             [[VAL_180:%.*]] = affine.min #map2([[VAL_139]], [[VAL_179]], [[VAL_166]])
-// CHECK:             [[VAL_181:%.*]] = dim [[VAL_135]], 1 : memref<?x?xf32, #map0>
-// CHECK:             [[VAL_182:%.*]] = affine.min #map2([[VAL_138]], [[VAL_181]], [[VAL_168]])
-// CHECK:             [[VAL_183:%.*]] = std.subview [[VAL_135]]{{\[}}[[VAL_166]], [[VAL_168]]]{{\[}}[[VAL_180]], [[VAL_182]]]{{\[}}[[VAL_136]], [[VAL_136]]] : memref<?x?xf32, #map0> to memref<?x?xf32, #map3>
-// CHECK:             [[VAL_184:%.*]] = affine.apply #map1([[VAL_159]])
-// CHECK:             [[VAL_185:%.*]] = addi [[VAL_184]], [[VAL_137]] : index
-// CHECK:             [[VAL_186:%.*]] = cmpi "slt", [[VAL_185]], [[VAL_170]] : index
-// CHECK:             loop.if [[VAL_186]] {
-// CHECK:               [[VAL_187:%.*]] = affine.apply #map1([[VAL_160]])
-// CHECK:               [[VAL_188:%.*]] = addi [[VAL_187]], [[VAL_137]] : index
-// CHECK:               [[VAL_189:%.*]] = cmpi "slt", [[VAL_188]], [[VAL_172]] : index
-// CHECK:               loop.if [[VAL_189]] {
-// CHECK:                 [[VAL_190:%.*]] = load [[VAL_173]]{{\[}}[[VAL_185]], [[VAL_188]]] : memref<?x?xf32, #map3>
-// CHECK:                 [[VAL_191:%.*]] = load [[VAL_178]]{{\[}}[[VAL_185]], [[VAL_188]]] : memref<?x?xf32, #map3>
-// CHECK:                 [[VAL_192:%.*]] = load [[VAL_183]]{{\[}}[[VAL_185]], [[VAL_188]]] : memref<?x?xf32, #map3>
-// CHECK:                 [[VAL_193:%.*]] = addf [[VAL_190]], [[VAL_191]] : f32
-// CHECK:                 store [[VAL_193]], [[VAL_183]]{{\[}}[[VAL_185]], [[VAL_188]]] : memref<?x?xf32, #map3>
+// CHECK-SAME:              [[VAL_0:%.*]]: memref<?x?xf32, #[[MAP0]]>, [[VAL_1:%.*]]: memref<?x?xf32, #[[MAP0]]>, [[VAL_2:%.*]]: memref<?x?xf32, #[[MAP0]]>) {
+// CHECK:           [[VAL_3:%.*]] = constant 1 : index
+// CHECK:           [[VAL_4:%.*]] = constant 0 : index
+// CHECK:           [[VAL_5:%.*]] = constant 3 : index
+// CHECK:           [[VAL_6:%.*]] = constant 2 : index
+// CHECK:           [[VAL_7:%.*]] = dim [[VAL_0]], 0 : memref<?x?xf32, #[[MAP0]]>
+// CHECK:           [[VAL_8:%.*]] = dim [[VAL_0]], 1 : memref<?x?xf32, #[[MAP0]]>
+// CHECK:           [[VAL_9:%.*]] = constant 1 : index
+// CHECK:           [[VAL_10:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_7]], [[VAL_4]], [[VAL_6]]]
+// CHECK:           [[VAL_11:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_8]], [[VAL_4]], [[VAL_5]]]
+// CHECK:           [[VAL_12:%.*]] = constant 2 : index
+// CHECK:           [[VAL_13:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_12]], [[VAL_4]], [[VAL_3]]]
+// CHECK:           [[VAL_14:%.*]] = constant 3 : index
+// CHECK:           [[VAL_15:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_14]], [[VAL_4]], [[VAL_3]]]
+// CHECK:           gpu.launch blocks([[VAL_16:%.*]], [[VAL_17:%.*]], [[VAL_18:%.*]]) in ([[VAL_19:%.*]] = [[VAL_10]], [[VAL_20:%.*]] = [[VAL_11]], [[VAL_21:%.*]] = [[VAL_9]]) threads([[VAL_22:%.*]], [[VAL_23:%.*]], [[VAL_24:%.*]]) in ([[VAL_25:%.*]] = [[VAL_13]], [[VAL_26:%.*]] = [[VAL_15]], [[VAL_27:%.*]] = [[VAL_9]]) {
+// CHECK:             [[VAL_28:%.*]] = affine.apply #[[MAP2]]([[VAL_16]]){{\[}}[[VAL_6]], [[VAL_4]]]
+// CHECK:             [[VAL_29:%.*]] = affine.apply #[[MAP2]]([[VAL_17]]){{\[}}[[VAL_5]], [[VAL_4]]]
+// CHECK:             [[VAL_30:%.*]] = dim [[VAL_0]], 0 : memref<?x?xf32, #[[MAP0]]>
+// CHECK:             [[VAL_31:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_30]]]
+// CHECK:             [[VAL_32:%.*]] = dim [[VAL_0]], 1 : memref<?x?xf32, #[[MAP0]]>
+// CHECK:             [[VAL_33:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_32]]]
+// CHECK:             [[VAL_34:%.*]] = std.subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]]{{\[}}[[VAL_31]], [[VAL_33]]]{{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
+// CHECK:             [[VAL_35:%.*]] = dim [[VAL_1]], 0 : memref<?x?xf32, #[[MAP0]]>
+// CHECK:             [[VAL_36:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_35]]]
+// CHECK:             [[VAL_37:%.*]] = dim [[VAL_1]], 1 : memref<?x?xf32, #[[MAP0]]>
+// CHECK:             [[VAL_38:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_37]]]
+// CHECK:             [[VAL_39:%.*]] = std.subview [[VAL_1]]{{\[}}[[VAL_28]], [[VAL_29]]]{{\[}}[[VAL_36]], [[VAL_38]]]{{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
+// CHECK:             [[VAL_40:%.*]] = dim [[VAL_2]], 0 : memref<?x?xf32, #[[MAP0]]>
+// CHECK:             [[VAL_41:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_40]]]
+// CHECK:             [[VAL_42:%.*]] = dim [[VAL_2]], 1 : memref<?x?xf32, #[[MAP0]]>
+// CHECK:             [[VAL_43:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_42]]]
+// CHECK:             [[VAL_44:%.*]] = std.subview [[VAL_2]]{{\[}}[[VAL_28]], [[VAL_29]]]{{\[}}[[VAL_41]], [[VAL_43]]]{{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
+// CHECK:             [[VAL_45:%.*]] = affine.apply #[[MAP2]]([[VAL_22]]){{\[}}[[VAL_3]], [[VAL_4]]]
+// CHECK:             [[VAL_46:%.*]] = cmpi "slt", [[VAL_45]], [[VAL_31]] : index
+// CHECK:             loop.if [[VAL_46]] {
+// CHECK:               [[VAL_47:%.*]] = affine.apply #[[MAP2]]([[VAL_23]]){{\[}}[[VAL_3]], [[VAL_4]]]
+// CHECK:               [[VAL_48:%.*]] = cmpi "slt", [[VAL_47]], [[VAL_33]] : index
+// CHECK:               loop.if [[VAL_48]] {
+// CHECK:                 [[VAL_49:%.*]] = load [[VAL_34]]{{\[}}[[VAL_45]], [[VAL_47]]] : memref<?x?xf32, #[[MAP5]]>
+// CHECK:                 [[VAL_50:%.*]] = load [[VAL_39]]{{\[}}[[VAL_45]], [[VAL_47]]] : memref<?x?xf32, #[[MAP5]]>
+// CHECK:                 [[VAL_51:%.*]] = load [[VAL_44]]{{\[}}[[VAL_45]], [[VAL_47]]] : memref<?x?xf32, #[[MAP5]]>
+// CHECK:                 [[VAL_52:%.*]] = addf [[VAL_49]], [[VAL_50]] : f32
+// CHECK:                 store [[VAL_52]], [[VAL_44]]{{\[}}[[VAL_45]], [[VAL_47]]] : memref<?x?xf32, #[[MAP5]]>
 // CHECK:               }
 // CHECK:             }
 // CHECK:             gpu.terminator
@@ -323,4 +299,3 @@ module {
 // CHECK:           return
 // CHECK:         }
 // CHECK:       }
-


        


More information about the llvm-branch-commits mailing list