[Mlir-commits] [mlir] 5a8a159 - Add verifier for insert/extract element/value on type match between container and inserted/extracted value, and fix vector.shuffle lowering
Mehdi Amini
llvmlistbot at llvm.org
Wed Jul 21 15:35:03 PDT 2021
Author: Mehdi Amini
Date: 2021-07-21T22:28:59Z
New Revision: 5a8a159bf5279455b4688f80a6864ca8f37f4b4e
URL: https://github.com/llvm/llvm-project/commit/5a8a159bf5279455b4688f80a6864ca8f37f4b4e
DIFF: https://github.com/llvm/llvm-project/commit/5a8a159bf5279455b4688f80a6864ca8f37f4b4e.diff
LOG: Add verifier for insert/extract element/value on type match between container and inserted/extracted value, and fix vector.shuffle lowering
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D106398
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 7d8c8a5a9600..6d09c50c92a8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -522,6 +522,7 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
let builders = [
OpBuilder<(ins "Value":$vector, "Value":$position,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
+ let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseExtractElementOp(parser, result); }];
let printer = [{ printExtractElementOp(p, *this); }];
}
@@ -532,6 +533,7 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
$res = builder.CreateExtractValue($container, extractPosition($position));
}];
let builders = [LLVM_OneResultOpBuilder];
+ let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseExtractValueOp(parser, result); }];
let printer = [{ printExtractValueOp(p, *this); }];
let hasFolder = 1;
@@ -544,6 +546,7 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
$res = builder.CreateInsertElement($vector, $value, $position);
}];
let builders = [LLVM_OneResultOpBuilder];
+ let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseInsertElementOp(parser, result); }];
let printer = [{ printInsertElementOp(p, *this); }];
}
@@ -560,6 +563,7 @@ def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
[{
build($_builder, $_state, container.getType(), container, value, position);
}]>];
+ let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseInsertValueOp(parser, result); }];
let printer = [{ printInsertValueOp(p, *this); }];
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 53ce5ca3d452..3b474113b51e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -541,6 +541,12 @@ class VectorShuffleOpConversion
}
// For all other cases, insert the individual values individually.
+ Type eltType;
+ llvm::errs() << llvmType << "\n";
+ if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
+ eltType = arrayType.getElementType();
+ else
+ eltType = llvmType.cast<VectorType>().getElementType();
Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
int64_t insPos = 0;
for (auto en : llvm::enumerate(maskArrayAttr)) {
@@ -551,7 +557,7 @@ class VectorShuffleOpConversion
value = adaptor.v2();
}
Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
- llvmType, rank, extPos);
+ eltType, rank, extPos);
insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
llvmType, rank, insPos++);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 4a8e0b9db0ee..5b9625b62a02 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -970,6 +970,20 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
return success();
}
+static LogicalResult verify(ExtractElementOp op) {
+ Type vectorType = op.vector().getType();
+ if (!LLVM::isCompatibleVectorType(vectorType))
+ return op->emitOpError("expected LLVM dialect-compatible vector type for "
+ "operand #1, got")
+ << vectorType;
+ Type valueType = LLVM::getVectorElementType(vectorType);
+ if (valueType != op.res().getType())
+ return op.emitOpError() << "Type mismatch: extracting from " << vectorType
+ << " should produce " << valueType
+ << " but this op returns " << op.res().getType();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::ExtractValueOp.
//===----------------------------------------------------------------------===//
@@ -1024,6 +1038,52 @@ static Type getInsertExtractValueElementType(OpAsmParser &parser,
return llvmType;
}
+// Extract the type at `position` in the wrapped LLVM IR aggregate type
+// `containerType`. Returns null on failure.
+static Type getInsertExtractValueElementType(Type containerType,
+ ArrayAttr positionAttr,
+ Operation *op) {
+ Type llvmType = containerType;
+ if (!isCompatibleType(containerType)) {
+ op->emitError("expected LLVM IR Dialect type, got ") << containerType;
+ return {};
+ }
+
+ // Infer the element type from the structure type: iteratively step inside the
+ // type by taking the element type, indexed by the position attribute for
+ // structures. Check the position index before accessing, it is supposed to
+ // be in bounds.
+ for (Attribute subAttr : positionAttr) {
+ auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
+ if (!positionElementAttr) {
+ op->emitOpError("expected an array of integer literals, got: ")
+ << subAttr;
+ return {};
+ }
+ int position = positionElementAttr.getInt();
+ if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
+ if (position < 0 ||
+ static_cast<unsigned>(position) >= arrayType.getNumElements()) {
+ op->emitOpError("position out of bounds: ") << position;
+ return {};
+ }
+ llvmType = arrayType.getElementType();
+ } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
+ if (position < 0 ||
+ static_cast<unsigned>(position) >= structType.getBody().size()) {
+ op->emitOpError("position out of bounds") << position;
+ return {};
+ }
+ llvmType = structType.getBody()[position];
+ } else {
+ op->emitOpError("expected LLVM IR structure/array type, got: ")
+ << llvmType;
+ return {};
+ }
+ }
+ return llvmType;
+}
+
// <operation> ::= `llvm.extractvalue` ssa-use
// `[` integer-literal (`,` integer-literal)* `]`
// attribute-dict? `:` type
@@ -1062,6 +1122,20 @@ OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+static LogicalResult verify(ExtractValueOp op) {
+ Type valueType = getInsertExtractValueElementType(op.container().getType(),
+ op.positionAttr(), op);
+ if (!valueType)
+ return failure();
+
+ if (op.res().getType() != valueType)
+ return op.emitOpError()
+ << "Type mismatch: extracting from " << op.container().getType()
+ << " should produce " << valueType << " but this op returns "
+ << op.res().getType();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::InsertElementOp.
//===----------------------------------------------------------------------===//
@@ -1104,6 +1178,14 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
return success();
}
+static LogicalResult verify(InsertElementOp op) {
+ Type valueType = LLVM::getVectorElementType(op.vector().getType());
+ if (valueType != op.value().getType())
+ return op.emitOpError()
+ << "Type mismatch: cannot insert " << op.value().getType()
+ << " into " << op.vector().getType();
+ return success();
+}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::InsertValueOp.
//===----------------------------------------------------------------------===//
@@ -1147,6 +1229,20 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
return success();
}
+static LogicalResult verify(InsertValueOp op) {
+ Type valueType = getInsertExtractValueElementType(op.container().getType(),
+ op.positionAttr(), op);
+ if (!valueType)
+ return failure();
+
+ if (op.value().getType() != valueType)
+ return op.emitOpError()
+ << "Type mismatch: cannot insert " << op.value().getType()
+ << " into " << op.container().getType();
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Printing, parsing and verification for LLVM::ReturnOp.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index a28218bdeddb..19b80b86c43f 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -355,6 +355,24 @@ func @insertvalue_wrong_nesting() {
llvm.insertvalue %a, %b[0,0] : !llvm.struct<(i32)>
}
+// -----
+
+func @insertvalue_invalid_type(%a : !llvm.array<1 x i32>) -> !llvm.array<1 x i32> {
+ // expected-error at +1 {{'llvm.insertvalue' op Type mismatch: cannot insert '!llvm.array<1 x i32>' into '!llvm.array<1 x i32>'}}
+ %b = "llvm.insertvalue"(%a, %a) {position = [0]} : (!llvm.array<1 x i32>, !llvm.array<1 x i32>) -> !llvm.array<1 x i32>
+ return %b : !llvm.array<1 x i32>
+}
+
+// -----
+
+func @extractvalue_invalid_type(%a : !llvm.array<4 x vector<8xf32>>) -> !llvm.array<4 x vector<8xf32>> {
+ // expected-error at +1 {{'llvm.extractvalue' op Type mismatch: extracting from '!llvm.array<4 x vector<8xf32>>' should produce 'vector<8xf32>' but this op returns '!llvm.array<4 x vector<8xf32>>'}}
+ %b = "llvm.extractvalue"(%a) {position = [1]}
+ : (!llvm.array<4 x vector<8xf32>>) -> !llvm.array<4 x vector<8xf32>>
+ return %b : !llvm.array<4 x vector<8xf32>>
+}
+
+
// -----
func @extractvalue_non_llvm_type(%a : i32, %b : tensor<*xi32>) {
@@ -422,6 +440,22 @@ func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
// -----
+func @invalid_vector_type_4(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> {
+ // expected-error at +1 {{'llvm.extractelement' op Type mismatch: extracting from 'vector<4xf32>' should produce 'f32' but this op returns 'vector<4xf32>'}}
+ %b = "llvm.extractelement"(%a, %idx) : (vector<4xf32>, i32) -> vector<4xf32>
+ return %b : vector<4xf32>
+}
+
+// -----
+
+func @invalid_vector_type_5(%a : vector<4xf32>, %idx : i32) -> vector<4xf32> {
+ // expected-error at +1 {{'llvm.insertelement' op Type mismatch: cannot insert 'vector<4xf32>' into 'vector<4xf32>'}}
+ %b = "llvm.insertelement"(%a, %a, %idx) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32>
+ return %b : vector<4xf32>
+}
+
+// -----
+
func @null_non_llvm_type() {
// expected-error at +1 {{must be LLVM pointer type, but got 'i32'}}
llvm.mlir.null : i32
More information about the Mlir-commits
mailing list