[Mlir-commits] [mlir] ced0dd8 - [MLIR] Guard DMA-specific logic with DMA option

Tim Shen llvmlistbot at llvm.org
Wed Mar 11 11:23:33 PDT 2020


Author: Tim Shen
Date: 2020-03-11T11:23:13-07:00
New Revision: ced0dd8e5104eff47cefdf3701d4980546d9afc5

URL: https://github.com/llvm/llvm-project/commit/ced0dd8e5104eff47cefdf3701d4980546d9afc5
DIFF: https://github.com/llvm/llvm-project/commit/ced0dd8e5104eff47cefdf3701d4980546d9afc5.diff

LOG: [MLIR] Guard DMA-specific logic with DMA option

Differential Revision: https://reviews.llvm.org/D75963

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/LoopUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index c2cd2333f531..1c9ac5e84754 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1411,22 +1411,24 @@ static LogicalResult generateCopy(
   auto numElementsSSA =
       top.create<ConstantIndexOp>(loc, numElements.getValue());
 
-  SmallVector<StrideInfo, 4> strideInfos;
-  getMultiLevelStrides(region, fastBufferShape, &strideInfos);
-
-  // TODO(bondhugula): use all stride levels once DmaStartOp is extended for
-  // multi-level strides.
-  if (strideInfos.size() > 1) {
-    LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
-    return failure();
-  }
+  Value dmaStride = nullptr;
+  Value numEltPerDmaStride = nullptr;
+  if (copyOptions.generateDma) {
+    SmallVector<StrideInfo, 4> dmaStrideInfos;
+    getMultiLevelStrides(region, fastBufferShape, &dmaStrideInfos);
+
+    // TODO(bondhugula): use all stride levels once DmaStartOp is extended for
+    // multi-level strides.
+    if (dmaStrideInfos.size() > 1) {
+      LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
+      return failure();
+    }
 
-  Value stride = nullptr;
-  Value numEltPerStride = nullptr;
-  if (!strideInfos.empty()) {
-    stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
-    numEltPerStride =
-        top.create<ConstantIndexOp>(loc, strideInfos[0].numEltPerStride);
+    if (!dmaStrideInfos.empty()) {
+      dmaStride = top.create<ConstantIndexOp>(loc, dmaStrideInfos[0].stride);
+      numEltPerDmaStride =
+          top.create<ConstantIndexOp>(loc, dmaStrideInfos[0].numEltPerStride);
+    }
   }
 
   // Record the last operation where we want the memref replacement to end. We
@@ -1469,13 +1471,13 @@ static LogicalResult generateCopy(
       b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
                                  fastMemRef, bufAffineMap, bufIndices,
                                  tagMemRef, tagAffineMap, tagIndices,
-                                 numElementsSSA, stride, numEltPerStride);
+                                 numElementsSSA, dmaStride, numEltPerDmaStride);
     } else {
       // DMA non-blocking write from fast buffer to the original memref.
       auto op = b.create<AffineDmaStartOp>(
           loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
           memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
-          stride, numEltPerStride);
+          dmaStride, numEltPerDmaStride);
       // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
       // end to mark end of block range being processed.
       if (isCopyOutAtEndOfBlock)


        


More information about the Mlir-commits mailing list