[Mlir-commits] [mlir] 00b12a9 - [mlir][llvm] Introduce some constant folding.

Théo Degioanni llvmlistbot at llvm.org
Mon Jun 26 04:02:00 PDT 2023


Author: Théo Degioanni
Date: 2023-06-26T12:52:48+02:00
New Revision: 00b12a94322dc7779f42d7f863957f8ea4304dad

URL: https://github.com/llvm/llvm-project/commit/00b12a94322dc7779f42d7f863957f8ea4304dad
DIFF: https://github.com/llvm/llvm-project/commit/00b12a94322dc7779f42d7f863957f8ea4304dad.diff

LOG: [mlir][llvm] Introduce some constant folding.

This revision introduces some constant folding features to the LLVM
dialect. This specific choice of operations to cover is intended to
allow the elimination of logic generated by mem2reg with memset in the
common case of memsets of constant values.

This also introduces new verifiers for integer extension operations.
This lead to a fix in SPIRV to LLVM conversion, as it would sometimes
generate invalid ZExt and SExt operations.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D153135

Added: 
    mlir/test/Dialect/LLVMIR/constant-folding.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
index 6f0f8f7d5e840..e445ca85e0f35 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
@@ -15,6 +15,7 @@ def LLVM_Dialect : Dialect {
   let name = "llvm";
   let cppNamespace = "::mlir::LLVM";
 
+  let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
   let hasRegionArgAttrVerify = 1;
   let hasRegionResultAttrVerify = 1;

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8536cbf835332..e7aca6da5d68e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -98,9 +98,13 @@ def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
 def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
 def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
 def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
-def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or">;
+def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
+  let hasFolder = 1;
+}
 def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl">;
+def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+  let hasFolder = 1;
+}
 def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
 def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">;
 
@@ -495,10 +499,15 @@ def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt",
                                   LLVM_ScalarOrVectorOf<AnyInteger>>;
 def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt",
                               LLVM_ScalarOrVectorOf<AnyInteger>,
-                              LLVM_ScalarOrVectorOf<AnyInteger>>;
+                              LLVM_ScalarOrVectorOf<AnyInteger>> {
+  let hasVerifier = 1;
+}
 def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
                               LLVM_ScalarOrVectorOf<AnyInteger>,
-                              LLVM_ScalarOrVectorOf<AnyInteger>>;
+                              LLVM_ScalarOrVectorOf<AnyInteger>> {
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
 def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
                                LLVM_ScalarOrVectorOf<AnyInteger>,
                                LLVM_ScalarOrVectorOf<AnyInteger>>;

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 8b4380870a26d..28e587a066e4e 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -51,6 +51,17 @@ static bool isUnsignedIntegerOrVector(Type type) {
   return false;
 }
 
+/// Returns the width of an integer or of the element type of an integer vector,
+/// if applicable.
+static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
+  if (auto intType = dyn_cast<IntegerType>(type))
+    return intType.getWidth();
+  if (auto vecType = dyn_cast<VectorType>(type))
+    if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
+      return intType.getWidth();
+  return std::nullopt;
+}
+
 /// Returns the bit width of integer, float or vector of float or integer values
 static unsigned getBitWidth(Type type) {
   assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
@@ -1183,15 +1194,30 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
       return success();
     }
 
+    std::optional<uint64_t> dstTypeWidth =
+        getIntegerOrVectorElementWidth(dstType);
+    std::optional<uint64_t> op2TypeWidth =
+        getIntegerOrVectorElementWidth(op2Type);
+
+    if (!dstTypeWidth || !op2TypeWidth)
+      return failure();
+
     Location loc = operation.getLoc();
     Value extended;
-    if (isUnsignedIntegerOrVector(op2Type)) {
-      extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
-                                                        adaptor.getOperand2());
+    if (op2TypeWidth < dstTypeWidth) {
+      if (isUnsignedIntegerOrVector(op2Type)) {
+        extended = rewriter.template create<LLVM::ZExtOp>(
+            loc, dstType, adaptor.getOperand2());
+      } else {
+        extended = rewriter.template create<LLVM::SExtOp>(
+            loc, dstType, adaptor.getOperand2());
+      }
+    } else if (op2TypeWidth == dstTypeWidth) {
+      extended = adaptor.getOperand2();
     } else {
-      extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
-                                                        adaptor.getOperand2());
+      return failure();
     }
+
     Value result = rewriter.template create<LLVMOp>(
         loc, dstType, adaptor.getOperand1(), extended);
     rewriter.replaceOp(operation, result);

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index b24e2ca47b688..9b91304fadc2d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2531,6 +2531,64 @@ LogicalResult FenceOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Verifier for extension ops
+//===----------------------------------------------------------------------===//
+
+/// Verifies that the given extension operation operates on consistent scalars
+/// or vectors, and that the target width is larger than the input width.
+template <class ExtOp>
+static LogicalResult verifyExtOp(ExtOp op) {
+  IntegerType inputType, outputType;
+  if (isCompatibleVectorType(op.getArg().getType())) {
+    if (!isCompatibleVectorType(op.getResult().getType()))
+      return op.emitError(
+          "input type is a vector but output type is an integer");
+    if (getVectorNumElements(op.getArg().getType()) !=
+        getVectorNumElements(op.getResult().getType()))
+      return op.emitError("input and output vectors are of incompatible shape");
+    // Because this is a CastOp, the element of vectors is guaranteed to be an
+    // integer.
+    inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
+    outputType =
+        cast<IntegerType>(getVectorElementType(op.getResult().getType()));
+  } else {
+    // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
+    // an integer.
+    inputType = cast<IntegerType>(op.getArg().getType());
+    outputType = dyn_cast<IntegerType>(op.getResult().getType());
+    if (!outputType)
+      return op.emitError(
+          "input type is an integer but output type is a vector");
+  }
+
+  if (outputType.getWidth() <= inputType.getWidth())
+    return op.emitError("integer width of the output type is smaller or "
+                        "equal to the integer width of the input type");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ZExtOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
+
+OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
+  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
+  if (!arg)
+    return {};
+
+  size_t targetSize = cast<IntegerType>(getType()).getWidth();
+  return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
+}
+
+//===----------------------------------------------------------------------===//
+// SExtOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
+
 //===----------------------------------------------------------------------===//
 // Folder and verifier for LLVM::BitcastOp
 //===----------------------------------------------------------------------===//
@@ -2647,6 +2705,42 @@ OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// ShlOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
+  auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
+  if (!rhs)
+    return {};
+
+  if (rhs.getValue().getZExtValue() >=
+      getLhs().getType().getIntOrFloatBitWidth())
+    return {}; // TODO: Fold into poison.
+
+  auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
+  if (!lhs)
+    return {};
+
+  return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
+}
+
+//===----------------------------------------------------------------------===//
+// OrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
+  auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
+  if (!lhs)
+    return {};
+
+  auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
+  if (!rhs)
+    return {};
+
+  return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
+}
+
 //===----------------------------------------------------------------------===//
 // Utilities for LLVM::MetadataOp
 //===----------------------------------------------------------------------===//
@@ -3186,6 +3280,15 @@ LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
   return verifyParameterAttribute(op, resType, resAttr);
 }
 
+Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
+                                            Type type, Location loc) {
+  // TODO: Accept more possible attributes. So far, only IntegerAttr may come
+  // up.
+  if (!isa<IntegerAttr>(value))
+    return nullptr;
+  return builder.create<LLVM::ConstantOp>(loc, type, value);
+}
+
 //===----------------------------------------------------------------------===//
 // Utility functions.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/LLVMIR/constant-folding.mlir b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
new file mode 100644
index 0000000000000..f800f2690467d
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(canonicalize))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @zext_basic
+llvm.func @zext_basic() -> i64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.zext %0 : i32 to i64
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK: llvm.return %[[RES]] : i64
+  llvm.return %1 : i64
+}
+
+// CHECK-LABEL: llvm.func @zext_neg
+llvm.func @zext_neg() -> i64 {
+  %0 = llvm.mlir.constant(-1 : i32) : i32
+  %1 = llvm.zext %0 : i32 to i64
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(4294967295 : i64) : i64
+  // CHECK: llvm.return %[[RES]] : i64
+  llvm.return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @shl_basic
+llvm.func @shl_basic() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %2 = llvm.shl %0, %1 : i32
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
+
+// CHECK-LABEL: llvm.func @shl_multiple
+llvm.func @shl_multiple() -> i32 {
+  %0 = llvm.mlir.constant(45 : i32) : i32
+  %1 = llvm.mlir.constant(7 : i32) : i32
+  %2 = llvm.shl %0, %1 : i32
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(5760 : i32) : i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @or_basic
+llvm.func @or_basic() -> i32 {
+  %0 = llvm.mlir.constant(5 : i32) : i32
+  %1 = llvm.mlir.constant(9 : i32) : i32
+  %2 = llvm.or %0, %1 : i32
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(13 : i32) : i32
+  // CHECK: llvm.return %[[RES]] : i32
+  llvm.return %2 : i32
+}

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index c88ed03f0c6e0..e6e11923439e0 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1447,3 +1447,45 @@ llvm.comdat @__llvm_comdat {
 llvm.mlir.global @not_comdat(0 : i32) : i32
 // expected-error at +1 {{expected comdat symbol}}
 llvm.mlir.global @invalid_comdat_use(0 : i32) comdat(@not_comdat) : i32
+
+// -----
+
+func.func @invalid_zext_target_size_equal(%arg: i32)  {
+  // expected-error at +1 {{integer width of the output type is smaller or equal to the integer width of the input type}}
+  %0 = llvm.zext %arg : i32 to i32
+}
+
+// -----
+
+func.func @invalid_zext_target_size(%arg: i32)  {
+  // expected-error at +1 {{integer width of the output type is smaller or equal to the integer width of the input type}}
+  %0 = llvm.zext %arg : i32 to i16
+}
+
+// -----
+
+func.func @invalid_zext_target_size_vector(%arg: vector<1xi32>)  {
+  // expected-error at +1 {{integer width of the output type is smaller or equal to the integer width of the input type}}
+  %0 = llvm.zext %arg : vector<1xi32> to vector<1xi16>
+}
+
+// -----
+
+func.func @invalid_zext_target_shape(%arg: vector<1xi32>)  {
+  // expected-error at +1 {{input and output vectors are of incompatible shape}}
+  %0 = llvm.zext %arg : vector<1xi32> to vector<2xi64>
+}
+
+// -----
+
+func.func @invalid_zext_target_type(%arg: i32)  {
+  // expected-error at +1 {{input type is an integer but output type is a vector}}
+  %0 = llvm.zext %arg : i32 to vector<1xi64>
+}
+
+// -----
+
+func.func @invalid_zext_target_type_two(%arg: vector<1xi32>)  {
+  // expected-error at +1 {{input type is a vector but output type is an integer}}
+  %0 = llvm.zext %arg : vector<1xi32> to i64
+}

diff  --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index 4a262f777509d..ce6338fb34883 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -1,12 +1,11 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @basic_memset
-llvm.func @basic_memset() -> i32 {
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @basic_memset(%memset_value: i8) -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %memset_value = llvm.mlir.constant(42 : i8) : i8
   %memset_len = llvm.mlir.constant(4 : i32) : i32
-  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
   // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
   // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
@@ -24,36 +23,27 @@ llvm.func @basic_memset() -> i32 {
 
 // -----
 
-// CHECK-LABEL: llvm.func @allow_dynamic_value_memset
-// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
-llvm.func @allow_dynamic_value_memset(%memset_value: i8) -> i32 {
+// CHECK-LABEL: llvm.func @basic_memset_constant
+llvm.func @basic_memset_constant() -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
   %memset_len = llvm.mlir.constant(4 : i32) : i32
-  // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
-  // CHECK-NOT: "llvm.intr.memset"
-  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
-  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
-  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
-  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
-  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
-  // CHECK-NOT: "llvm.intr.memset"
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: llvm.return %[[VALUE_32]] : i32
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: llvm.return %[[RES]] : i32
   llvm.return %2 : i32
 }
 
 // -----
 
 // CHECK-LABEL: llvm.func @exotic_target_memset
-llvm.func @exotic_target_memset() -> i40 {
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %memset_value = llvm.mlir.constant(42 : i8) : i8
   %memset_len = llvm.mlir.constant(5 : i32) : i32
-  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
   // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
   // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
   // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
@@ -74,6 +64,21 @@ llvm.func @exotic_target_memset() -> i40 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @exotic_target_memset_constant
+llvm.func @exotic_target_memset_constant() -> i40 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_value = llvm.mlir.constant(42 : i8) : i8
+  %memset_len = llvm.mlir.constant(5 : i32) : i32
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
+  // CHECK: %[[RES:.*]] = llvm.mlir.constant(181096032810 : i40) : i40
+  // CHECK: llvm.return %[[RES]] : i40
+  llvm.return %2 : i40
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @no_volatile_memset
 llvm.func @no_volatile_memset() -> i32 {
   // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32


        


More information about the Mlir-commits mailing list