[Mlir-commits] [mlir] 50000ab - [mlir] Use affine.apply when distributing to processors

Lei Zhang llvmlistbot at llvm.org
Tue Mar 9 05:40:38 PST 2021


Author: Lei Zhang
Date: 2021-03-09T08:37:20-05:00
New Revision: 50000abe3cb25f45ec0f293a66a81499726943de

URL: https://github.com/llvm/llvm-project/commit/50000abe3cb25f45ec0f293a66a81499726943de
DIFF: https://github.com/llvm/llvm-project/commit/50000abe3cb25f45ec0f293a66a81499726943de.diff

LOG: [mlir] Use affine.apply when distributing to processors

This makes it easy to compose the distribution computation with
other affine computations.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D98171

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/test/Dialect/Linalg/tile-and-distribute.mlir
    mlir/test/Transforms/parametric-mapping.mlir
    mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
    mlir/test/lib/Transforms/TestLoopMapping.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 71a0fc8e5d89..3fa04514e81f 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -2152,17 +2152,28 @@ void mlir::mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef<Value> processorId,
 
   OpBuilder b(forOp);
   Location loc(forOp.getLoc());
-  Value mul = processorId.front();
-  for (unsigned i = 1, e = processorId.size(); i < e; ++i)
-    mul = b.create<AddIOp>(loc, b.create<MulIOp>(loc, mul, numProcessors[i]),
-                           processorId[i]);
-  Value lb = b.create<AddIOp>(loc, forOp.lowerBound(),
-                              b.create<MulIOp>(loc, forOp.step(), mul));
+  AffineExpr lhs, rhs;
+  bindSymbols(forOp.getContext(), lhs, rhs);
+  auto mulMap = AffineMap::get(0, 2, lhs * rhs);
+  auto addMap = AffineMap::get(0, 2, lhs + rhs);
+
+  Value linearIndex = processorId.front();
+  for (unsigned i = 1, e = processorId.size(); i < e; ++i) {
+    auto mulApplyOp = b.create<AffineApplyOp>(
+        loc, mulMap, ValueRange{linearIndex, numProcessors[i]});
+    linearIndex = b.create<AffineApplyOp>(
+        loc, addMap, ValueRange{mulApplyOp, processorId[i]});
+  }
+
+  auto mulApplyOp = b.create<AffineApplyOp>(
+      loc, mulMap, ValueRange{linearIndex, forOp.step()});
+  Value lb = b.create<AffineApplyOp>(
+      loc, addMap, ValueRange{mulApplyOp, forOp.lowerBound()});
   forOp.setLowerBound(lb);
 
   Value step = forOp.step();
   for (auto numProcs : numProcessors)
-    step = b.create<MulIOp>(loc, step, numProcs);
+    step = b.create<AffineApplyOp>(loc, mulMap, ValueRange{numProcs, step});
   forOp.setStep(step);
 }
 

diff  --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
index 94c0e546db01..d566701d7bb6 100644
--- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
@@ -175,23 +175,28 @@ func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
 
 // -----
 
-// CHECK-LABEL: func @matmul_tensors(
+//      CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+//      CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//      CHECK: func @matmul_tensors(
 // CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
 // CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
 // CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
 func @matmul_tensors(
   %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
     -> tensor<?x?xf32> {
-//      CHECK: %[[C8:.*]] = constant 8 : index
+//  CHECK-DAG: %[[C8:.*]] = constant 8 : index
+//  CHECK-DAG: %[[C0:.*]] = constant 0 : index
 //      CHECK: %[[BIDY:.*]] = "gpu.block_id"() {dimension = "y"}
 //      CHECK: %[[NBLOCKSY:.*]] = "gpu.grid_dim"() {dimension = "y"}
 //      CHECK: %[[BIDX:.*]] = "gpu.block_id"() {dimension = "x"}
 //      CHECK: %[[NBLOCKSX:.*]] = "gpu.grid_dim"() {dimension = "x"}
-//      CHECK: %[[LBY:.*]] = muli %[[BIDY]], %[[C8]] : index
-//      CHECK: %[[STEPY:.*]] = muli %[[NBLOCKSY]], %[[C8]] : index
+//      CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]]
+//      CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
+//      CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]]
 //      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
-//      CHECK: %[[LBX:.*]] = muli %[[BIDX]], %[[C8]] : index
-//      CHECK: %[[STEPX:.*]] = muli %[[NBLOCKSX]], %[[C8]] : index
+//      CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]]
+//      CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]]
+//      CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]]
 //      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
 //      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
 //      CHECK:       %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>

diff  --git a/mlir/test/Transforms/parametric-mapping.mlir b/mlir/test/Transforms/parametric-mapping.mlir
index 2ad24e1ae6b8..6988d038f6d3 100644
--- a/mlir/test/Transforms/parametric-mapping.mlir
+++ b/mlir/test/Transforms/parametric-mapping.mlir
@@ -1,21 +1,25 @@
 // RUN: mlir-opt -allow-unregistered-dialect -test-mapping-to-processing-elements %s | FileCheck %s
 
-// CHECK-LABEL: @map1d
-//       CHECK: (%[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index) {
+// CHECK: #[[mul_map:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK: #[[add_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+//      CHECK: func @map1d
+// CHECK-SAME: (%[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index)
 func @map1d(%lb: index, %ub: index, %step: index) {
   // CHECK: %[[threads:.*]]:2 = "new_processor_id_and_range"() : () -> (index, index)
   %0:2 = "new_processor_id_and_range"() : () -> (index, index)
 
-  // CHECK: %[[thread_offset:.*]] = muli %[[step]], %[[threads]]#0
-  // CHECK: %[[new_lb:.*]] = addi %[[lb]], %[[thread_offset]]
-  // CHECK: %[[new_step:.*]] = muli %[[step]], %[[threads]]#1
+  // CHECK: %[[thread_offset:.+]] = affine.apply #[[mul_map]]()[%[[threads]]#0, %[[step]]]
+  // CHECK: %[[new_lb:.+]] = affine.apply #[[add_map]]()[%[[thread_offset]], %[[lb]]]
+  // CHECK: %[[new_step:.+]] = affine.apply #[[mul_map]]()[%[[threads]]#1, %[[step]]]
+
   // CHECK: scf.for %{{.*}} = %[[new_lb]] to %[[ub]] step %[[new_step]] {
   scf.for %i = %lb to %ub step %step {}
   return
 }
 
-// CHECK-LABEL: @map2d
-//       CHECK: (%[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index) {
+//      CHECK: func @map2d
+// CHECK-SAME: (%[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index)
 func @map2d(%lb : index, %ub : index, %step : index) {
   // CHECK: %[[blocks:.*]]:2 = "new_processor_id_and_range"() : () -> (index, index)
   %0:2 = "new_processor_id_and_range"() : () -> (index, index)
@@ -24,24 +28,25 @@ func @map2d(%lb : index, %ub : index, %step : index) {
   %1:2 = "new_processor_id_and_range"() : () -> (index, index)
 
   // blockIdx.x * blockDim.x
-  // CHECK: %[[bidxXbdimx:.*]] = muli %[[blocks]]#0, %[[threads]]#1 : index
+  // CHECK: %[[bidxXbdimx:.+]] = affine.apply #[[mul_map]]()[%[[blocks]]#0, %[[threads]]#1]
   //
   // threadIdx.x + blockIdx.x * blockDim.x
-  // CHECK: %[[tidxpbidxXbdimx:.*]] = addi %[[bidxXbdimx]], %[[threads]]#0 : index
+  // CHECK: %[[tidxpbidxXbdimx:.+]] = affine.apply #[[add_map]]()[%[[bidxXbdimx]], %[[threads]]#0]
   //
   // thread_offset = step * (threadIdx.x + blockIdx.x * blockDim.x)
-  // CHECK: %[[thread_offset:.*]] = muli %[[step]], %[[tidxpbidxXbdimx]] : index
+  // CHECK: %[[thread_offset:.+]] = affine.apply #[[mul_map]]()[%[[tidxpbidxXbdimx]], %[[step]]]
   //
   // new_lb = lb + thread_offset
-  // CHECK: %[[new_lb:.*]] = addi %[[lb]], %[[thread_offset]] : index
+  // CHECK: %[[new_lb:.+]] = affine.apply #[[add_map]]()[%[[thread_offset]], %[[lb]]]
   //
   // stepXgdimx = step * gridDim.x
-  // CHECK: %[[stepXgdimx:.*]] = muli %[[step]], %[[blocks]]#1 : index
+  // CHECK: %[[stepXgdimx:.+]] = affine.apply #[[mul_map]]()[%[[blocks]]#1, %[[step]]]
   //
   // new_step = step * gridDim.x * blockDim.x
-  // CHECK: %[[new_step:.*]] = muli %[[stepXgdimx]], %[[threads]]#1 : index
+  // CHECK: %[[new_step:.+]] = affine.apply #[[mul_map]]()[%[[threads]]#1, %[[stepXgdimx]]]
   //
   // CHECK: scf.for %{{.*}} = %[[new_lb]] to %[[ub]] step %[[new_step]] {
+
   scf.for %i = %lb to %ub step %step {}
   return
 }

diff  --git a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
index c8c959047250..ac53d97a3d3a 100644
--- a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
+++ b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/GPU/MemoryPromotion.h"
 #include "mlir/Dialect/SCF/SCF.h"
@@ -30,7 +31,7 @@ class TestGpuMemoryPromotionPass
     : public PassWrapper<TestGpuMemoryPromotionPass,
                          OperationPass<gpu::GPUFuncOp>> {
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<StandardOpsDialect, scf::SCFDialect>();
+    registry.insert<AffineDialect, StandardOpsDialect, scf::SCFDialect>();
   }
 
   void runOnOperation() override {

diff  --git a/mlir/test/lib/Transforms/TestLoopMapping.cpp b/mlir/test/lib/Transforms/TestLoopMapping.cpp
index 591fac32698f..20ec5a11a1c3 100644
--- a/mlir/test/lib/Transforms/TestLoopMapping.cpp
+++ b/mlir/test/lib/Transforms/TestLoopMapping.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
@@ -27,6 +28,10 @@ class TestLoopMappingPass
 public:
   explicit TestLoopMappingPass() {}
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, scf::SCFDialect>();
+  }
+
   void runOnFunction() override {
     FuncOp func = getFunction();
 


        


More information about the Mlir-commits mailing list