[Mlir-commits] [mlir] 620e2bb - [mlir][LLVM] NFC - Remove createIndexConstant method

Nicolas Vasilache llvmlistbot at llvm.org
Wed Aug 2 00:24:34 PDT 2023


Author: Nicolas Vasilache
Date: 2023-08-02T07:24:29Z
New Revision: 620e2bb20cb7f9e59a7c30eab0737e34eb26ed2d

URL: https://github.com/llvm/llvm-project/commit/620e2bb20cb7f9e59a7c30eab0737e34eb26ed2d
DIFF: https://github.com/llvm/llvm-project/commit/620e2bb20cb7f9e59a7c30eab0737e34eb26ed2d.diff

LOG: [mlir][LLVM] NFC - Remove createIndexConstant method

This revision removes the createIndexConstant method, which implicitly creates constants of the
getIndexType type and updates all uses to the more explicit createIndexAttrConstant which requires
an explicit Type parameter.

This is an NFC step towards entangling index type conversion in LLVM lowering.

The selection of which index type to use requires finer granularity than the existing
implementations which all rely on pass level flags and end up in mismatches, especially on GPUs
with multiple address spaces of different capacities.

This revision also includes an NFC fix to MemRefToLLVM.cpp that prevents a crash in cases where
an integer memory space cannot be derived for a MemRef.

Differential Revision: https://reviews.llvm.org/D156854

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
    mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 7f08ec87023053..0aee13818df4d5 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -65,10 +65,6 @@ class ConvertToLLVMPattern : public ConversionPattern {
   static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
                                        Type resultType, int64_t value);
 
-  /// Create an LLVM dialect operation defining the given index constant.
-  Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
-                            uint64_t value) const;
-
   // This is a strided getElementPtr variant that linearizes subscripts as:
   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
   Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
@@ -155,9 +151,9 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                ConversionPatternRewriter &rewriter) const final {
     if constexpr (SourceOp::hasProperties())
       return rewrite(cast<SourceOp>(op),
-              OpAdaptor(operands, op->getDiscardableAttrDictionary(),
-                        cast<SourceOp>(op).getProperties()),
-              rewriter);
+                     OpAdaptor(operands, op->getDiscardableAttrDictionary(),
+                               cast<SourceOp>(op).getProperties()),
+                     rewriter);
     rewrite(cast<SourceOp>(op),
             OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter);
   }

diff  --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
index 770f319c82ffbf..495c4d63986f80 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
@@ -15,7 +15,7 @@ namespace mlir {
 
 /// Lowering for memory allocation ops.
 struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
-  using ConvertToLLVMPattern::createIndexConstant;
+  using ConvertToLLVMPattern::createIndexAttrConstant;
   using ConvertToLLVMPattern::getIndexType;
   using ConvertToLLVMPattern::getVoidPtrType;
 
@@ -43,7 +43,9 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
     MemRefType memRefType = op.getType();
     Value alignment;
     if (auto alignmentAttr = op.getAlignment()) {
-      alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
+      Type indexType = getIndexType();
+      alignment =
+          createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
     } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
       // In the case where no alignment is specified, we may want to override
       // `malloc's` behavior. `malloc` typically aligns at the size of the

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 1de583c1780932..ecd4cbb25f2d5c 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -168,7 +168,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     Value lowHalf = rewriter.create<LLVM::TruncOp>(loc, llvmI32, ptrAsInt);
     resource = rewriter.create<LLVM::InsertElementOp>(
         loc, llvm4xI32, resource, lowHalf,
-        this->createIndexConstant(rewriter, loc, 0));
+        this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 0));
 
     // Bits 48-63 are used both for the stride of the buffer and (on gfx10) for
     // enabling swizzling. Prevent the high bits of pointers from accidentally
@@ -180,7 +180,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
         createI32Constant(rewriter, loc, 0x0000ffff));
     resource = rewriter.create<LLVM::InsertElementOp>(
         loc, llvm4xI32, resource, highHalfTruncated,
-        this->createIndexConstant(rewriter, loc, 1));
+        this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 1));
 
     Value numRecords;
     if (memrefType.hasStaticShape()) {
@@ -202,7 +202,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     }
     resource = rewriter.create<LLVM::InsertElementOp>(
         loc, llvm4xI32, resource, numRecords,
-        this->createIndexConstant(rewriter, loc, 2));
+        this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 2));
 
     // Final word:
     // bits 0-11: dst sel, ignored by these intrinsics
@@ -227,7 +227,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     Value word3Const = createI32Constant(rewriter, loc, word3);
     resource = rewriter.create<LLVM::InsertElementOp>(
         loc, llvm4xI32, resource, word3Const,
-        this->createIndexConstant(rewriter, loc, 3));
+        this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 3));
     args.push_back(resource);
 
     // Indexing (voffset)

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1838c8e6050c34..9993c093badc16 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -67,9 +67,10 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
 protected:
   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
                        MemRefType type, MemRefDescriptor desc) const {
+    Type indexType = ConvertToLLVMPattern::getIndexType();
     return type.hasStaticShape()
-               ? ConvertToLLVMPattern::createIndexConstant(
-                     rewriter, loc, type.getNumElements())
+               ? ConvertToLLVMPattern::createIndexAttrConstant(
+                     rewriter, loc, indexType, type.getNumElements())
                // For identity maps (verified by caller), the number of
                // elements is stride[0] * size[0].
                : rewriter.create<LLVM::MulOp>(loc,

diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index ea31bdd2b53909..1699172eb9dab3 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -60,11 +60,6 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
                                           builder.getIndexAttr(value));
 }
 
-Value ConvertToLLVMPattern::createIndexConstant(
-    ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
-  return createIndexAttrConstant(builder, loc, getIndexType(), value);
-}
-
 Value ConvertToLLVMPattern::getStridedElementPtr(
     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
     ConversionPatternRewriter &rewriter) const {
@@ -79,13 +74,15 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
   Value base =
       memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
 
+  Type indexType = getIndexType();
   Value index;
   for (int i = 0, e = indices.size(); i < e; ++i) {
     Value increment = indices[i];
     if (strides[i] != 1) { // Skip if stride is 1.
-      Value stride = ShapedType::isDynamic(strides[i])
-                         ? memRefDescriptor.stride(rewriter, loc, i)
-                         : createIndexConstant(rewriter, loc, strides[i]);
+      Value stride =
+          ShapedType::isDynamic(strides[i])
+              ? memRefDescriptor.stride(rewriter, loc, i)
+              : createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
       increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
     }
     index =
@@ -130,15 +127,17 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
 
   sizes.reserve(memRefType.getRank());
   unsigned dynamicIndex = 0;
+  Type indexType = getIndexType();
   for (int64_t size : memRefType.getShape()) {
-    sizes.push_back(size == ShapedType::kDynamic
-                        ? dynamicSizes[dynamicIndex++]
-                        : createIndexConstant(rewriter, loc, size));
+    sizes.push_back(
+        size == ShapedType::kDynamic
+            ? dynamicSizes[dynamicIndex++]
+            : createIndexAttrConstant(rewriter, loc, indexType, size));
   }
 
   // Strides: iterate sizes in reverse order and multiply.
   int64_t stride = 1;
-  Value runningStride = createIndexConstant(rewriter, loc, 1);
+  Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
   strides.resize(memRefType.getRank());
   for (auto i = memRefType.getRank(); i-- > 0;) {
     strides[i] = runningStride;
@@ -158,7 +157,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
       runningStride =
           rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
     else
-      runningStride = createIndexConstant(rewriter, loc, stride);
+      runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
   }
   if (sizeInBytes) {
     // Buffer size in bytes.
@@ -195,22 +194,25 @@ Value ConvertToLLVMPattern::getNumElements(
              static_cast<ssize_t>(dynamicSizes.size()) &&
          "dynamicSizes size doesn't match dynamic sizes count in memref shape");
 
+  Type indexType = getIndexType();
   Value numElements = memRefType.getRank() == 0
-                          ? createIndexConstant(rewriter, loc, 1)
+                          ? createIndexAttrConstant(rewriter, loc, indexType, 1)
                           : nullptr;
   unsigned dynamicIndex = 0;
 
   // Compute the total number of memref elements.
   for (int64_t staticSize : memRefType.getShape()) {
     if (numElements) {
-      Value size = staticSize == ShapedType::kDynamic
-                       ? dynamicSizes[dynamicIndex++]
-                       : createIndexConstant(rewriter, loc, staticSize);
+      Value size =
+          staticSize == ShapedType::kDynamic
+              ? dynamicSizes[dynamicIndex++]
+              : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
       numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
     } else {
-      numElements = staticSize == ShapedType::kDynamic
-                        ? dynamicSizes[dynamicIndex++]
-                        : createIndexConstant(rewriter, loc, staticSize);
+      numElements =
+          staticSize == ShapedType::kDynamic
+              ? dynamicSizes[dynamicIndex++]
+              : createIndexAttrConstant(rewriter, loc, indexType, staticSize);
     }
   }
   return numElements;
@@ -231,8 +233,9 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
   memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
 
   // Field 3: Offset in aligned pointer.
-  memRefDescriptor.setOffset(rewriter, loc,
-                             createIndexConstant(rewriter, loc, 0));
+  Type indexType = getIndexType();
+  memRefDescriptor.setOffset(
+      rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
 
   // Fields 4: Sizes.
   for (const auto &en : llvm::enumerate(sizes))

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index b762758a16eab2..715d00f2e215ac 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -138,7 +138,8 @@ bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
 Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
     ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
     Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
-  Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
+  Value allocAlignment =
+      createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
 
   MemRefType memRefType = getMemRefResultType(op);
   // Function aligned_alloc requires size to be a multiple of alignment; we pad

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 89811c11680fa3..d69ee3ff82220e 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -160,11 +160,12 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
     auto computeNumElements =
         [&](MemRefType type, function_ref<Value()> getDynamicSize) -> Value {
       // Compute number of elements.
+      Type indexType = ConvertToLLVMPattern::getIndexType();
       Value numElements =
           type.isDynamicDim(0)
               ? getDynamicSize()
-              : createIndexConstant(rewriter, loc, type.getDimSize(0));
-      Type indexType = getIndexType();
+              : createIndexAttrConstant(rewriter, loc, indexType,
+                                        type.getDimSize(0));
       if (numElements.getType() != indexType)
         numElements = typeConverter->materializeTargetConversion(
             rewriter, loc, indexType, numElements);
@@ -482,7 +483,8 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
     // The size value that we have to extract can be obtained using GEPop with
     // `dimOp.index() + 1` index argument.
     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
-        loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex());
+        loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
+        adaptor.getIndex());
     Value sizePtr = rewriter.create<LLVM::GEPOp>(
         loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
         idxPlusOne);
@@ -508,6 +510,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
 
     // Take advantage if index is constant.
     MemRefType memRefType = cast<MemRefType>(operandType);
+    Type indexType = getIndexType();
     if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
       int64_t i = *index;
       if (i >= 0 && i < memRefType.getRank()) {
@@ -518,7 +521,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
         }
         // Use constant for static size.
         int64_t dimSize = memRefType.getDimSize(i);
-        return createIndexConstant(rewriter, loc, dimSize);
+        return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
       }
     }
     Value index = adaptor.getIndex();
@@ -717,7 +720,11 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
 
     // This is called after a type conversion, which would have failed if this
     // call fails.
-    unsigned memSpace = *getTypeConverter()->getMemRefAddressSpace(type);
+    std::optional<unsigned> maybeAddressSpace =
+        getTypeConverter()->getMemRefAddressSpace(type);
+    if (!maybeAddressSpace)
+      return std::make_tuple(Value(), Value());
+    unsigned memSpace = *maybeAddressSpace;
 
     Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
     Type resTy = getTypeConverter()->getPointerType(arrayTy, memSpace);
@@ -826,8 +833,10 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
       return success();
     }
     if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
-      rewriter.replaceOp(
-          op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
+      Type indexType = getIndexType();
+      rewriter.replaceOp(op,
+                         {createIndexAttrConstant(rewriter, loc, indexType,
+                                                  rankedMemRefType.getRank())});
       return success();
     }
     return failure();
@@ -1351,29 +1360,31 @@ struct MemRefReshapeOpLowering
       assert(targetMemRefType.getLayout().isIdentity() &&
              "Identity layout map is a precondition of a valid reshape op");
 
+      Type indexType = getIndexType();
       Value stride = nullptr;
       int64_t targetRank = targetMemRefType.getRank();
       for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
         if (!ShapedType::isDynamic(strides[i])) {
           // If the stride for this dimension is dynamic, then use the product
           // of the sizes of the inner dimensions.
-          stride = createIndexConstant(rewriter, loc, strides[i]);
+          stride =
+              createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
         } else if (!stride) {
           // `stride` is null only in the first iteration of the loop.  However,
           // since the target memref has an identity layout, we can safely set
           // the innermost stride to 1.
-          stride = createIndexConstant(rewriter, loc, 1);
+          stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
         }
 
         Value dimSize;
         // If the size of this dimension is dynamic, then load it at runtime
         // from the shape operand.
         if (!targetMemRefType.isDynamicDim(i)) {
-          dimSize = createIndexConstant(rewriter, loc,
-                                        targetMemRefType.getDimSize(i));
+          dimSize = createIndexAttrConstant(rewriter, loc, indexType,
+                                            targetMemRefType.getDimSize(i));
         } else {
           Value shapeOp = reshapeOp.getShape();
-          Value index = createIndexConstant(rewriter, loc, i);
+          Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
           dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
           Type indexType = getIndexType();
           if (dimSize.getType() != indexType)
@@ -1444,7 +1455,7 @@ struct MemRefReshapeOpLowering
     Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
         rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
     Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
-    Value oneIndex = createIndexConstant(rewriter, loc, 1);
+    Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
     Value resultRankMinusOne =
         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
 
@@ -1466,7 +1477,7 @@ struct MemRefReshapeOpLowering
     Value indexArg = condBlock->getArgument(0);
     Value strideArg = condBlock->getArgument(1);
 
-    Value zeroIndex = createIndexConstant(rewriter, loc, 0);
+    Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
     Value pred = rewriter.create<LLVM::ICmpOp>(
         loc, IntegerType::get(rewriter.getContext(), 1),
         LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
@@ -1604,11 +1615,11 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
   // Build and return the value for the idx^th shape dimension, either by
   // returning the constant shape dimension or counting the proper dynamic size.
   Value getSize(ConversionPatternRewriter &rewriter, Location loc,
-                ArrayRef<int64_t> shape, ValueRange dynamicSizes,
-                unsigned idx) const {
+                ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
+                Type indexType) const {
     assert(idx < shape.size());
     if (!ShapedType::isDynamic(shape[idx]))
-      return createIndexConstant(rewriter, loc, shape[idx]);
+      return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
     // Count the number of dynamic dims in range [0, idx]
     unsigned nDynamic =
         llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
@@ -1621,16 +1632,16 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
   // result returned by this function.
   Value getStride(ConversionPatternRewriter &rewriter, Location loc,
                   ArrayRef<int64_t> strides, Value nextSize,
-                  Value runningStride, unsigned idx) const {
+                  Value runningStride, unsigned idx, Type indexType) const {
     assert(idx < strides.size());
     if (!ShapedType::isDynamic(strides[idx]))
-      return createIndexConstant(rewriter, loc, strides[idx]);
+      return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
     if (nextSize)
       return runningStride
                  ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
                  : nextSize;
     assert(!runningStride);
-    return createIndexConstant(rewriter, loc, 1);
+    return createIndexAttrConstant(rewriter, loc, indexType, 1);
   }
 
   LogicalResult
@@ -1697,11 +1708,13 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
 
     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
-    // Field 3: The offset in the resulting type must be 0. This is because of
-    // the type change: an offset on srcType* may not be expressible as an
-    // offset on dstType*.
-    targetMemRef.setOffset(rewriter, loc,
-                           createIndexConstant(rewriter, loc, offset));
+    Type indexType = getIndexType();
+    // Field 3: The offset in the resulting type must be 0. This is
+    // because of the type change: an offset on srcType* may not be
+    // expressible as an offset on dstType*.
+    targetMemRef.setOffset(
+        rewriter, loc,
+        createIndexAttrConstant(rewriter, loc, indexType, offset));
 
     // Early exit for 0-D corner case.
     if (viewMemRefType.getRank() == 0)
@@ -1712,10 +1725,11 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
       // Update size.
       Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
-                           adaptor.getSizes(), i);
+                           adaptor.getSizes(), i, indexType);
       targetMemRef.setSize(rewriter, loc, i, size);
       // Update stride.
-      stride = getStride(rewriter, loc, strides, nextSize, stride, i);
+      stride =
+          getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
       targetMemRef.setStride(rewriter, loc, i, stride);
       nextSize = size;
     }


        


More information about the Mlir-commits mailing list