[Mlir-commits] [mlir] 8b97e17 - [mlir] Simplify code generated by ConvertToLLVMPattern::getStridedElementPtr().
Christian Sigg
llvmlistbot at llvm.org
Wed Nov 18 02:52:20 PST 2020
Author: Christian Sigg
Date: 2020-11-18T11:52:09+01:00
New Revision: 8b97e17d161a177ae54c989e6e550930f6a75876
URL: https://github.com/llvm/llvm-project/commit/8b97e17d161a177ae54c989e6e550930f6a75876
DIFF: https://github.com/llvm/llvm-project/commit/8b97e17d161a177ae54c989e6e550930f6a75876.diff
LOG: [mlir] Simplify code generated by ConvertToLLVMPattern::getStridedElementPtr().
Make the interface match the one of ConvertToLLVMPattern::getDataPtr() (to be removed in a separate change).
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D91599
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index dbf38de56a0a..919a93ac84a2 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -505,18 +505,18 @@ class ConvertToLLVMPattern : public ConversionPattern {
// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
- Value getStridedElementPtr(Location loc, Type elementTypePtr,
- Value descriptor, ValueRange indices,
- ArrayRef<int64_t> strides, int64_t offset,
+ Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
+ ValueRange indices,
ConversionPatternRewriter &rewriter) const;
- /// Returns if the givem memref type is supported.
- bool isSupportedMemRefType(MemRefType type) const;
-
+ // Forwards to getStridedElementPtr. TODO: remove.
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
ValueRange indices,
ConversionPatternRewriter &rewriter) const;
+ /// Returns if the givem memref type is supported.
+ bool isSupportedMemRefType(MemRefType type) const;
+
/// Returns the type of a pointer to an element of the memref.
Type getElementPtrType(MemRefType type) const;
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 870b91034361..17187e933cfa 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1042,39 +1042,45 @@ Value ConvertToLLVMPattern::createIndexConstant(
}
Value ConvertToLLVMPattern::getStridedElementPtr(
- Location loc, Type elementTypePtr, Value descriptor, ValueRange indices,
- ArrayRef<int64_t> strides, int64_t offset,
+ Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter) const {
- MemRefDescriptor memRefDescriptor(descriptor);
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(type, strides, offset);
+ assert(succeeded(successStrides) && "unexpected non-strided memref");
+ (void)successStrides;
+
+ MemRefDescriptor memRefDescriptor(memRefDesc);
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
- Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
- ? memRefDescriptor.offset(rewriter, loc)
- : createIndexConstant(rewriter, loc, offset);
+
+ Value index;
+ if (offset != 0) // Skip if offset is zero.
+ index = offset == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.offset(rewriter, loc)
+ : createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
- Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
- ? memRefDescriptor.stride(rewriter, loc, i)
- : createIndexConstant(rewriter, loc, strides[i]);
- Value additionalOffset =
- rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
- offsetValue =
- rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
+ Value increment = indices[i];
+ if (strides[i] != 1) { // Skip if stride is 1.
+ Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.stride(rewriter, loc, i)
+ : createIndexConstant(rewriter, loc, strides[i]);
+ increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+ }
+ index =
+ index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
}
- return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
+
+ LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType();
+ return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
+ : base;
}
Value ConvertToLLVMPattern::getDataPtr(
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter) const {
- LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementPtrType();
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto successStrides = getStridesAndOffset(type, strides, offset);
- assert(succeeded(successStrides) && "unexpected non-strided memref");
- (void)successStrides;
- return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
- offset, rewriter);
+ return getStridedElementPtr(loc, type, memRefDesc, indices, rewriter);
}
// Check if the MemRefType `type` is supported by the lowering. We currently
@@ -3044,8 +3050,9 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
LoadOp::Adaptor transformed(operands);
auto type = loadOp.getMemRefType();
- Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter);
+ Value dataPtr =
+ getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return success();
}
@@ -3062,8 +3069,9 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
auto type = cast<StoreOp>(op).getMemRefType();
StoreOp::Adaptor transformed(operands);
- Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter);
+ Value dataPtr =
+ getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
return success();
@@ -3082,8 +3090,9 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
PrefetchOp::Adaptor transformed(operands);
auto type = prefetchOp.getMemRefType();
- Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
- transformed.indices(), rewriter);
+ Value dataPtr =
+ getStridedElementPtr(op->getLoc(), type, transformed.memref(),
+ transformed.indices(), rewriter);
// Replace with llvm.prefetch.
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
@@ -3788,8 +3797,9 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
AtomicRMWOp::Adaptor adaptor(operands);
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
- auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
- adaptor.indices(), rewriter);
+ auto dataPtr =
+ getStridedElementPtr(op->getLoc(), memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter);
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
op, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel);
@@ -3854,8 +3864,8 @@ struct GenericAtomicRMWOpLowering
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
- auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter);
+ auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6fe801fd0d94..42c66bcab9ab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1277,8 +1277,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
// addrspacecast shall be used when source/dst memrefs are not on
// address space 0.
// TODO: support alignment when possible.
- Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter);
+ Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter);
auto vecTy =
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
Value vectorDataPtr;
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 317cf322e3e0..26b8bec1f3fc 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -102,8 +102,8 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
// Note that the dataPtr starts at the offset address specified by
// indices, so no need to calculate offset size in bytes again in
// the MUBUF instruction.
- Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
- adaptor.indices(), rewriter);
+ Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter);
// 1. Create and fill a <4 x i32> dwordConfig with:
// 1st two elements holding the address of dataPtr.
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
index 7db6ea568b51..b2708a562eab 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir
@@ -182,13 +182,9 @@ func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> {
// CHECK: %[[J:.*]]: !llvm.i64)
func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr<float>
%0 = load %mixed[%i, %j] : memref<42x?xf32>
@@ -207,13 +203,9 @@ func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
// CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64
func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr<float>
%0 = load %dynamic[%i, %j] : memref<?x?xf32>
@@ -232,13 +224,9 @@ func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64
func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK-NEXT: [[C3:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
@@ -270,13 +258,9 @@ func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64
func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f32) {
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
store %val, %dynamic[%i, %j] : memref<?x?xf32>
@@ -295,13 +279,9 @@ func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f
// CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64
func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) {
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
store %val, %mixed[%i, %j] : memref<42x?xf32>
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
index 9c32abc39f14..158fdcba7c92 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
@@ -234,14 +234,10 @@ func @static_dealloc(%static: memref<10x8xf32>) {
// BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm.ptr<float>) -> !llvm.float
func @zero_d_load(%arg0: memref<f32>) -> f32 {
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
-// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
-// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm.ptr<float>
+// CHECK-NEXT: %{{.*}} = llvm.load %[[ptr]] : !llvm.ptr<float>
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
-// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
-// BAREPTR-NEXT: llvm.load %[[addr:.*]] : !llvm.ptr<float>
+// BAREPTR-NEXT: llvm.load %[[ptr:.*]] : !llvm.ptr<float>
%0 = load %arg0[] : memref<f32>
return %0 : f32
}
@@ -257,24 +253,16 @@ func @zero_d_load(%arg0: memref<f32>) -> f32 {
// BAREPTR-SAME: (%[[A:.*]]: !llvm.ptr<float>, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) {
func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr<float>
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// BAREPTR-NEXT: llvm.load %[[addr]] : !llvm.ptr<float>
%0 = load %static[%i, %j] : memref<10x42xf32>
@@ -288,14 +276,10 @@ func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
// BAREPTR-SAME: (%[[A:.*]]: !llvm.ptr<float>, %[[val:.*]]: !llvm.float)
func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
-// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
+// CHECK-NEXT: llvm.store %{{.*}}, %[[ptr]] : !llvm.ptr<float>
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64)>
-// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
-// BAREPTR-NEXT: llvm.store %[[val]], %[[addr]] : !llvm.ptr<float>
+// BAREPTR-NEXT: llvm.store %[[val]], %[[ptr]] : !llvm.ptr<float>
store %arg1, %arg0[] : memref<f32>
return
}
@@ -318,24 +302,16 @@ func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
// BAREPTR-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64
func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
// BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64
// BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64
-// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64
-// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
-// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64
-// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64
+// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64
// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// BAREPTR-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr<float>
store %val, %static[%i, %j] : memref<10x42xf32>
More information about the Mlir-commits
mailing list