[Mlir-commits] [mlir] cafaa35 - [mlir] Make it possible to directly supply constant values to LLVM GEPOp

Alex Zinenko llvmlistbot at llvm.org
Fri Jan 7 00:56:10 PST 2022


Author: Alex Zinenko
Date: 2022-01-07T09:56:01+01:00
New Revision: cafaa3503643c047b670dcc387e22096930d2d6c

URL: https://github.com/llvm/llvm-project/commit/cafaa3503643c047b670dcc387e22096930d2d6c
DIFF: https://github.com/llvm/llvm-project/commit/cafaa3503643c047b670dcc387e22096930d2d6c.diff

LOG: [mlir] Make it possible to directly supply constant values to LLVM GEPOp

In LLVM IR, the GEP indices that correspond to structures are required to be
i32 constants. MLIR models constants as just values defined by special
operations, and there is no verification that it is the case for structure
indices in GEP. Furthermore, some common transformations such as control flow
simplification may lead to the operands becoming non-constant. Make it possible
to directly supply constant values to LLVM GEPOp to guarantee they remain
constant until the translation to LLVM IR. This is not yet a requirement and
the verifier is not modified, this will be introduced separately.

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D116757

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 95551730ddc80..dd4def5449f32 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -315,17 +315,39 @@ def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase {
   let printer = [{ printAllocaOp(p, *this); }];
 }
 
-def LLVM_GEPOp
-    : LLVM_Op<"getelementptr", [NoSideEffect]>,
-      LLVM_Builder<
-          "$res = builder.CreateGEP("
-          " $base->getType()->getPointerElementType(), $base, $indices);"> {
+def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
   let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
-                   Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$indices);
+                   Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$indices,
+                   I32ElementsAttr:$structIndices);
   let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
-  let builders = [LLVM_OneResultOpBuilder];
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices,
+               CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+    OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices,
+               "ArrayRef<int32_t>":$structIndices,
+               CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
+  ];
+  let llvmBuilder = [{
+    SmallVector<llvm::Value *> indices;
+    indices.reserve($structIndices.size());
+    unsigned operandIdx = 0;
+    for (int32_t structIndex : $structIndices.getValues<int32_t>()) {
+      if (structIndex == GEPOp::kDynamicIndex)
+        indices.push_back($indices[operandIdx++]);
+      else
+        indices.push_back(builder.getInt32(structIndex));
+    }
+    $res = builder.CreateGEP(
+      $base->getType()->getPointerElementType(), $base, indices);
+  }];
   let assemblyFormat = [{
-    $base `[` $indices `]` attr-dict `:` functional-type(operands, results)
+    $base `[` custom<GEPIndices>($indices, $structIndices) `]` attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
   }];
   let hasFolder = 1;
 }

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index f7f8b6b142357..ab498da6e8ccb 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -790,8 +790,8 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
 
   Type elementPtrType = getElementPtrType(memRefType);
   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
-  Value gepPtr = rewriter.create<LLVM::GEPOp>(
-      loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements});
+  Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
+                                              ArrayRef<Value>{numElements});
   auto sizeBytes =
       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
 

diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 0003bd859e479..41e8cefd712ad 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -162,8 +162,8 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
   // Buffer size in bytes.
   Type elementPtrType = getElementPtrType(memRefType);
   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
-  Value gepPtr = rewriter.create<LLVM::GEPOp>(
-      loc, elementPtrType, ArrayRef<Value>{nullPtr, runningStride});
+  Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
+                                              ArrayRef<Value>{runningStride});
   sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
 }
 
@@ -178,8 +178,8 @@ Value ConvertToLLVMPattern::getSizeInBytes(
       LLVM::LLVMPointerType::get(typeConverter->convertType(type));
   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
   auto gep = rewriter.create<LLVM::GEPOp>(
-      loc, convertedPtrType,
-      ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
+      loc, convertedPtrType, nullPtr,
+      ArrayRef<Value>{createIndexConstant(rewriter, loc, 1)});
   return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
 }
 

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 9142be183174e..d8fa9654664e2 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -497,10 +497,11 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
     Type elementType = typeConverter->convertType(type.getElementType());
     Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
 
-    SmallVector<Value, 4> operands = {addressOf};
+    SmallVector<Value> operands;
     operands.insert(operands.end(), type.getRank() + 1,
                     createIndexConstant(rewriter, loc, 0));
-    auto gep = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, operands);
+    auto gep =
+        rewriter.create<LLVM::GEPOp>(loc, elementPtrType, addressOf, operands);
 
     // We do not expect the memref obtained using `memref.get_global` to be
     // ever deallocated. Set the allocated pointer to be known bad value to

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index e65c14e38268c..995d2eaa89a37 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -356,6 +356,67 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
                     : getCaseOperandsMutable(index - 1);
 }
 
+//===----------------------------------------------------------------------===//
+// Code for LLVM::GEPOp.
+//===----------------------------------------------------------------------===//
+
+void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  Value basePtr, ValueRange operands,
+                  ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, resultType, basePtr, operands,
+        SmallVector<int32_t>(operands.size(), LLVM::GEPOp::kDynamicIndex),
+        attributes);
+}
+
+void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
+                  Value basePtr, ValueRange indices,
+                  ArrayRef<int32_t> structIndices,
+                  ArrayRef<NamedAttribute> attributes) {
+  result.addTypes(resultType);
+  result.addAttributes(attributes);
+  result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices));
+  result.addOperands(basePtr);
+  result.addOperands(indices);
+}
+
+static ParseResult
+parseGEPIndices(OpAsmParser &parser,
+                SmallVectorImpl<OpAsmParser::OperandType> &indices,
+                DenseIntElementsAttr &structIndices) {
+  SmallVector<int32_t> constantIndices;
+  do {
+    int32_t constantIndex;
+    OptionalParseResult parsedInteger =
+        parser.parseOptionalInteger(constantIndex);
+    if (parsedInteger.hasValue()) {
+      if (failed(parsedInteger.getValue()))
+        return failure();
+      constantIndices.push_back(constantIndex);
+      continue;
+    }
+
+    constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
+    if (failed(parser.parseOperand(indices.emplace_back())))
+      return failure();
+  } while (succeeded(parser.parseOptionalComma()));
+
+  structIndices = parser.getBuilder().getI32TensorAttr(constantIndices);
+  return success();
+}
+
+static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
+                            OperandRange indices,
+                            DenseIntElementsAttr structIndices) {
+  unsigned operandIdx = 0;
+  llvm::interleaveComma(structIndices.getValues<int32_t>(), printer,
+                        [&](int32_t cst) {
+                          if (cst == LLVM::GEPOp::kDynamicIndex)
+                            printer.printOperand(indices[operandIdx++]);
+                          else
+                            printer << cst;
+                        });
+}
+
 //===----------------------------------------------------------------------===//
 // Builder, printer and parser for for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 3e06f9caf7b10..f7ad3383323b2 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -760,7 +760,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     Type type = processType(inst->getType());
     if (!type)
       return failure();
-    v = b.create<GEPOp>(loc, type, ops);
+    v = b.create<GEPOp>(loc, type, ops[0],
+                        llvm::makeArrayRef(ops).drop_front());
     return success();
   }
   }

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index b6a09d6ff09e9..d4172e1a072d4 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -170,6 +170,16 @@ func @ops(%arg0: i32, %arg1: f32,
   llvm.return
 }
 
+// CHECK-LABEL: @gep
+llvm.func @gep(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64,
+               %ptr2: !llvm.ptr<struct<(array<10xf32>)>>) {
+  // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 1, 0] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
+  llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
+  // CHECK: llvm.getelementptr %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr<struct<(array<10 x f32>)>>, i64, i64) -> !llvm.ptr<f32>
+  llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr<struct<(array<10 x f32>)>>, i64, i64) -> !llvm.ptr<f32>
+  llvm.return
+}
+
 // An larger self-contained function.
 // CHECK-LABEL: llvm.func @foo(%{{.*}}: i32) -> !llvm.struct<(i32, f64, i32)> {
 llvm.func @foo(%arg0: i32) -> !llvm.struct<(i32, f64, i32)> {

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 54dfd519d81c7..04a65f845148d 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -975,6 +975,16 @@ llvm.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32) -> !llvm.struct<(
   llvm.return %10 : !llvm.struct<(f32, i32)>
 }
 
+// CHECK-LABEL: @gep
+llvm.func @gep(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64,
+               %ptr2: !llvm.ptr<struct<(array<10xf32>)>>) {
+  // CHECK: = getelementptr { i32, { i32, float } }, { i32, { i32, float } }* %{{.*}}, i64 %{{.*}}, i32 1, i32 0
+  llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
+  // CHECK: = getelementptr { [10 x float] }, { [10 x float] }* %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
+  llvm.getelementptr %ptr2[%idx, 0, %idx] : (!llvm.ptr<struct<(array<10xf32>)>>, i64, i64) -> !llvm.ptr<f32>
+  llvm.return
+}
+
 //
 // Indirect function calls
 //


        


More information about the Mlir-commits mailing list