[Mlir-commits] [mlir] 8b97e17 - [mlir] Simplify code generated by ConvertToLLVMPattern::getStridedElementPtr().

Christian Sigg llvmlistbot at llvm.org
Wed Nov 18 02:52:20 PST 2020


Author: Christian Sigg
Date: 2020-11-18T11:52:09+01:00
New Revision: 8b97e17d161a177ae54c989e6e550930f6a75876

URL: https://github.com/llvm/llvm-project/commit/8b97e17d161a177ae54c989e6e550930f6a75876
DIFF: https://github.com/llvm/llvm-project/commit/8b97e17d161a177ae54c989e6e550930f6a75876.diff

LOG: [mlir] Simplify code generated by ConvertToLLVMPattern::getStridedElementPtr().

Make the interface match the one of ConvertToLLVMPattern::getDataPtr() (to be removed in a separate change).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D91599

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
    mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
    mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index dbf38de56a0a..919a93ac84a2 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -505,18 +505,18 @@ class ConvertToLLVMPattern : public ConversionPattern {
 
   // This is a strided getElementPtr variant that linearizes subscripts as:
   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
-  Value getStridedElementPtr(Location loc, Type elementTypePtr,
-                             Value descriptor, ValueRange indices,
-                             ArrayRef<int64_t> strides, int64_t offset,
+  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
+                             ValueRange indices,
                              ConversionPatternRewriter &rewriter) const;
 
-  /// Returns if the givem memref type is supported.
-  bool isSupportedMemRefType(MemRefType type) const;
-
+  // Forwards to getStridedElementPtr. TODO: remove.
   Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
                    ValueRange indices,
                    ConversionPatternRewriter &rewriter) const;
 
+  /// Returns if the givem memref type is supported.
+  bool isSupportedMemRefType(MemRefType type) const;
+
   /// Returns the type of a pointer to an element of the memref.
   Type getElementPtrType(MemRefType type) const;
 

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 870b91034361..17187e933cfa 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1042,39 +1042,45 @@ Value ConvertToLLVMPattern::createIndexConstant(
 }
 
 Value ConvertToLLVMPattern::getStridedElementPtr(
-    Location loc, Type elementTypePtr, Value descriptor, ValueRange indices,
-    ArrayRef<int64_t> strides, int64_t offset,
+    Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
     ConversionPatternRewriter &rewriter) const {
-  MemRefDescriptor memRefDescriptor(descriptor);
 
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  auto successStrides = getStridesAndOffset(type, strides, offset);
+  assert(succeeded(successStrides) && "unexpected non-strided memref");
+  (void)successStrides;
+
+  MemRefDescriptor memRefDescriptor(memRefDesc);
   Value base = memRefDescriptor.alignedPtr(rewriter, loc);
-  Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
-                          ? memRefDescriptor.offset(rewriter, loc)
-                          : createIndexConstant(rewriter, loc, offset);
+
+  Value index;
+  if (offset != 0) // Skip if offset is zero.
+    index = offset == MemRefType::getDynamicStrideOrOffset()
+                ? memRefDescriptor.offset(rewriter, loc)
+                : createIndexConstant(rewriter, loc, offset);
 
   for (int i = 0, e = indices.size(); i < e; ++i) {
-    Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
-                       ? memRefDescriptor.stride(rewriter, loc, i)
-                       : createIndexConstant(rewriter, loc, strides[i]);
-    Value additionalOffset =
-        rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
-    offsetValue =
-        rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
+    Value increment = indices[i];
+    if (strides[i] != 1) { // Skip if stride is 1.
+      Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
+                         ? memRefDescriptor.stride(rewriter, loc, i)
+                         : createIndexConstant(rewriter, loc, strides[i]);
+      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+    }
+    index =
+        index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
   }
-  return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
+
+  LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType();
+  return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
+               : base;
 }
 
 Value ConvertToLLVMPattern::getDataPtr(
     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
     ConversionPatternRewriter &rewriter) const {
-  LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementPtrType();
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  auto successStrides = getStridesAndOffset(type, strides, offset);
-  assert(succeeded(successStrides) && "unexpected non-strided memref");
-  (void)successStrides;
-  return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
-                              offset, rewriter);
+  return getStridedElementPtr(loc, type, memRefDesc, indices, rewriter);
 }
 
 // Check if the MemRefType `type` is supported by the lowering. We currently
@@ -3044,8 +3050,9 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
     LoadOp::Adaptor transformed(operands);
     auto type = loadOp.getMemRefType();
 
-    Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
-                               transformed.indices(), rewriter);
+    Value dataPtr =
+        getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+                             transformed.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
     return success();
   }
@@ -3062,8 +3069,9 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
     auto type = cast<StoreOp>(op).getMemRefType();
     StoreOp::Adaptor transformed(operands);
 
-    Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
-                               transformed.indices(), rewriter);
+    Value dataPtr =
+        getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+                             transformed.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
                                                dataPtr);
     return success();
@@ -3082,8 +3090,9 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
     PrefetchOp::Adaptor transformed(operands);
     auto type = prefetchOp.getMemRefType();
 
-    Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
-                               transformed.indices(), rewriter);
+    Value dataPtr =
+        getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+                             transformed.indices(), rewriter);
 
     // Replace with llvm.prefetch.
     auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
@@ -3788,8 +3797,9 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
     AtomicRMWOp::Adaptor adaptor(operands);
     auto resultType = adaptor.value().getType();
     auto memRefType = atomicOp.getMemRefType();
-    auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
-                              adaptor.indices(), rewriter);
+    auto dataPtr =
+        getStridedElementPtr(op->getLoc(), memRefType, adaptor.memref(),
+                             adaptor.indices(), rewriter);
     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
         op, resultType, *maybeKind, dataPtr, adaptor.value(),
         LLVM::AtomicOrdering::acq_rel);
@@ -3854,8 +3864,8 @@ struct GenericAtomicRMWOpLowering
     // Compute the loaded value and branch to the loop block.
     rewriter.setInsertionPointToEnd(initBlock);
     auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
-    auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                              adaptor.indices(), rewriter);
+    auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
+                                        adaptor.indices(), rewriter);
     Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6fe801fd0d94..42c66bcab9ab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1277,8 +1277,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
     //    addrspacecast shall be used when source/dst memrefs are not on
     //    address space 0.
     // TODO: support alignment when possible.
-    Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                               adaptor.indices(), rewriter);
+    Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
+                                         adaptor.indices(), rewriter);
     auto vecTy =
         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
     Value vectorDataPtr;

diff  --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 317cf322e3e0..26b8bec1f3fc 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -102,8 +102,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
     // Note that the dataPtr starts at the offset address specified by
     // indices, so no need to calculate offset size in bytes again in
     // the MUBUF instruction.
-    Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
-                               adaptor.indices(), rewriter);
+    Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
+                                         adaptor.indices(), rewriter);
 
     // 1. Create and fill a <4 x i32> dwordConfig with:
     //    1st two elements holding the address of dataPtr.

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index 7db6ea568b51..b2708a562eab 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -182,13 +182,9 @@ func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> {
 // CHECK:         %[[J:.*]]: !llvm.i64)
 func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-//  CHECK-NEXT:  %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-//  CHECK-NEXT:  %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-//  CHECK-NEXT:  %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 //  CHECK-NEXT:  llvm.load %[[addr]] : !llvm.ptr<float>
   %0 = load %mixed[%i, %j] : memref<42x?xf32>
@@ -207,13 +203,9 @@ func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
 // CHECK-SAME:         %[[J:[a-zA-Z0-9]*]]: !llvm.i64
 func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-//  CHECK-NEXT:  %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-//  CHECK-NEXT:  %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-//  CHECK-NEXT:  %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 //  CHECK-NEXT:  llvm.load %[[addr]] : !llvm.ptr<float>
   %0 = load %dynamic[%i, %j] : memref<?x?xf32>
@@ -232,13 +224,9 @@ func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
 // CHECK-SAME:         %[[J:[a-zA-Z0-9]*]]: !llvm.i64
 func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
 //      CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT:  %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 // CHECK-NEXT:  %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT:  %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT:  %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT:  %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
 // CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 // CHECK-NEXT:  [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
 // CHECK-NEXT:  [[C3:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
@@ -270,13 +258,9 @@ func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
 // CHECK-SAME:         %[[J:[a-zA-Z0-9]*]]: !llvm.i64
 func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f32) {
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-//  CHECK-NEXT:  %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-//  CHECK-NEXT:  %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-//  CHECK-NEXT:  %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 //  CHECK-NEXT:  llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
   store %val, %dynamic[%i, %j] : memref<?x?xf32>
@@ -295,13 +279,9 @@ func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f
 // CHECK-SAME:         %[[J:[a-zA-Z0-9]*]]: !llvm.i64
 func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) {
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 //  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-//  CHECK-NEXT:  %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-//  CHECK-NEXT:  %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-//  CHECK-NEXT:  %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 //  CHECK-NEXT:  llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
   store %val, %mixed[%i, %j] : memref<42x?xf32>

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
index 9c32abc39f14..158fdcba7c92 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
@@ -234,14 +234,10 @@ func @static_dealloc(%static: memref<10x8xf32>) {
 // BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm.ptr<float>) -> !llvm.float
 func @zero_d_load(%arg0: memref<f32>) -> f32 {
 //      CHECK:  %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
-// CHECK-NEXT:  %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
-// CHECK-NEXT:  %{{.*}} = llvm.load %[[addr]] : !llvm.ptr<float>
+// CHECK-NEXT:  %{{.*}} = llvm.load %[[ptr]] : !llvm.ptr<float>
 
 // BAREPTR:      %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
-// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
-// BAREPTR-NEXT: llvm.load %[[addr:.*]] : !llvm.ptr<float>
+// BAREPTR-NEXT: llvm.load %[[ptr:.*]] : !llvm.ptr<float>
   %0 = load %arg0[] : memref<f32>
   return %0 : f32
 }
@@ -257,24 +253,16 @@ func @zero_d_load(%arg0: memref<f32>) -> f32 {
 // BAREPTR-SAME: (%[[A:.*]]: !llvm.ptr<float>, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) {
 func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
 //  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-//  CHECK-NEXT:  %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-//  CHECK-NEXT:  %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-//  CHECK-NEXT:  %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 // CHECK-NEXT:  llvm.load %[[addr]] : !llvm.ptr<float>
 
 // BAREPTR:      %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
 // BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 // BAREPTR-NEXT: llvm.load %[[addr]] : !llvm.ptr<float>
   %0 = load %static[%i, %j] : memref<10x42xf32>
@@ -288,14 +276,10 @@ func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
 // BAREPTR-SAME: (%[[A:.*]]: !llvm.ptr<float>, %[[val:.*]]: !llvm.float)
 func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
 //      CHECK:  %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
-// CHECK-NEXT:  %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
-// CHECK-NEXT:  llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
+// CHECK-NEXT:  llvm.store %{{.*}}, %[[ptr]] : !llvm.ptr<float>
 
 // BAREPTR:      %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
-// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
-// BAREPTR-NEXT: llvm.store %[[val]], %[[addr]] : !llvm.ptr<float>
+// BAREPTR-NEXT: llvm.store %[[val]], %[[ptr]] : !llvm.ptr<float>
   store %arg1, %arg0[] : memref<f32>
   return
 }
@@ -318,24 +302,16 @@ func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
 // BAREPTR-SAME:         %[[J:[a-zA-Z0-9]*]]: !llvm.i64
 func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
 //       CHECK:  %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-//  CHECK-NEXT:  %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 //  CHECK-NEXT:  %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
 //  CHECK-NEXT:  %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-//  CHECK-NEXT:  %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-//  CHECK-NEXT:  %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-//  CHECK-NEXT:  %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+//  CHECK-NEXT:  %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
 //  CHECK-NEXT:  %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 //  CHECK-NEXT:  llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
 
 // BAREPTR:      %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
 // BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
 // BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 // BAREPTR-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
   store %val, %static[%i, %j] : memref<10x42xf32>


        


More information about the Mlir-commits mailing list