[llvm] [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (PR #67083)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 17 16:19:11 PDT 2023


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/67083

>From dca986262a0f85d21cbc2ff9063c06cbce6acc98 Mon Sep 17 00:00:00 2001
From: Stella Laurenzo <stellaraccident at gmail.com>
Date: Fri, 6 Oct 2023 11:30:08 -0700
Subject: [PATCH 1/5] Revert "[RISCV][CostModel] VPIntrinsics have same cost as
 their non-vp counterparts (#67178)"

This reverts commit fc865c20345860f394448c228054beafc22a1d4d.

Breaks x86 test.
---
 llvm/include/llvm/CodeGen/BasicTTIImpl.h      |  56 ---
 llvm/test/Analysis/CostModel/RISCV/gep.ll     |   8 +-
 .../CostModel/RISCV/rvv-intrinsics.ll         | 370 +-----------------
 3 files changed, 5 insertions(+), 429 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 5d83a0d37dc3a1b..0f8fb14f7b0e3b2 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1686,62 +1686,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     }
     }
 
-    // VP Intrinsics should have the same cost as their non-vp counterpart.
-    // TODO: Adjust the cost to make the vp intrinsic cheaper than its non-vp
-    // counterpart when the vector length argument is smaller than the maximum
-    // vector length.
-    if (VPIntrinsic::isVPIntrinsic(ICA.getID())) {
-      std::optional<unsigned> FOp =
-          VPIntrinsic::getFunctionalOpcodeForVP(ICA.getID());
-      if (FOp) {
-        // TODO: Support other kinds of Intrinsics (i.e. reductions)
-        if (ICA.getID() == Intrinsic::vp_load) {
-          Align Alignment;
-          if (auto *VPI = dyn_cast_or_null<VPIntrinsic>(ICA.getInst()))
-            Alignment = VPI->getPointerAlignment().valueOrOne();
-          unsigned AS = 0;
-          if (ICA.getArgs().size() > 1)
-            if (auto *PtrTy =
-                    dyn_cast<PointerType>(ICA.getArgs()[0]->getType()))
-              AS = PtrTy->getAddressSpace();
-          return thisT()->getMemoryOpCost(*FOp, ICA.getReturnType(), Alignment,
-                                          AS, CostKind);
-        }
-        if (ICA.getID() == Intrinsic::vp_store) {
-          Align Alignment;
-          if (auto *VPI = dyn_cast_or_null<VPIntrinsic>(ICA.getInst()))
-            Alignment = VPI->getPointerAlignment().valueOrOne();
-          unsigned AS = 0;
-          if (ICA.getArgs().size() >= 2)
-            if (auto *PtrTy =
-                    dyn_cast<PointerType>(ICA.getArgs()[1]->getType()))
-              AS = PtrTy->getAddressSpace();
-          return thisT()->getMemoryOpCost(*FOp, Args[0]->getType(), Alignment,
-                                          AS, CostKind);
-        }
-        if (VPBinOpIntrinsic::isVPBinOp(ICA.getID())) {
-          return thisT()->getArithmeticInstrCost(*FOp, ICA.getReturnType(),
-                                                 CostKind);
-        }
-      }
-
-      std::optional<Intrinsic::ID> FID =
-          VPIntrinsic::getFunctionalIntrinsicIDForVP(ICA.getID());
-      if (FID) {
-        // Non-vp version will have same Args/Tys except mask and vector length.
-        assert(ICA.getArgs().size() >= 2 && ICA.getArgTypes().size() >= 2 &&
-               "Expected VPIntrinsic to have Mask and Vector Length args and "
-               "types");
-        ArrayRef<const Value *> NewArgs = ArrayRef(ICA.getArgs()).drop_back(2);
-        ArrayRef<Type *> NewTys = ArrayRef(ICA.getArgTypes()).drop_back(2);
-
-        IntrinsicCostAttributes NewICA(*FID, ICA.getReturnType(), NewArgs,
-                                       NewTys, ICA.getFlags(), ICA.getInst(),
-                                       ICA.getScalarizationCost());
-        return thisT()->getIntrinsicInstrCost(NewICA, CostKind);
-      }
-    }
-
     // Assume that we need to scalarize this intrinsic.
     // Compute the scalarization overhead based on Args for a vector
     // intrinsic.
diff --git a/llvm/test/Analysis/CostModel/RISCV/gep.ll b/llvm/test/Analysis/CostModel/RISCV/gep.ll
index 4fadf34c1973f83..be518faf7e05165 100644
--- a/llvm/test/Analysis/CostModel/RISCV/gep.ll
+++ b/llvm/test/Analysis/CostModel/RISCV/gep.ll
@@ -270,7 +270,7 @@ define void @non_foldable_vector_uses(ptr %base, <2 x ptr> %base.vec) {
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %4 = getelementptr i8, ptr %base, i32 42
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 5 for instruction: %x4 = call <2 x i8> @llvm.masked.expandload.v2i8(ptr %4, <2 x i1> undef, <2 x i8> undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %5 = getelementptr i8, ptr %base, i32 42
-; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %x5 = call <2 x i8> @llvm.vp.load.v2i8.p0(ptr %5, <2 x i1> undef, i32 undef)
+; RVI-NEXT:  Cost Model: Found an estimated cost of 5 for instruction: %x5 = call <2 x i8> @llvm.vp.load.v2i8.p0(ptr %5, <2 x i1> undef, i32 undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %6 = getelementptr i8, ptr %base, i32 42
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 5 for instruction: %x6 = call <2 x i8> @llvm.experimental.vp.strided.load.v2i8.p0.i64(ptr %6, i64 undef, <2 x i1> undef, i32 undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %7 = getelementptr i8, ptr %base, i32 42
@@ -282,7 +282,7 @@ define void @non_foldable_vector_uses(ptr %base, <2 x ptr> %base.vec) {
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %10 = getelementptr i8, ptr %base, i32 42
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: call void @llvm.masked.compressstore.v2i8(<2 x i8> undef, ptr %10, <2 x i1> undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %11 = getelementptr i8, ptr %base, i32 42
-; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.v2i8.p0(<2 x i8> undef, ptr %11, <2 x i1> undef, i32 undef)
+; RVI-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: call void @llvm.vp.store.v2i8.p0(<2 x i8> undef, ptr %11, <2 x i1> undef, i32 undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %12 = getelementptr i8, ptr %base, i32 42
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: call void @llvm.experimental.vp.strided.store.v2i8.p0.i64(<2 x i8> undef, ptr %12, i64 undef, <2 x i1> undef, i32 undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
@@ -340,7 +340,7 @@ define void @foldable_vector_uses(ptr %base, <2 x ptr> %base.vec) {
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %4 = getelementptr i8, ptr %base, i32 0
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 5 for instruction: %x4 = call <2 x i8> @llvm.masked.expandload.v2i8(ptr %4, <2 x i1> undef, <2 x i8> undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %5 = getelementptr i8, ptr %base, i32 0
-; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %x5 = call <2 x i8> @llvm.vp.load.v2i8.p0(ptr %5, <2 x i1> undef, i32 undef)
+; RVI-NEXT:  Cost Model: Found an estimated cost of 5 for instruction: %x5 = call <2 x i8> @llvm.vp.load.v2i8.p0(ptr %5, <2 x i1> undef, i32 undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %6 = getelementptr i8, ptr %base, i32 0
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 5 for instruction: %x6 = call <2 x i8> @llvm.experimental.vp.strided.load.v2i8.p0.i64(ptr %6, i64 undef, <2 x i1> undef, i32 undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %7 = getelementptr i8, ptr %base, i32 0
@@ -352,7 +352,7 @@ define void @foldable_vector_uses(ptr %base, <2 x ptr> %base.vec) {
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %10 = getelementptr i8, ptr %base, i32 0
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: call void @llvm.masked.compressstore.v2i8(<2 x i8> undef, ptr %10, <2 x i1> undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %11 = getelementptr i8, ptr %base, i32 0
-; RVI-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.v2i8.p0(<2 x i8> undef, ptr %11, <2 x i1> undef, i32 undef)
+; RVI-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: call void @llvm.vp.store.v2i8.p0(<2 x i8> undef, ptr %11, <2 x i1> undef, i32 undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %12 = getelementptr i8, ptr %base, i32 0
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 12 for instruction: call void @llvm.experimental.vp.strided.store.v2i8.p0.i64(<2 x i8> undef, ptr %12, i64 undef, <2 x i1> undef, i32 undef)
 ; RVI-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
diff --git a/llvm/test/Analysis/CostModel/RISCV/rvv-intrinsics.ll b/llvm/test/Analysis/CostModel/RISCV/rvv-intrinsics.ll
index 85364c935267d24..93de623cf1c6da9 100644
--- a/llvm/test/Analysis/CostModel/RISCV/rvv-intrinsics.ll
+++ b/llvm/test/Analysis/CostModel/RISCV/rvv-intrinsics.ll
@@ -206,378 +206,10 @@ define void @vp_fshl() {
   ret void
 }
 
-define void @add() {
-; CHECK-LABEL: 'add'
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t0 = call <2 x i8> @llvm.vp.add.v2i8(<2 x i8> undef, <2 x i8> undef, <2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t1 = add <2 x i8> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t2 = call <4 x i8> @llvm.vp.add.v4i8(<4 x i8> undef, <4 x i8> undef, <4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t3 = add <4 x i8> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t4 = call <8 x i8> @llvm.vp.add.v8i8(<8 x i8> undef, <8 x i8> undef, <8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t5 = add <8 x i8> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t6 = call <16 x i8> @llvm.vp.add.v16i8(<16 x i8> undef, <16 x i8> undef, <16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t7 = add <16 x i8> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t8 = call <2 x i64> @llvm.vp.add.v2i64(<2 x i64> undef, <2 x i64> undef, <2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t9 = add <2 x i64> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t10 = call <4 x i64> @llvm.vp.add.v4i64(<4 x i64> undef, <4 x i64> undef, <4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t12 = add <4 x i64> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %t13 = call <8 x i64> @llvm.vp.add.v8i64(<8 x i64> undef, <8 x i64> undef, <8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %t14 = add <8 x i64> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: %t15 = call <16 x i64> @llvm.vp.add.v16i64(<16 x i64> undef, <16 x i64> undef, <16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: %t16 = add <16 x i64> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t17 = call <vscale x 2 x i8> @llvm.vp.add.nxv2i8(<vscale x 2 x i8> undef, <vscale x 2 x i8> undef, <vscale x 2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t18 = add <vscale x 2 x i8> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t19 = call <vscale x 4 x i8> @llvm.vp.add.nxv4i8(<vscale x 4 x i8> undef, <vscale x 4 x i8> undef, <vscale x 4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t20 = add <vscale x 4 x i8> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t21 = call <vscale x 8 x i8> @llvm.vp.add.nxv8i8(<vscale x 8 x i8> undef, <vscale x 8 x i8> undef, <vscale x 8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t22 = add <vscale x 8 x i8> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t23 = call <vscale x 16 x i8> @llvm.vp.add.nxv16i8(<vscale x 16 x i8> undef, <vscale x 16 x i8> undef, <vscale x 16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t24 = add <vscale x 16 x i8> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t25 = call <vscale x 2 x i64> @llvm.vp.add.nxv2i64(<vscale x 2 x i64> undef, <vscale x 2 x i64> undef, <vscale x 2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t26 = add <vscale x 2 x i64> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %t27 = call <vscale x 4 x i64> @llvm.vp.add.nxv4i64(<vscale x 4 x i64> undef, <vscale x 4 x i64> undef, <vscale x 4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %t28 = add <vscale x 4 x i64> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: %t29 = call <vscale x 8 x i64> @llvm.vp.add.nxv8i64(<vscale x 8 x i64> undef, <vscale x 8 x i64> undef, <vscale x 8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: %t30 = add <vscale x 8 x i64> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 16 for instruction: %t31 = call <vscale x 16 x i64> @llvm.vp.add.nxv16i64(<vscale x 16 x i64> undef, <vscale x 16 x i64> undef, <vscale x 16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 16 for instruction: %t32 = add <vscale x 16 x i64> undef, undef
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
-;
-  %t0 = call <2 x i8> @llvm.vp.add.v2i8(<2 x i8> undef, <2 x i8> undef, <2 x i1> undef, i32 undef)
-  %t1 = add <2 x i8> undef, undef
-  %t2 = call <4 x i8> @llvm.vp.add.v4i8(<4 x i8> undef, <4 x i8> undef, <4 x i1> undef, i32 undef)
-  %t3 = add <4 x i8> undef, undef
-  %t4 = call <8 x i8> @llvm.vp.add.v8i8(<8 x i8> undef, <8 x i8> undef, <8 x i1> undef, i32 undef)
-  %t5 = add <8 x i8> undef, undef
-  %t6 = call <16 x i8> @llvm.vp.add.v16i8(<16 x i8> undef, <16 x i8> undef, <16 x i1> undef, i32 undef)
-  %t7 = add <16 x i8> undef, undef
-  %t8 = call <2 x i64> @llvm.vp.add.v2i64(<2 x i64> undef, <2 x i64> undef, <2 x i1> undef, i32 undef)
-  %t9 = add <2 x i64> undef, undef
-  %t10 = call <4 x i64> @llvm.vp.add.v4i64(<4 x i64> undef, <4 x i64> undef, <4 x i1> undef, i32 undef)
-  %t12 = add <4 x i64> undef, undef
-  %t13 = call <8 x i64> @llvm.vp.add.v8i64(<8 x i64> undef, <8 x i64> undef, <8 x i1> undef, i32 undef)
-  %t14 = add <8 x i64> undef, undef
-  %t15 = call <16 x i64> @llvm.vp.add.v16i64(<16 x i64> undef, <16 x i64> undef, <16 x i1> undef, i32 undef)
-  %t16 = add <16 x i64> undef, undef
-  %t17 = call <vscale x 2 x i8> @llvm.vp.add.nv2i8(<vscale x 2 x i8> undef, <vscale x 2 x i8> undef, <vscale x 2 x i1> undef, i32 undef)
-  %t18 = add <vscale x 2 x i8> undef, undef
-  %t19 = call <vscale x 4 x i8> @llvm.vp.add.nv4i8(<vscale x 4 x i8> undef, <vscale x 4 x i8> undef, <vscale x 4 x i1> undef, i32 undef)
-  %t20 = add <vscale x 4 x i8> undef, undef
-  %t21 = call <vscale x 8 x i8> @llvm.vp.add.nv8i8(<vscale x 8 x i8> undef, <vscale x 8 x i8> undef, <vscale x 8 x i1> undef, i32 undef)
-  %t22 = add <vscale x 8 x i8> undef, undef
-  %t23 = call <vscale x 16 x i8> @llvm.vp.add.nv16i8(<vscale x 16 x i8> undef, <vscale x 16 x i8> undef, <vscale x 16 x i1> undef, i32 undef)
-  %t24 = add <vscale x 16 x i8> undef, undef
-  %t25 = call <vscale x 2 x i64> @llvm.vp.add.nv2i64(<vscale x 2 x i64> undef, <vscale x 2 x i64> undef, <vscale x 2 x i1> undef, i32 undef)
-  %t26 = add <vscale x 2 x i64> undef, undef
-  %t27 = call <vscale x 4 x i64> @llvm.vp.add.nv4i64(<vscale x 4 x i64> undef, <vscale x 4 x i64> undef, <vscale x 4 x i1> undef, i32 undef)
-  %t28 = add <vscale x 4 x i64> undef, undef
-  %t29 = call <vscale x 8 x i64> @llvm.vp.add.nv8i64(<vscale x 8 x i64> undef, <vscale x 8 x i64> undef, <vscale x 8 x i1> undef, i32 undef)
-  %t30 = add <vscale x 8 x i64> undef, undef
-  %t31 = call <vscale x 16 x i64> @llvm.vp.add.nv16i64(<vscale x 16 x i64> undef, <vscale x 16 x i64> undef, <vscale x 16 x i1> undef, i32 undef)
-  %t32 = add <vscale x 16 x i64> undef, undef
-  ret void
-}
-
-define void @abs() {
-; CHECK-LABEL: 'abs'
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %1 = call <2 x i8> @llvm.vp.abs.v2i8(<2 x i8> undef, i1 false, <2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %2 = call <2 x i8> @llvm.abs.v2i8(<2 x i8> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %3 = call <4 x i8> @llvm.vp.abs.v4i8(<4 x i8> undef, i1 false, <4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %4 = call <4 x i8> @llvm.abs.v4i8(<4 x i8> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %5 = call <8 x i8> @llvm.vp.abs.v8i8(<8 x i8> undef, i1 false, <8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %6 = call <8 x i8> @llvm.abs.v8i8(<8 x i8> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %7 = call <16 x i8> @llvm.vp.abs.v16i8(<16 x i8> undef, i1 false, <16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %8 = call <16 x i8> @llvm.abs.v16i8(<16 x i8> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %9 = call <2 x i64> @llvm.vp.abs.v2i64(<2 x i64> undef, i1 false, <2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %10 = call <2 x i64> @llvm.abs.v2i64(<2 x i64> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %11 = call <4 x i64> @llvm.vp.abs.v4i64(<4 x i64> undef, i1 false, <4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %12 = call <4 x i64> @llvm.abs.v4i64(<4 x i64> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %13 = call <8 x i64> @llvm.vp.abs.v8i64(<8 x i64> undef, i1 false, <8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %14 = call <8 x i64> @llvm.abs.v8i64(<8 x i64> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %15 = call <16 x i64> @llvm.vp.abs.v16i64(<16 x i64> undef, i1 false, <16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %16 = call <16 x i64> @llvm.abs.v16i64(<16 x i64> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %17 = call <vscale x 2 x i8> @llvm.vp.abs.nxv2i8(<vscale x 2 x i8> undef, i1 false, <vscale x 2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %18 = call <vscale x 2 x i8> @llvm.abs.nxv2i8(<vscale x 2 x i8> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %19 = call <vscale x 4 x i8> @llvm.vp.abs.nxv4i8(<vscale x 4 x i8> undef, i1 false, <vscale x 4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %20 = call <vscale x 4 x i8> @llvm.abs.nxv4i8(<vscale x 4 x i8> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %21 = call <vscale x 8 x i8> @llvm.vp.abs.nxv8i8(<vscale x 8 x i8> undef, i1 false, <vscale x 8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %22 = call <vscale x 8 x i8> @llvm.abs.nxv8i8(<vscale x 8 x i8> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %23 = call <vscale x 16 x i8> @llvm.vp.abs.nxv16i8(<vscale x 16 x i8> undef, i1 false, <vscale x 16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %24 = call <vscale x 16 x i8> @llvm.abs.nxv16i8(<vscale x 16 x i8> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %25 = call <vscale x 2 x i64> @llvm.vp.abs.nxv2i64(<vscale x 2 x i64> undef, i1 false, <vscale x 2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %26 = call <vscale x 2 x i64> @llvm.abs.nxv2i64(<vscale x 2 x i64> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %27 = call <vscale x 4 x i64> @llvm.vp.abs.nxv4i64(<vscale x 4 x i64> undef, i1 false, <vscale x 4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %28 = call <vscale x 4 x i64> @llvm.abs.nxv4i64(<vscale x 4 x i64> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %29 = call <vscale x 8 x i64> @llvm.vp.abs.nxv8i64(<vscale x 8 x i64> undef, i1 false, <vscale x 8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %30 = call <vscale x 8 x i64> @llvm.abs.nxv8i64(<vscale x 8 x i64> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %31 = call <vscale x 16 x i64> @llvm.vp.abs.nxv16i64(<vscale x 16 x i64> undef, i1 false, <vscale x 16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %32 = call <vscale x 16 x i64> @llvm.abs.nxv16i64(<vscale x 16 x i64> undef, i1 false)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
-;
-  call <2 x i8> @llvm.vp.abs.v2i8(<2 x i8> undef, i1 0, <2 x i1> undef, i32 undef)
-  call <2 x i8> @llvm.abs.v2i8(<2 x i8> undef, i1 0)
-  call <4 x i8> @llvm.vp.abs.v4i8(<4 x i8> undef, i1 0, <4 x i1> undef, i32 undef)
-  call <4 x i8> @llvm.abs.v4i8(<4 x i8> undef, i1 0)
-  call <8 x i8> @llvm.vp.abs.v8i8(<8 x i8> undef, i1 0, <8 x i1> undef, i32 undef)
-  call <8 x i8> @llvm.abs.v8i8(<8 x i8> undef, i1 0)
-  call <16 x i8> @llvm.vp.abs.v16i8(<16 x i8> undef, i1 0, <16 x i1> undef, i32 undef)
-  call <16 x i8> @llvm.abs.v16i8(<16 x i8> undef, i1 0)
-  call <2 x i64> @llvm.vp.abs.v2i64(<2 x i64> undef, i1 0, <2 x i1> undef, i32 undef)
-  call <2 x i64> @llvm.abs.v2i64(<2 x i64> undef, i1 0)
-  call <4 x i64> @llvm.vp.abs.v4i64(<4 x i64> undef, i1 0, <4 x i1> undef, i32 undef)
-  call <4 x i64> @llvm.abs.v4i64(<4 x i64> undef, i1 0)
-  call <8 x i64> @llvm.vp.abs.v8i64(<8 x i64> undef, i1 0, <8 x i1> undef, i32 undef)
-  call <8 x i64> @llvm.abs.v8i64(<8 x i64> undef, i1 0)
-  call <16 x i64> @llvm.vp.abs.v16i64(<16 x i64> undef, i1 0, <16 x i1> undef, i32 undef)
-  call <16 x i64> @llvm.abs.v16i64(<16 x i64> undef, i1 0)
-  call <vscale x 2 x i8> @llvm.vp.abs.nv2i8(<vscale x 2 x i8> undef, i1 0, <vscale x 2 x i1> undef, i32 undef)
-  call <vscale x 2 x i8> @llvm.abs.nv2i8(<vscale x 2 x i8> undef, i1 0)
-  call <vscale x 4 x i8> @llvm.vp.abs.nv4i8(<vscale x 4 x i8> undef, i1 0, <vscale x 4 x i1> undef, i32 undef)
-  call <vscale x 4 x i8> @llvm.abs.nv4i8(<vscale x 4 x i8> undef, i1 0)
-  call <vscale x 8 x i8> @llvm.vp.abs.nv8i8(<vscale x 8 x i8> undef, i1 0, <vscale x 8 x i1> undef, i32 undef)
-  call <vscale x 8 x i8> @llvm.abs.nv8i8(<vscale x 8 x i8> undef, i1 0)
-  call <vscale x 16 x i8> @llvm.vp.abs.nv16i8(<vscale x 16 x i8> undef, i1 0, <vscale x 16 x i1> undef, i32 undef)
-  call <vscale x 16 x i8> @llvm.abs.nv16i8(<vscale x 16 x i8> undef, i1 0)
-  call <vscale x 2 x i64> @llvm.vp.abs.nv2i64(<vscale x 2 x i64> undef, i1 0, <vscale x 2 x i1> undef, i32 undef)
-  call <vscale x 2 x i64> @llvm.abs.nv2i64(<vscale x 2 x i64> undef, i1 0)
-  call <vscale x 4 x i64> @llvm.vp.abs.nv4i64(<vscale x 4 x i64> undef, i1 0, <vscale x 4 x i1> undef, i32 undef)
-  call <vscale x 4 x i64> @llvm.abs.nv4i64(<vscale x 4 x i64> undef, i1 0)
-  call <vscale x 8 x i64> @llvm.vp.abs.nv8i64(<vscale x 8 x i64> undef, i1 0, <vscale x 8 x i1> undef, i32 undef)
-  call <vscale x 8 x i64> @llvm.abs.nv8i64(<vscale x 8 x i64> undef, i1 0)
-  call <vscale x 16 x i64> @llvm.vp.abs.nv16i64(<vscale x 16 x i64> undef, i1 0, <vscale x 16 x i1> undef, i32 undef)
-  call <vscale x 16 x i64> @llvm.abs.nv16i64(<vscale x 16 x i64> undef, i1 0)
-  ret void
-}
-
-define void @load() {
-; CHECK-LABEL: 'load'
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t0 = call <2 x i8> @llvm.vp.load.v2i8.p0(ptr undef, <2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t1 = load <2 x i8>, ptr undef, align 2
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t2 = call <4 x i8> @llvm.vp.load.v4i8.p0(ptr undef, <4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t3 = load <4 x i8>, ptr undef, align 4
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t4 = call <8 x i8> @llvm.vp.load.v8i8.p0(ptr undef, <8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t5 = load <8 x i8>, ptr undef, align 8
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t6 = call <16 x i8> @llvm.vp.load.v16i8.p0(ptr undef, <16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t7 = load <16 x i8>, ptr undef, align 16
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t8 = call <2 x i64> @llvm.vp.load.v2i64.p0(ptr undef, <2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t9 = load <2 x i64>, ptr undef, align 16
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t10 = call <4 x i64> @llvm.vp.load.v4i64.p0(ptr undef, <4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t12 = load <4 x i64>, ptr undef, align 32
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %t13 = call <8 x i64> @llvm.vp.load.v8i64.p0(ptr undef, <8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %t14 = load <8 x i64>, ptr undef, align 64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: %t15 = call <16 x i64> @llvm.vp.load.v16i64.p0(ptr undef, <16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: %t16 = load <16 x i64>, ptr undef, align 128
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t17 = call <vscale x 2 x i8> @llvm.vp.load.nxv2i8.p0(ptr undef, <vscale x 2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t18 = load <vscale x 2 x i8>, ptr undef, align 2
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t19 = call <vscale x 4 x i8> @llvm.vp.load.nxv4i8.p0(ptr undef, <vscale x 4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t20 = load <vscale x 4 x i8>, ptr undef, align 4
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t21 = call <vscale x 8 x i8> @llvm.vp.load.nxv8i8.p0(ptr undef, <vscale x 8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %t22 = load <vscale x 8 x i8>, ptr undef, align 8
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t23 = call <vscale x 16 x i8> @llvm.vp.load.nxv16i8.p0(ptr undef, <vscale x 16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t24 = load <vscale x 16 x i8>, ptr undef, align 16
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t25 = call <vscale x 2 x i64> @llvm.vp.load.nxv2i64.p0(ptr undef, <vscale x 2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %t26 = load <vscale x 2 x i64>, ptr undef, align 16
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %t27 = call <vscale x 4 x i64> @llvm.vp.load.nxv4i64.p0(ptr undef, <vscale x 4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %t28 = load <vscale x 4 x i64>, ptr undef, align 32
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: %t29 = call <vscale x 8 x i64> @llvm.vp.load.nxv8i64.p0(ptr undef, <vscale x 8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: %t30 = load <vscale x 8 x i64>, ptr undef, align 64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 16 for instruction: %t31 = call <vscale x 16 x i64> @llvm.vp.load.nxv16i64.p0(ptr undef, <vscale x 16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 16 for instruction: %t32 = load <vscale x 16 x i64>, ptr undef, align 128
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
-;
-  %t0 = call <2 x i8> @llvm.vp.load.v2i8(ptr undef, <2 x i1> undef, i32 undef)
-  %t1 = load <2 x i8>, ptr undef
-  %t2 = call <4 x i8> @llvm.vp.load.v4i8(ptr undef, <4 x i1> undef, i32 undef)
-  %t3 = load <4 x i8>, ptr undef
-  %t4 = call <8 x i8> @llvm.vp.load.v8i8(ptr undef, <8 x i1> undef, i32 undef)
-  %t5 = load <8 x i8>, ptr undef
-  %t6 = call <16 x i8> @llvm.vp.load.v16i8(ptr undef, <16 x i1> undef, i32 undef)
-  %t7 = load <16 x i8>, ptr undef
-  %t8 = call <2 x i64> @llvm.vp.load.v2i64(ptr undef, <2 x i1> undef, i32 undef)
-  %t9 = load <2 x i64>, ptr undef
-  %t10 = call <4 x i64> @llvm.vp.load.v4i64(ptr undef, <4 x i1> undef, i32 undef)
-  %t12 = load <4 x i64>, ptr undef
-  %t13 = call <8 x i64> @llvm.vp.load.v8i64(ptr undef, <8 x i1> undef, i32 undef)
-  %t14 = load <8 x i64>, ptr undef
-  %t15 = call <16 x i64> @llvm.vp.load.v16i64(ptr undef, <16 x i1> undef, i32 undef)
-  %t16 = load <16 x i64>, ptr undef
-  %t17 = call <vscale x 2 x i8> @llvm.vp.load.nv2i8(ptr undef, <vscale x 2 x i1> undef, i32 undef)
-  %t18 = load <vscale x 2 x i8>, ptr undef
-  %t19 = call <vscale x 4 x i8> @llvm.vp.load.nv4i8(ptr undef, <vscale x 4 x i1> undef, i32 undef)
-  %t20 = load <vscale x 4 x i8>, ptr undef
-  %t21 = call <vscale x 8 x i8> @llvm.vp.load.nv8i8(ptr undef, <vscale x 8 x i1> undef, i32 undef)
-  %t22 = load <vscale x 8 x i8>, ptr undef
-  %t23 = call <vscale x 16 x i8> @llvm.vp.load.nv16i8(ptr undef, <vscale x 16 x i1> undef, i32 undef)
-  %t24 = load <vscale x 16 x i8>, ptr undef
-  %t25 = call <vscale x 2 x i64> @llvm.vp.load.nv2i64(ptr undef, <vscale x 2 x i1> undef, i32 undef)
-  %t26 = load <vscale x 2 x i64>, ptr undef
-  %t27 = call <vscale x 4 x i64> @llvm.vp.load.nv4i64(ptr undef, <vscale x 4 x i1> undef, i32 undef)
-  %t28 = load <vscale x 4 x i64>, ptr undef
-  %t29 = call <vscale x 8 x i64> @llvm.vp.load.nv8i64(ptr undef, <vscale x 8 x i1> undef, i32 undef)
-  %t30 = load <vscale x 8 x i64>, ptr undef
-  %t31 = call <vscale x 16 x i64> @llvm.vp.load.nv16i64(ptr undef, <vscale x 16 x i1> undef, i32 undef)
-  %t32 = load <vscale x 16 x i64>, ptr undef
-  ret void
-}
-
-define void @store() {
-; CHECK-LABEL: 'store'
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.v2i8.p0(<2 x i8> undef, ptr undef, <2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store <2 x i8> undef, ptr undef, align 2
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.v4i8.p0(<4 x i8> undef, ptr undef, <4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store <4 x i8> undef, ptr undef, align 4
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.v8i8.p0(<8 x i8> undef, ptr undef, <8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store <8 x i8> undef, ptr undef, align 8
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.v16i8.p0(<16 x i8> undef, ptr undef, <16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store <16 x i8> undef, ptr undef, align 16
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.v2i64.p0(<2 x i64> undef, ptr undef, <2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store <2 x i64> undef, ptr undef, align 16
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: call void @llvm.vp.store.v4i64.p0(<4 x i64> undef, ptr undef, <4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: store <4 x i64> undef, ptr undef, align 32
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: call void @llvm.vp.store.v8i64.p0(<8 x i64> undef, ptr undef, <8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: store <8 x i64> undef, ptr undef, align 64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: call void @llvm.vp.store.v16i64.p0(<16 x i64> undef, ptr undef, <16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: store <16 x i64> undef, ptr undef, align 128
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.nxv2i8.p0(<vscale x 2 x i8> undef, ptr undef, <vscale x 2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store <vscale x 2 x i8> undef, ptr undef, align 2
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.nxv4i8.p0(<vscale x 4 x i8> undef, ptr undef, <vscale x 4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store <vscale x 4 x i8> undef, ptr undef, align 4
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: call void @llvm.vp.store.nxv8i8.p0(<vscale x 8 x i8> undef, ptr undef, <vscale x 8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store <vscale x 8 x i8> undef, ptr undef, align 8
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: call void @llvm.vp.store.nxv16i8.p0(<vscale x 16 x i8> undef, ptr undef, <vscale x 16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: store <vscale x 16 x i8> undef, ptr undef, align 16
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: call void @llvm.vp.store.nxv2i64.p0(<vscale x 2 x i64> undef, ptr undef, <vscale x 2 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: store <vscale x 2 x i64> undef, ptr undef, align 16
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: call void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64> undef, ptr undef, <vscale x 4 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: store <vscale x 4 x i64> undef, ptr undef, align 32
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: call void @llvm.vp.store.nxv8i64.p0(<vscale x 8 x i64> undef, ptr undef, <vscale x 8 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 8 for instruction: store <vscale x 8 x i64> undef, ptr undef, align 64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 16 for instruction: call void @llvm.vp.store.nxv16i64.p0(<vscale x 16 x i64> undef, ptr undef, <vscale x 16 x i1> undef, i32 undef)
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 16 for instruction: store <vscale x 16 x i64> undef, ptr undef, align 128
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
-;
-  call void @llvm.vp.store.v2i8(<2 x i8> undef, ptr undef, <2 x i1> undef, i32 undef)
-  store <2 x i8> undef, ptr undef
-  call void @llvm.vp.store.v4i8(<4 x i8> undef, ptr undef, <4 x i1> undef, i32 undef)
-  store <4 x i8> undef, ptr undef
-  call void @llvm.vp.store.v8i8(<8 x i8> undef, ptr undef, <8 x i1> undef, i32 undef)
-  store <8 x i8> undef, ptr undef
-  call void @llvm.vp.store.v16i8(<16 x i8> undef, ptr undef, <16 x i1> undef, i32 undef)
-  store <16 x i8> undef, ptr undef
-  call void @llvm.vp.store.v2i64(<2 x i64> undef, ptr undef, <2 x i1> undef, i32 undef)
-  store <2 x i64> undef, ptr undef
-  call void @llvm.vp.store.v4i64(<4 x i64> undef, ptr undef, <4 x i1> undef, i32 undef)
-  store <4 x i64> undef, ptr undef
-  call void @llvm.vp.store.v8i64(<8 x i64> undef, ptr undef, <8 x i1> undef, i32 undef)
-  store <8 x i64> undef, ptr undef
-  call void @llvm.vp.store.v16i64(<16 x i64> undef, ptr undef, <16 x i1> undef, i32 undef)
-  store <16 x i64> undef, ptr undef
-  call void @llvm.vp.store.nv2i8(<vscale x 2 x i8> undef, ptr undef, <vscale x 2 x i1> undef, i32 undef)
-  store <vscale x 2 x i8> undef, ptr undef
-  call void @llvm.vp.store.nv4i8(<vscale x 4 x i8> undef, ptr undef, <vscale x 4 x i1> undef, i32 undef)
-  store <vscale x 4 x i8> undef, ptr undef
-  call void @llvm.vp.store.nv8i8(<vscale x 8 x i8> undef, ptr undef, <vscale x 8 x i1> undef, i32 undef)
-  store <vscale x 8 x i8> undef, ptr undef
-  call void @llvm.vp.store.nv16i8(<vscale x 16 x i8> undef, ptr undef, <vscale x 16 x i1> undef, i32 undef)
-  store <vscale x 16 x i8> undef, ptr undef
-  call void @llvm.vp.store.nv2i64(<vscale x 2 x i64> undef, ptr undef, <vscale x 2 x i1> undef, i32 undef)
-  store <vscale x 2 x i64> undef, ptr undef
-  call void @llvm.vp.store.nv4i64(<vscale x 4 x i64> undef, ptr undef, <vscale x 4 x i1> undef, i32 undef)
-  store <vscale x 4 x i64> undef, ptr undef
-  call void @llvm.vp.store.nv8i64(<vscale x 8 x i64> undef, ptr undef, <vscale x 8 x i1> undef, i32 undef)
-  store <vscale x 8 x i64> undef, ptr undef
-  call void @llvm.vp.store.nv16i64(<vscale x 16 x i64> undef, ptr undef, <vscale x 16 x i1> undef, i32 undef)
-  store <vscale x 16 x i64> undef, ptr undef
-  ret void
-}
-
-declare <2 x i8> @llvm.vp.add.v2i8(<2 x i8>, <2 x i8>, <2 x i1>, i32)
-declare <4 x i8> @llvm.vp.add.v4i8(<4 x i8>, <4 x i8>, <4 x i1>, i32)
-declare <8 x i8> @llvm.vp.add.v8i8(<8 x i8>, <8 x i8>, <8 x i1>, i32)
-declare <16 x i8> @llvm.vp.add.v16i8(<16 x i8>, <16 x i8>, <16 x i1>, i32)
-declare <2 x i64> @llvm.vp.add.v2i64(<2 x i64>, <2 x i64>, <2 x i1>, i32)
-declare <4 x i64> @llvm.vp.add.v4i64(<4 x i64>, <4 x i64>, <4 x i1>, i32)
-declare <8 x i64> @llvm.vp.add.v8i64(<8 x i64>, <8 x i64>, <8 x i1>, i32)
-declare <16 x i64> @llvm.vp.add.v16i64(<16 x i64>, <16 x i64>, <16 x i1>, i32)
-declare <vscale x 2 x i8> @llvm.vp.add.nv2i8(<vscale x 2 x i8>, <vscale x 2 x i8>, <vscale x 2 x i1>, i32)
-declare <vscale x 4 x i8> @llvm.vp.add.nv4i8(<vscale x 4 x i8>, <vscale x 4 x i8>, <vscale x 4 x i1>, i32)
-declare <vscale x 8 x i8> @llvm.vp.add.nv8i8(<vscale x 8 x i8>, <vscale x 8 x i8>, <vscale x 8 x i1>, i32)
-declare <vscale x 16 x i8> @llvm.vp.add.nv16i8(<vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i1>, i32)
-declare <vscale x 2 x i64> @llvm.vp.add.nv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i1>, i32)
-declare <vscale x 4 x i64> @llvm.vp.add.nv4i64(<vscale x 4 x i64>, <vscale x 4 x i64>, <vscale x 4 x i1>, i32)
-declare <vscale x 8 x i64> @llvm.vp.add.nv8i64(<vscale x 8 x i64>, <vscale x 8 x i64>, <vscale x 8 x i1>, i32)
-declare <vscale x 16 x i64> @llvm.vp.add.nv16i64(<vscale x 16 x i64>, <vscale x 16 x i64>, <vscale x 16 x i1>, i32)
-
-declare <2 x i8> @llvm.vp.abs.v2i8(<2 x i8>, i1, <2 x i1>, i32)
-declare <4 x i8> @llvm.vp.abs.v4i8(<4 x i8>, i1, <4 x i1>, i32)
-declare <8 x i8> @llvm.vp.abs.v8i8(<8 x i8>, i1, <8 x i1>, i32)
-declare <16 x i8> @llvm.vp.abs.v16i8(<16 x i8>, i1, <16 x i1>, i32)
-declare <2 x i64> @llvm.vp.abs.v2i64(<2 x i64>, i1, <2 x i1>, i32)
-declare <4 x i64> @llvm.vp.abs.v4i64(<4 x i64>, i1, <4 x i1>, i32)
-declare <8 x i64> @llvm.vp.abs.v8i64(<8 x i64>, i1, <8 x i1>, i32)
-declare <16 x i64> @llvm.vp.abs.v16i64(<16 x i64>, i1, <16 x i1>, i32)
-declare <vscale x 2 x i8> @llvm.vp.abs.nv2i8(<vscale x 2 x i8>, i1, <vscale x 2 x i1>, i32)
-declare <vscale x 4 x i8> @llvm.vp.abs.nv4i8(<vscale x 4 x i8>, i1, <vscale x 4 x i1>, i32)
-declare <vscale x 8 x i8> @llvm.vp.abs.nv8i8(<vscale x 8 x i8>, i1, <vscale x 8 x i1>, i32)
-declare <vscale x 16 x i8> @llvm.vp.abs.nv16i8(<vscale x 16 x i8>, i1, <vscale x 16 x i1>, i32)
-declare <vscale x 2 x i64> @llvm.vp.abs.nv2i64(<vscale x 2 x i64>, i1, <vscale x 2 x i1>, i32)
-declare <vscale x 4 x i64> @llvm.vp.abs.nv4i64(<vscale x 4 x i64>, i1, <vscale x 4 x i1>, i32)
-declare <vscale x 8 x i64> @llvm.vp.abs.nv8i64(<vscale x 8 x i64>, i1, <vscale x 8 x i1>, i32)
-declare <vscale x 16 x i64> @llvm.vp.abs.nv16i64(<vscale x 16 x i64>, i1, <vscale x 16 x i1>, i32)
-
-declare <2 x i8> @llvm.abs.v2i8(<2 x i8>, i1)
-declare <4 x i8> @llvm.abs.v4i8(<4 x i8>, i1)
-declare <8 x i8> @llvm.abs.v8i8(<8 x i8>, i1)
-declare <16 x i8> @llvm.abs.v16i8(<16 x i8>, i1)
-declare <2 x i64> @llvm.abs.v2i64(<2 x i64>, i1)
-declare <4 x i64> @llvm.abs.v4i64(<4 x i64>, i1)
-declare <8 x i64> @llvm.abs.v8i64(<8 x i64>, i1)
-declare <16 x i64> @llvm.abs.v16i64(<16 x i64>, i1)
-declare <vscale x 2 x i8> @llvm.abs.nv2i8(<vscale x 2 x i8>, i1)
-declare <vscale x 4 x i8> @llvm.abs.nv4i8(<vscale x 4 x i8>, i1)
-declare <vscale x 8 x i8> @llvm.abs.nv8i8(<vscale x 8 x i8>, i1)
-declare <vscale x 16 x i8> @llvm.abs.nv16i8(<vscale x 16 x i8>, i1)
-declare <vscale x 2 x i64> @llvm.abs.nv2i64(<vscale x 2 x i64>, i1)
-declare <vscale x 4 x i64> @llvm.abs.nv4i64(<vscale x 4 x i64>, i1)
-declare <vscale x 8 x i64> @llvm.abs.nv8i64(<vscale x 8 x i64>, i1)
-declare <vscale x 16 x i64> @llvm.abs.nv16i64(<vscale x 16 x i64>, i1)
-
-declare <2 x i8> @llvm.vp.load.v2i8(ptr, <2 x i1>, i32)
-declare <4 x i8> @llvm.vp.load.v4i8(ptr, <4 x i1>, i32)
-declare <8 x i8> @llvm.vp.load.v8i8(ptr, <8 x i1>, i32)
-declare <16 x i8> @llvm.vp.load.v16i8(ptr, <16 x i1>, i32)
-declare <2 x i64> @llvm.vp.load.v2i64(ptr, <2 x i1>, i32)
-declare <4 x i64> @llvm.vp.load.v4i64(ptr, <4 x i1>, i32)
-declare <8 x i64> @llvm.vp.load.v8i64(ptr, <8 x i1>, i32)
-declare <16 x i64> @llvm.vp.load.v16i64(ptr, <16 x i1>, i32)
-declare <vscale x 2 x i8> @llvm.vp.load.nv2i8(ptr, <vscale x 2 x i1>, i32)
-declare <vscale x 4 x i8> @llvm.vp.load.nv4i8(ptr, <vscale x 4 x i1>, i32)
-declare <vscale x 8 x i8> @llvm.vp.load.nv8i8(ptr, <vscale x 8 x i1>, i32)
-declare <vscale x 16 x i8> @llvm.vp.load.nv16i8(ptr, <vscale x 16 x i1>, i32)
-declare <vscale x 2 x i64> @llvm.vp.load.nv2i64(ptr, <vscale x 2 x i1>, i32)
-declare <vscale x 4 x i64> @llvm.vp.load.nv4i64(ptr, <vscale x 4 x i1>, i32)
-declare <vscale x 8 x i64> @llvm.vp.load.nv8i64(ptr, <vscale x 8 x i1>, i32)
-declare <vscale x 16 x i64> @llvm.vp.load.nv16i64(ptr, <vscale x 16 x i1>, i32)
-
-declare void @llvm.vp.store.v2i8(<2 x i8>, ptr, <2 x i1>, i32)
-declare void @llvm.vp.store.v4i8(<4 x i8>, ptr, <4 x i1>, i32)
-declare void @llvm.vp.store.v8i8(<8 x i8>, ptr, <8 x i1>, i32)
-declare void @llvm.vp.store.v16i8(<16 x i8>, ptr, <16 x i1>, i32)
-declare void @llvm.vp.store.v2i64(<2 x i64>, ptr, <2 x i1>, i32)
-declare void @llvm.vp.store.v4i64(<4 x i64>, ptr, <4 x i1>, i32)
-declare void @llvm.vp.store.v8i64(<8 x i64>, ptr, <8 x i1>, i32)
-declare void @llvm.vp.store.v16i64(<16 x i64>, ptr, <16 x i1>, i32)
-declare void @llvm.vp.store.nv2i8(<vscale x 2 x i8>, ptr, <vscale x 2 x i1>, i32)
-declare void @llvm.vp.store.nv4i8(<vscale x 4 x i8>, ptr, <vscale x 4 x i1>, i32)
-declare void @llvm.vp.store.nv8i8(<vscale x 8 x i8>, ptr, <vscale x 8 x i1>, i32)
-declare void @llvm.vp.store.nv16i8(<vscale x 16 x i8>, ptr, <vscale x 16 x i1>, i32)
-declare void @llvm.vp.store.nv2i64(<vscale x 2 x i64>, ptr, <vscale x 2 x i1>, i32)
-declare void @llvm.vp.store.nv4i64(<vscale x 4 x i64>, ptr, <vscale x 4 x i1>, i32)
-declare void @llvm.vp.store.nv8i64(<vscale x 8 x i64>, ptr, <vscale x 8 x i1>, i32)
-declare void @llvm.vp.store.nv16i64(<vscale x 16 x i64>, ptr, <vscale x 16 x i1>, i32)
-
 declare <vscale x 1 x i32> @llvm.fshr.nxv4i32(<vscale x 1 x i32> %a, <vscale x 1 x i32> %b, <vscale x 1 x i32> %c)
 declare <vscale x 1 x i32> @llvm.fshl.nxv4i32(<vscale x 1 x i32> %a, <vscale x 1 x i32> %b, <vscale x 1 x i32> %c)
 
+
 declare <vscale x 4 x float> @llvm.pow.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>)
 declare <vscale x 4 x float> @llvm.powi.nxv4f32.i32(<vscale x 4 x float>, i32)
 declare <vscale x 4 x float> @llvm.nearbyint.nxv4f32(<vscale x 4 x float>)

>From 000310d8a56b6cef37b05145c5056ef94c2a9e2e Mon Sep 17 00:00:00 2001
From: Stella Laurenzo <stellaraccident at gmail.com>
Date: Fri, 6 Oct 2023 12:22:50 -0700
Subject: [PATCH 2/5] Revert "[mlir][tosa] Align `shift` attribute of
 `TOSA_MulOp` with the spec (#67816)"

This reverts commit 363c617aac9d5e8058be549a3c59e3b085c09a54.

Temporarily reverting TOSA asm format change to let integrations catch up.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  2 +-
 .../TosaToLinalg/tosa-to-linalg.mlir          |  8 ++++----
 mlir/test/Dialect/Tosa/broadcast.mlir         |  2 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 10 +++++-----
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  | 20 +++++++++----------
 mlir/test/Dialect/Tosa/ops.mlir               |  4 ++--
 .../Tosa/tosa-decompose-depthwise.mlir        |  4 ++--
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 12 +++++------
 8 files changed, 31 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f4d9a251fb97839..e7da35a0c8145ac 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -798,7 +798,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    I8Attr:$shift
+    I32Attr:$shift
   );
 
   let results = (outs
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 8e0307085f1ce26..b08f4969ef50813 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -447,7 +447,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: arith.mulf
-  %4 = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %4 = tosa.mul %0, %1 {shift = 0 : i32} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
@@ -570,7 +570,7 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
   // CHECK: arith.extsi
   // CHECK: arith.extsi
   // CHECK: arith.muli
-  %0 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+  %0 = tosa.mul %arg0, %arg0 {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
 
   return
 }
@@ -598,12 +598,12 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: arith.muli
-  %2 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %2 = tosa.mul %arg0, %arg0 {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: arith.constant 2
   // CHECK: apply_scale
-  %3 = tosa.mul %arg0, %arg0 {shift = 2 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %3 = tosa.mul %arg0, %arg0 {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: arith.divsi
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 7613aa3b8dd03d1..5dfd6433f5e3730 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -174,7 +174,7 @@ func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>)
 func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
   // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
   // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
+  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i32 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
   return %0 : tensor<17x16x15x14xi32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 323864ea9013048..5ed5a383c6f6fea 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -205,7 +205,7 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %ones {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -214,7 +214,7 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %ones, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %ones, %arg0 {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -223,7 +223,7 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
-  %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %1 = tosa.mul %arg0, %ones {shift = 0 : i32} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
   return %1 : tensor<2x3xi32>
 }
 
@@ -232,11 +232,11 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
   // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
   // CHECK-NOT: tosa.mul
   %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %zeros {shift = 0 : i32} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
 
   // CHECK-NOT: tosa.mul
   // CHECK: return %[[ZERO]], %[[ZERO]]
-  %2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %2 = tosa.mul %zeros, %arg0 {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 56619fbc560e5fa..e66082d83cb907e 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -213,7 +213,7 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
 func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %arg0, %zero {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -224,7 +224,7 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %zero, %arg0 {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -235,7 +235,7 @@ func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
-  %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %mul = tosa.mul %arg0, %zero {shift = 0 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<i32>
 }
@@ -246,7 +246,7 @@ func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
-  %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %mul = tosa.mul %zero, %arg0 {shift = 0 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<i32>
 }
@@ -256,7 +256,7 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 // CHECK-LABEL: @fold_mul_one_rhs_f32
 func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %mul = tosa.mul %arg0, %one {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %arg0, %one {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %arg0
   return %mul : tensor<f32>
 }
@@ -266,7 +266,7 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @fold_mul_one_lhs_f32
 func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %one = "tosa.const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
-  %mul = tosa.mul %one, %arg0 {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %one, %arg0 {shift = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %arg0
   return %mul : tensor<f32>
 }
@@ -276,7 +276,7 @@ func.func @fold_mul_one_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @fold_mul_one_rhs_i32
 func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
-  %mul = tosa.mul %arg0, %one {shift = 6 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %mul = tosa.mul %arg0, %one {shift = 6 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %arg0
   return %mul : tensor<i32>
 }
@@ -286,7 +286,7 @@ func.func @fold_mul_one_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 // CHECK-LABEL: @fold_mul_one_lhs_i32
 func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %one = "tosa.const"() {value = dense<64> : tensor<i32>} : () -> tensor<i32>
-  %mul = tosa.mul %one, %arg0 {shift = 6 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %mul = tosa.mul %one, %arg0 {shift = 6 : i32} : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %arg0
   return %mul : tensor<i32>
 }
@@ -297,7 +297,7 @@ func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 func.func @fold_mul_splat_i8() -> tensor<10xi32> {
   %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
   %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
-  %mul = tosa.mul %one, %two {shift = 3 : i8} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32>
+  %mul = tosa.mul %one, %two {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32>
   // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>}
   // CHECK: return %[[THREE]]
   return %mul : tensor<10xi32>
@@ -309,7 +309,7 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> {
 func.func @fold_mul_splat_f32() -> tensor<10xf32> {
   %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32>
   %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %mul = tosa.mul %one, %two {shift = 0 : i8} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  %mul = tosa.mul %one, %two {shift = 0 : i32} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
   // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>}
   // CHECK: return %[[THREE]]
   return %mul : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 7d7f2d31a4244cd..754843969ef8ef5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -224,14 +224,14 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i32} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> {
-  %0 = "tosa.mul"(%arg0, %arg1)  { shift = 1 : i8 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
+  %0 = "tosa.mul"(%arg0, %arg1)  { shift = 1 : i32 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index c86bf5d056f85ee..b3aed8ae84033e4 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -34,7 +34,7 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
   // CHECK: %[[sIn:.+]] = tosa.sub %[[cIn]], %[[iZp]]
   // CHECK: %[[sWe:.+]] = tosa.sub %[[cWe]], %[[wZp]]
   // CHECK: %[[resWe:.+]] = tosa.reshape %[[sWe]] {new_shape = array<i64: 1, 1, 1, 2, 3>}
-  // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]] {shift = 0 : i8}
+  // CHECK: %[[mul:.+]] = tosa.mul %[[sIn]], %[[resWe]] {shift = 0 : i32}
   // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 10, 10, 6>}
   // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
   // CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]]
@@ -51,7 +51,7 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
   // CHECK: %[[reIn:.+]] = tosa.reshape %arg0 {new_shape = array<i64: 4, 10, 10, 2, 1>}
   // CHECK: %[[padded:.+]] = tosa.pad %[[reIn]], %[[pad]], %[[zero]] : (tensor<4x10x10x2x1xf32>, tensor<5x2xi64>, tensor<f32>) -> tensor<4x12x12x2x1xf32>
   // CHECK: %[[reArg1:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 1, 1, 1, 2, 3>}
-  // CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i8}
+  // CHECK: %[[mul:.+]] = tosa.mul %3, %[[reArg1]] {shift = 0 : i32}
   // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 12, 12, 6>}
   // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
   // CHECK: %[[add:.+]] = tosa.add %[[reOut]], %[[reArg2]]
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ce4defcf4a6e65..d468ba582483cbe 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -114,8 +114,8 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<f32>) ->
   // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
   %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
 
-  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
-  %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
+  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i32} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
+  %3 = tosa.mul %arg0, %arg1 { shift = 0 : i32 } : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
 
   // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
   %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<f32>) -> tensor<*xf32>
@@ -148,8 +148,8 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32
   // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
-  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
-  %3 = tosa.mul %arg0, %arg1 { shift = 0 : i8 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
+  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i32} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
+  %3 = tosa.mul %arg0, %arg1 { shift = 0 : i32 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
 
   // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
   %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
@@ -206,8 +206,8 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<i32>) -> () {
   // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
   %10 = tosa.minimum %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
 
-  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
-  %11 = tosa.mul %arg0, %arg1 { shift = 0 : i8 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
+  // CHECK: tosa.mul %arg0, %arg1 {shift = 0 : i32} : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
+  %11 = tosa.mul %arg0, %arg1 { shift = 0 : i32 }: (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>
 
   // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
   %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<i32>) -> tensor<*xi32>

>From d94b09d8878afbf6bc8de69010af37067180de67 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 12 Oct 2023 18:17:37 -0700
Subject: [PATCH 3/5] [mlir][affine] ValueBoundsConstraintSet: Fully compose
 affine.apply (#68899)

Fully compose `affine.apply` ops before adding them to the underlying
`FlatLinearConstraints`. This works around a limitation of
`FlatLinearConstraints`, which cannot deduce a constant bound if it
involves two identical local variables.

Details for future improvements of `FlatLinearConstraints`: The
constraint set infrastructure fails to compute a constant bound of -8
for the first variable.
```
Domain: 0, Range: 1, Symbols: 4, Locals: 2
8 constraints
(None    None    None    None    None    Local    Local    const)
 1    -1    0    0    0    0    0    0    = 0
 0    1    -1    1    0    0    0    0    = 0
 0    0    1    0    0    0    -16    0    = 0
 0    0    0    1    0    -16    0    -8    = 0
 0    0    0    0    -1    0    32    31    >= 0
 0    0    0    0    1    0    -32    0    >= 0
 0    0    0    0    -1    32    0    31    >= 0
 0    0    0    0    1    -32    0    0    >= 0
```
---
 .../Affine/IR/ValueBoundsOpInterfaceImpl.h    | 14 ++++++
 .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp  | 47 +++++++++++++++++--
 .../value-bounds-op-interface-impl.mlir       | 32 +++++++++++++
 .../Dialect/Affine/TestReifyValueBounds.cpp   | 11 ++++-
 4 files changed, 97 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
index 2abbabc5bb2868c..5d4774861bdfd37 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
@@ -9,11 +9,25 @@
 #ifndef MLIR_DIALECT_AFFINE_IR_VALUEBOUNDSOPINTERFACEIMPL_H
 #define MLIR_DIALECT_AFFINE_IR_VALUEBOUNDSOPINTERFACEIMPL_H
 
+#include "mlir/Support/LogicalResult.h"
+
 namespace mlir {
 class DialectRegistry;
+class Value;
 
 namespace affine {
 void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
+
+/// Compute whether the given values are equal. Return "failure" if equality
+/// could not be determined. `value1`/`value2` must be index-typed.
+///
+/// This function is similar to `ValueBoundsConstraintSet::areEqual`. To work
+/// around limitations in `FlatLinearConstraints`, this function fully composes
+/// `value1` and `value2` (if they are the result of affine.apply ops) before
+/// populating the constraint set. The folding/composing logic can see
+/// opportunities for simplifications that the constraint set implementation
+/// cannot see.
+FailureOr<bool> fullyComposeAndCheckIfEqual(Value value1, Value value2);
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index 97dd70e4f1d2b7e..d47c8eb8ccb4272 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -27,12 +27,22 @@ struct AffineApplyOpInterface
     assert(applyOp.getAffineMap().getNumResults() == 1 &&
            "expected single result");
 
+    // Fully compose this affine.apply with other ops because the folding logic
+    // can see opportunities for simplifying the affine map that
+    // `FlatLinearConstraints` can currently not see.
+    AffineMap map = applyOp.getAffineMap();
+    SmallVector<Value> operands = llvm::to_vector(applyOp.getOperands());
+    fullyComposeAffineMapAndOperands(&map, &operands);
+
     // Align affine map result with dims/symbols in the constraint set.
-    AffineExpr expr = applyOp.getAffineMap().getResult(0);
-    SmallVector<AffineExpr> dimReplacements = llvm::to_vector(llvm::map_range(
-        applyOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); }));
-    SmallVector<AffineExpr> symReplacements = llvm::to_vector(llvm::map_range(
-        applyOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); }));
+    AffineExpr expr = map.getResult(0);
+    SmallVector<AffineExpr> dimReplacements, symReplacements;
+    for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
+      dimReplacements.push_back(cstr.getExpr(operands[i]));
+    for (int64_t i = map.getNumDims(),
+                 e = map.getNumDims() + map.getNumSymbols();
+         i < e; ++i)
+      symReplacements.push_back(cstr.getExpr(operands[i]));
     AffineExpr bound =
         expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
     cstr.bound(value) == bound;
@@ -92,3 +102,30 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
     AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
   });
 }
+
+FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
+                                                          Value value2) {
+  assert(value1.getType().isIndex() && "expected index type");
+  assert(value2.getType().isIndex() && "expected index type");
+
+  // Subtract the two values/dimensions from each other. If the result is 0,
+  // both are equal.
+  Builder b(value1.getContext());
+  AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
+                                 b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
+  // Fully compose the affine map with other ops because the folding logic
+  // can see opportunities for simplifying the affine map that
+  // `FlatLinearConstraints` can currently not see.
+  SmallVector<Value> mapOperands;
+  mapOperands.push_back(value1);
+  mapOperands.push_back(value2);
+  affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
+  ValueDimList valueDims;
+  for (Value v : mapOperands)
+    valueDims.push_back({v, std::nullopt});
+  FailureOr<int64_t> bound = ValueBoundsConstraintSet::computeConstantBound(
+      presburger::BoundType::EQ, map, valueDims);
+  if (failed(bound))
+    return failure();
+  return *bound == 0;
+}
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 338c48c5b210bc1..8acf358c887a987 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -58,3 +58,35 @@ func.func @affine_min_lb(%a: index) -> (index) {
   %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index)
   return %2 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @composed_affine_apply(
+//       CHECK:   %[[cst:.*]] = arith.constant -8 : index
+//       CHECK:   return %[[cst]]
+func.func @composed_affine_apply(%i1 : index) -> (index) {
+  // The ValueBoundsOpInterface implementation of affine.apply fully composes
+  // the affine map (and its operands) with other affine.apply ops drawn from
+  // its operands before adding it to the constraint set. This is to work
+  // around a limitation in `FlatLinearConstraints`, which can currently not
+  // compute a constant bound for %s. (The affine map simplification logic can
+  // simplify %s to -8.)
+  %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
+  %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+  %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
+  %reified = "test.reify_constant_bound"(%s) {type = "EQ"} : (index) -> (index)
+  return %reified : index
+}
+
+
+// -----
+
+// Test for affine::fullyComposeAndCheckIfEqual
+func.func @composed_are_equal(%i1 : index) {
+  %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
+  %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+  %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
+  // expected-remark @below{{different}}
+   "test.are_equal"(%i2, %i3) {compose} : (index, index) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index ad017cef1b9bace..6e3c3dff759a2ed 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Arith/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -186,8 +187,14 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
         op->emitOpError("invalid op");
         return WalkResult::skip();
       }
-      FailureOr<bool> equal = ValueBoundsConstraintSet::areEqual(
-          op->getOperand(0), op->getOperand(1));
+      FailureOr<bool> equal = failure();
+      if (op->hasAttr("compose")) {
+        equal = affine::fullyComposeAndCheckIfEqual(op->getOperand(0),
+                                                    op->getOperand(1));
+      } else {
+        equal = ValueBoundsConstraintSet::areEqual(op->getOperand(0),
+                                                   op->getOperand(1));
+      }
       if (failed(equal)) {
         op->emitError("could not determine equality");
       } else if (*equal) {

>From 1151cc387a4ff519f2fc70d37137505a448df829 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sun, 15 Oct 2023 19:37:54 -0400
Subject: [PATCH 4/5] [mlir][vector] Enable transfer op hoisting with dynamic
 indices (#68500)

Recent changes (https://github.com/llvm/llvm-project/pull/66930)
disabled vector transfer ops hoisting with view-like intermediate ops.
The recommended way is to fold subview ops into transfer op indices
before invoking hoisting. That would mean now we see transfer op indices
involving dynamic values, instead of static constant values before with
subview ops. Therefore hoisting won't kick in anymore. This breaks
downstream users.

To fix it, this commit enables hoisting transfer ops with dynamic
indices by using `ValueBoundsConstraintSet` to prove ranges are disjoint
in `isDisjointTransferIndices`. Given that utility is used in many
places including op folders, right now we introduce a flag to it and
only set as true for "heavy" transforms in hoisting and load-store
forwarding.
---
 .../Affine/IR/ValueBoundsOpInterfaceImpl.h    |  12 +-
 .../mlir/Dialect/Vector/IR/VectorOps.h        |  19 ++-
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  10 ++
 .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp  |   9 +-
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  12 +-
 mlir/lib/Dialect/Vector/IR/CMakeLists.txt     |   2 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  65 +++++++--
 .../Transforms/VectorTransferOpTransforms.cpp |   6 +-
 .../lib/Interfaces/ValueBoundsOpInterface.cpp |  27 ++--
 mlir/test/Dialect/Linalg/hoisting.mlir        | 132 ++++++++++++++++++
 .../Dialect/Vector/vector-transferop-opt.mlir | 104 ++++++++++++++
 .../Dialect/Affine/TestReifyValueBounds.cpp   |  30 ++--
 .../llvm-project-overlay/mlir/BUILD.bazel     |   2 +
 13 files changed, 370 insertions(+), 60 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
index 5d4774861bdfd37..6e617ef40a53d7d 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h
@@ -18,16 +18,18 @@ class Value;
 namespace affine {
 void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
 
-/// Compute whether the given values are equal. Return "failure" if equality
-/// could not be determined. `value1`/`value2` must be index-typed.
+/// Compute a constant delta of the given two values. Return "failure" if we
+/// cannot determine a constant delta. `value1`/`value2` must be index-typed.
 ///
-/// This function is similar to `ValueBoundsConstraintSet::areEqual`. To work
-/// around limitations in `FlatLinearConstraints`, this function fully composes
+/// This function is similar to
+/// `ValueBoundsConstraintSet::computeConstantDistance`. To work around
+/// limitations in `FlatLinearConstraints`, this function fully composes
 /// `value1` and `value2` (if they are the result of affine.apply ops) before
 /// populating the constraint set. The folding/composing logic can see
 /// opportunities for simplifications that the constraint set implementation
 /// cannot see.
-FailureOr<bool> fullyComposeAndCheckIfEqual(Value value1, Value value2);
+FailureOr<int64_t> fullyComposeAndComputeConstantDelta(Value value1,
+                                                       Value value2);
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index fc0c80036ff79ad..9ab20e20d975429 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -105,16 +105,23 @@ bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
 /// op.
 bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
 
-/// Same behavior as `isDisjointTransferSet` but doesn't require the operations
-/// to have the same tensor/memref. This allows comparing operations accessing
-/// different tensors.
+/// Return true if we can prove that the transfer operations access disjoint
+/// memory, without requring the accessed tensor/memref to be the same.
+///
+/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
+/// via ValueBoundsOpInterface.
 bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
-                               VectorTransferOpInterface transferB);
+                               VectorTransferOpInterface transferB,
+                               bool testDynamicValueUsingBounds = false);
 
 /// Return true if we can prove that the transfer operations access disjoint
-/// memory.
+/// memory, requiring the operations to access the same tensor/memref.
+///
+/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
+/// via ValueBoundsOpInterface.
 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
-                           VectorTransferOpInterface transferB);
+                           VectorTransferOpInterface transferB,
+                           bool testDynamicValueUsingBounds = false);
 
 /// Return the result value of reducing two scalar/vector values with the
 /// corresponding arith operation.
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 2687d79aec68ebb..8f11c563e0cbd91 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -176,6 +176,16 @@ class ValueBoundsConstraintSet {
       presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
       StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
+  /// Compute a constant delta between the given two values. Return "failure"
+  /// if a constant delta could not be determined.
+  ///
+  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
+  /// index-typed.
+  static FailureOr<int64_t>
+  computeConstantDelta(Value value1, Value value2,
+                       std::optional<int64_t> dim1 = std::nullopt,
+                       std::optional<int64_t> dim2 = std::nullopt);
+
   /// Compute whether the given values/dimensions are equal. Return "failure" if
   /// equality could not be determined.
   ///
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index d47c8eb8ccb4272..e0c3abe7a0f71d1 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -103,8 +103,8 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
   });
 }
 
-FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
-                                                          Value value2) {
+FailureOr<int64_t>
+mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
   assert(value1.getType().isIndex() && "expected index type");
   assert(value2.getType().isIndex() && "expected index type");
 
@@ -123,9 +123,6 @@ FailureOr<bool> mlir::affine::fullyComposeAndCheckIfEqual(Value value1,
   ValueDimList valueDims;
   for (Value v : mapOperands)
     valueDims.push_back({v, std::nullopt});
-  FailureOr<int64_t> bound = ValueBoundsConstraintSet::computeConstantBound(
+  return ValueBoundsConstraintSet::computeConstantBound(
       presburger::BoundType::EQ, map, valueDims);
-  if (failed(bound))
-    return failure();
-  return *bound == 0;
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 221bec713b38aa3..cbb2c507de69f9e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -173,16 +173,16 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
         if (auto transferWriteUse =
                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
           if (!vector::isDisjointTransferSet(
-                  cast<VectorTransferOpInterface>(transferWrite.getOperation()),
-                  cast<VectorTransferOpInterface>(
-                      transferWriteUse.getOperation())))
+                  cast<VectorTransferOpInterface>(*transferWrite),
+                  cast<VectorTransferOpInterface>(*transferWriteUse),
+                  /*testDynamicValueUsingBounds=*/true))
             return WalkResult::advance();
         } else if (auto transferReadUse =
                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
           if (!vector::isDisjointTransferSet(
-                  cast<VectorTransferOpInterface>(transferWrite.getOperation()),
-                  cast<VectorTransferOpInterface>(
-                      transferReadUse.getOperation())))
+                  cast<VectorTransferOpInterface>(*transferWrite),
+                  cast<VectorTransferOpInterface>(*transferReadUse),
+                  /*testDynamicValueUsingBounds=*/true))
             return WalkResult::advance();
         } else {
           // Unknown use, we cannot prove that it doesn't alias with the
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 9ec919423b3428f..70f3fa8c297d4bc 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorDialect
   MLIRVectorAttributesIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAffineDialect
   MLIRArithDialect
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
@@ -22,5 +23,6 @@ add_mlir_dialect_library(MLIRVectorDialect
   MLIRMemRefDialect
   MLIRSideEffectInterfaces
   MLIRTensorDialect
+  MLIRValueBoundsOpInterface
   MLIRVectorInterfaces
   )
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 027ef3605aeba46..d09a226e34ad35a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
+#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -30,6 +31,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -168,39 +170,76 @@ bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
 }
 
 bool mlir::vector::isDisjointTransferIndices(
-    VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
+    VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
+    bool testDynamicValueUsingBounds) {
   // For simplicity only look at transfer of same type.
   if (transferA.getVectorType() != transferB.getVectorType())
     return false;
   unsigned rankOffset = transferA.getLeadingShapedRank();
   for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
-    auto indexA = getConstantIntValue(transferA.indices()[i]);
-    auto indexB = getConstantIntValue(transferB.indices()[i]);
-    // If any of the indices are dynamic we cannot prove anything.
-    if (!indexA.has_value() || !indexB.has_value())
-      continue;
+    Value indexA = transferA.indices()[i];
+    Value indexB = transferB.indices()[i];
+    std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
+    std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
 
     if (i < rankOffset) {
       // For leading dimensions, if we can prove that index are different we
       // know we are accessing disjoint slices.
-      if (*indexA != *indexB)
-        return true;
+      if (cstIndexA.has_value() && cstIndexB.has_value()) {
+        if (*cstIndexA != *cstIndexB)
+          return true;
+        continue;
+      }
+      if (testDynamicValueUsingBounds) {
+        // First try to see if we can fully compose and simplify the affine
+        // expression as a fast track.
+        FailureOr<uint64_t> delta =
+            affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
+        if (succeeded(delta) && *delta != 0)
+          return true;
+
+        FailureOr<bool> testEqual =
+            ValueBoundsConstraintSet::areEqual(indexA, indexB);
+        if (succeeded(testEqual) && !testEqual.value())
+          return true;
+      }
     } else {
       // For this dimension, we slice a part of the memref we need to make sure
       // the intervals accessed don't overlap.
-      int64_t distance = std::abs(*indexA - *indexB);
-      if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
-        return true;
+      int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
+      if (cstIndexA.has_value() && cstIndexB.has_value()) {
+        int64_t distance = std::abs(*cstIndexA - *cstIndexB);
+        if (distance >= vectorDim)
+          return true;
+        continue;
+      }
+      if (testDynamicValueUsingBounds) {
+        // First try to see if we can fully compose and simplify the affine
+        // expression as a fast track.
+        FailureOr<int64_t> delta =
+            affine::fullyComposeAndComputeConstantDelta(indexA, indexB);
+        if (succeeded(delta) && std::abs(*delta) >= vectorDim)
+          return true;
+
+        FailureOr<int64_t> computeDelta =
+            ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB);
+        if (succeeded(computeDelta)) {
+          if (std::abs(computeDelta.value()) >= vectorDim)
+            return true;
+        }
+      }
     }
   }
   return false;
 }
 
 bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
-                                         VectorTransferOpInterface transferB) {
+                                         VectorTransferOpInterface transferB,
+                                         bool testDynamicValueUsingBounds) {
   if (transferA.source() != transferB.source())
     return false;
-  return isDisjointTransferIndices(transferA, transferB);
+  return isDisjointTransferIndices(transferA, transferB,
+                                   testDynamicValueUsingBounds);
 }
 
 // Helper to iterate over n-D vector slice elements. Calculate the next
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 603b88f11c8e007..a5f1b28152b9bde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -142,7 +142,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
       // Don't need to consider disjoint accesses.
       if (vector::isDisjointTransferSet(
               cast<VectorTransferOpInterface>(write.getOperation()),
-              cast<VectorTransferOpInterface>(transferOp.getOperation())))
+              cast<VectorTransferOpInterface>(transferOp.getOperation()),
+              /*testDynamicValueUsingBounds=*/true))
         continue;
     }
     blockingAccesses.push_back(user);
@@ -217,7 +218,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
       // the write.
       if (vector::isDisjointTransferSet(
               cast<VectorTransferOpInterface>(write.getOperation()),
-              cast<VectorTransferOpInterface>(read.getOperation())))
+              cast<VectorTransferOpInterface>(read.getOperation()),
+              /*testDynamicValueUsingBounds=*/true))
         continue;
       if (write.getSource() == read.getSource() &&
           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index c00ee0315a9639a..ff941115219f68b 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -484,25 +484,32 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   return failure();
 }
 
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
-                                   std::optional<int64_t> dim1,
-                                   std::optional<int64_t> dim2) {
+FailureOr<int64_t>
+ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
+                                               std::optional<int64_t> dim1,
+                                               std::optional<int64_t> dim2) {
 #ifndef NDEBUG
   assertValidValueDim(value1, dim1);
   assertValidValueDim(value2, dim2);
 #endif // NDEBUG
 
-  // Subtract the two values/dimensions from each other. If the result is 0,
-  // both are equal.
   Builder b(value1.getContext());
   AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
                                  b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
-  FailureOr<int64_t> bound = computeConstantBound(
-      presburger::BoundType::EQ, map, {{value1, dim1}, {value2, dim2}});
-  if (failed(bound))
+  return computeConstantBound(presburger::BoundType::EQ, map,
+                              {{value1, dim1}, {value2, dim2}});
+}
+
+FailureOr<bool>
+ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+                                   std::optional<int64_t> dim1,
+                                   std::optional<int64_t> dim2) {
+  // Subtract the two values/dimensions from each other. If the result is 0,
+  // both are equal.
+  FailureOr<int64_t> delta = computeConstantDelta(value1, value2, dim1, dim2);
+  if (failed(delta))
     return failure();
-  return *bound == 0;
+  return *delta == 0;
 }
 
 ValueBoundsConstraintSet::BoundBuilder &
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 7d0c3648c344b1d..11bf4b58b95c82e 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -872,3 +872,135 @@ transform.sequence failures(propagate) {
   transform.structured.hoist_redundant_vector_transfers %0
     : (!transform.any_op) -> !transform.any_op
 }
+
+// -----
+
+// Test that we can hoist out 1-D read-write pairs whose indices are dynamic values.
+
+// CHECK: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 1)>
+// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+
+//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
+//    CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
+
+//         CHECK:   %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
+//         CHECK:   %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
+//         CHECK:   %2 = vector.transfer_read %[[BUFFER]][%[[I0]], %[[I0]]]
+//         CHECK:   %3 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[I0]]]
+//         CHECK:   %4 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
+// CHECK-COUNT-2:   scf.for %{{.+}} = {{.+}} -> (vector<4xf32>, vector<4xf32>, vector<4xf32>)
+// CHECK-COUNT-3:     "some_use"
+// CHECK-COUNT-2:   scf.yield {{.+}} : vector<4xf32>, vector<4xf32>, vector<4xf32>
+//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
+//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[I0]]]
+//         CHECK:   vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
+
+func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+  %cst = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
+  %i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
+
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+      // Disjoint leading dim
+      %r1 = vector.transfer_read %buffer[%i1, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+      // Non-overlap trailing dim
+      %r2 = vector.transfer_read %buffer[%i1, %i2], %cst: memref<?x?xf32>, vector<4xf32>
+      %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
+      %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
+      %u2 = "some_use"(%r2) : (vector<4xf32>) -> vector<4xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %buffer[%i1, %i0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u2, %buffer[%i1, %i2] : vector<4xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// Test that we cannot hoist out read-write pairs whose indices are overlapping.
+
+//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_overlapping_dynamic
+// CHECK-COUNT-2:   scf.for
+// CHECK-COUNT-2:     vector.transfer_read
+// CHECK-COUNT-2:     vector.transfer_write
+
+func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+  %cst = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
+
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+      // Overlapping range with the above
+      %r1 = vector.transfer_read %buffer[%i0, %i1], %cst: memref<?x?xf32>, vector<4xf32>
+      %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
+      %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %buffer[%i0, %i1] : vector<4xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// Test that we can hoist out 2-D read-write pairs whose indices are dynamic values.
+
+//   CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
+// CHECK-COUNT-3:   vector.transfer_read
+// CHECK-COUNT-2:   %{{.+}}:3 = scf.for {{.+}} -> (vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32>)
+// CHECK-COUNT-2:   scf.yield {{.+}} : vector<16x8xf32>, vector<16x8xf32>, vector<16x8xf32>
+// CHECK-COUNT-3:   vector.transfer_write
+//         CHECK:   return
+
+func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
+    %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) {
+  %cst = arith.constant 0.0 : f32
+  %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
+  %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+  %i4 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 16)>(%i1)
+
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %buffer[%i0, %i2], %cst: memref<?x?xf32>, vector<16x8xf32>
+      %r1 = vector.transfer_read %buffer[%i0, %i3], %cst: memref<?x?xf32>, vector<16x8xf32>
+      %r2 = vector.transfer_read %buffer[%i0, %i4], %cst: memref<?x?xf32>, vector<16x8xf32>
+      %u0 = "some_use"(%r0) : (vector<16x8xf32>) -> vector<16x8xf32>
+      %u1 = "some_use"(%r1) : (vector<16x8xf32>) -> vector<16x8xf32>
+      %u2 = "some_use"(%r2) : (vector<16x8xf32>) -> vector<16x8xf32>
+      vector.transfer_write %u2, %buffer[%i0, %i4] : vector<16x8xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %buffer[%i0, %i3] : vector<16x8xf32>, memref<?x?xf32>
+      vector.transfer_write %u0, %buffer[%i0, %i2] : vector<16x8xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index f43367ab4aeba7d..13957af014b89ed 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -256,3 +256,107 @@ func.func @collapse_shape(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
   }
   return
 }
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_same_index
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_store_dynamic_same_index(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  vector.transfer_write %v0, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op reads/writes to the same address so that we can forward.
+  %0 = vector.transfer_read %buffer[%i, %i], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
+
+//   CHECK-LABEL: func @dont_forward_dead_store_dynamic_overlap
+// CHECK-COUNT-2:   vector.transfer_write
+//         CHECK:   vector.transfer_read
+//         CHECK:   scf.for
+//         CHECK:   }
+//         CHECK:   vector.transfer_write
+//         CHECK:   return
+func.func @dont_forward_dead_store_dynamic_overlap(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
+  vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op writes to an overlapping range so we cannot forward.
+  vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_leading_dim
+//       CHECK:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_store_dynamic_non_overlap_leading_dim(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
+  vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op writes to an non-overlapping range so we can forward.
+  vector.transfer_write %v0, %buffer[%i1, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_trailing_dim
+//       CHECK:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
+    %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %i1 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
+  vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  // The following transfer op writes to an non-overlapping range so we can forward.
+  vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+  %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+    %1 = arith.addf %acc, %acc : vector<4xf32>
+    scf.yield %1 : vector<4xf32>
+  }
+  vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+  return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 6e3c3dff759a2ed..2f1631cbdb02e01 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -187,20 +187,26 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
         op->emitOpError("invalid op");
         return WalkResult::skip();
       }
-      FailureOr<bool> equal = failure();
       if (op->hasAttr("compose")) {
-        equal = affine::fullyComposeAndCheckIfEqual(op->getOperand(0),
-                                                    op->getOperand(1));
-      } else {
-        equal = ValueBoundsConstraintSet::areEqual(op->getOperand(0),
-                                                   op->getOperand(1));
-      }
-      if (failed(equal)) {
-        op->emitError("could not determine equality");
-      } else if (*equal) {
-        op->emitRemark("equal");
+        FailureOr<int64_t> equal = affine::fullyComposeAndComputeConstantDelta(
+            op->getOperand(0), op->getOperand(1));
+        if (failed(equal)) {
+          op->emitError("could not determine equality");
+        } else if (*equal == 0) {
+          op->emitRemark("equal");
+        } else {
+          op->emitRemark("different");
+        }
       } else {
-        op->emitRemark("different");
+        FailureOr<bool> equal = ValueBoundsConstraintSet::areEqual(
+            op->getOperand(0), op->getOperand(1));
+        if (failed(equal)) {
+          op->emitError("could not determine equality");
+        } else if (*equal) {
+          op->emitRemark("equal");
+        } else {
+          op->emitRemark("different");
+        }
       }
     }
     return WalkResult::advance();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 53b626996f8bbfa..f03f3c6737ebafa 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4321,6 +4321,7 @@ cc_library(
     ]),
     includes = ["include"],
     deps = [
+        ":AffineDialect",
         ":ArithDialect",
         ":ArithUtils",
         ":ControlFlowInterfaces",
@@ -4335,6 +4336,7 @@ cc_library(
         ":SideEffectInterfaces",
         ":Support",
         ":TensorDialect",
+        ":ValueBoundsOpInterface",
         ":VectorInterfaces",
         ":VectorAttributesIncGen",
         ":VectorDialectIncGen",

>From 1f2b48cd3a0866fff82e98efbf2d1369f90e01fd Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Thu, 21 Sep 2023 16:24:11 -0700
Subject: [PATCH 5/5] [mlir][TilingInterface] Add `scf::tileUsingSCFForallOp`
 method to tile using the interface to generate `scf::forall`.

Similar to `scf::tileUsingSCFForOp` that is a method that tiles
operations that implement the `TilingInterface`, using `scf.for`
operations, this method introduces tiling of operations using
`scf.forall`. Most of this implementation is derived from
`linalg::tileToForallOp` method. Eventually that method will either be
deprecated or moved to use the method introduced here.
---
 .../SCF/Transforms/TileUsingInterface.h       |  17 +++
 .../SCF/Transforms/TileUsingInterface.cpp     | 133 ++++++++++++++++++
 .../TilingInterface/tile-using-scfforall.mlir |  37 +++++
 .../TilingInterface/TestTilingInterface.cpp   |  69 +++++++++
 4 files changed, 256 insertions(+)
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9f49d97e141e0c8..06cce19894e9f5a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -51,6 +51,17 @@ struct SCFTilingOptions {
     interchangeVector = llvm::to_vector(interchange);
     return *this;
   }
+
+  /// Specify mapping of loops to devices. This is only respected when the loop
+  /// constructs support such a mapping (like `scf.forall`). Will be ignored
+  /// when using loop constructs that dont support such a mapping (like
+  /// `scf.for`)
+  SmallVector<Attribute> mappingVector = {};
+  SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
+    mappingVector = llvm::to_vector(
+        llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
+    return *this;
+  }
 };
 
 /// Transformation information returned after tiling.
@@ -82,6 +93,12 @@ struct SCFTileAndFuseOptions {
   }
 };
 
+/// Method to tile and op that implements the `TilingInterface` using
+/// `scf.forall`.
+FailureOr<SCFTilingResult>
+tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+                     const SCFTilingOptions &options);
+
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
 /// required slice of the producer in-place.  Note that the method
 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 96d6169111b3856..a58cd7a7541a515 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -122,6 +122,24 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
 }
 
+/// Clones the operation and updates the destination if the operation
+/// implements the `DestinationStyleOpInterface`.
+static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
+                                                  Operation *op,
+                                                  ValueRange newDestArgs) {
+  Operation *clonedOp = rewriter.clone(*op);
+  if (auto destinationStyleOp =
+          dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
+    // Note that this is assuming that
+    auto [start, end] = destinationStyleOp.getDpsInitsPositionRange();
+    assert((end - start == newDestArgs.size()) &&
+           "expected as many new destination args as number of inits of the "
+           "operation");
+    clonedOp->setOperands(start, end - start, newDestArgs);
+  }
+  return clonedOp;
+}
+
 /// Generate an empty loop nest that represents the tiled loop nest shell.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
@@ -728,6 +746,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
                                    getAsOperations(forLoops), replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+                                const scf::SCFTilingOptions &options) {
+  Location loc = op->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+
+  // 1. Get the range of loops that are represented by the operation.
+  SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+  if (loopRanges.empty())
+    return op->emitOpError("expected non-empty loop ranges");
+  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+  if (llvm::any_of(loopRanges, hasStrideOne))
+    return op->emitOpError("only stride-1 supported atm");
+
+  // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+  // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+  SmallVector<OpFoldResult> tileSizeVector =
+      options.tileSizeComputationFunction(rewriter, op);
+  tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+  // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  for (auto [index, tileSize, loopRange] :
+       llvm::enumerate(tileSizeVector, loopRanges)) {
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    lbs.push_back(loopRange.offset);
+    ubs.push_back(loopRange.size);
+    steps.push_back(tileSize);
+  }
+
+  // 4. Gather destination tensors.
+  SmallVector<Value> dest;
+  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+    return op->emitOpError("failed to get destination tensors");
+
+  // 5. Build the device mapping attribute;
+  std::optional<ArrayAttr> mappingAttr;
+  if (!options.mappingVector.empty()) {
+    mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
+  }
+
+  // 6. Create the ForallOp. We don't use the lambda body-builder
+  // version because we require the use of RewriterBase in the body, so we
+  // manually move the insertion point to the body below.
+  auto forallOp =
+      rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
+
+  // 7. Get the tile offset and sizes.
+  rewriter.setInsertionPoint(forallOp.getTerminator());
+  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+  tiledOffsets.reserve(loopRanges.size());
+  tiledSizes.reserve(loopRanges.size());
+  ValueRange ivs = forallOp.getInductionVars();
+  {
+    int materializedLoopNum = 0;
+    for (auto [index, tileSize, loopRange] :
+         llvm::enumerate(tileSizeVector, loopRanges)) {
+      if (isConstantIntValue(tileSize, 0)) {
+        tiledOffsets.push_back(loopRange.offset);
+        tiledSizes.push_back(loopRange.size);
+        continue;
+      }
+      Value iv = ivs[materializedLoopNum++];
+      tiledOffsets.push_back(iv);
+      tiledSizes.push_back(
+          getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+    }
+  }
+
+  // 8. Tile the operation. Clone the operation to allow fix up of destination
+  // operands
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  Operation *clonedOp =
+      cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
+  FailureOr<TilingResult> tilingResult =
+      cast<TilingInterface>(clonedOp).getTiledImplementation(
+          rewriter, tiledOffsets, tiledSizes);
+  if (failed(tilingResult))
+    return clonedOp->emitError("Failed to tile op: ");
+  rewriter.eraseOp(clonedOp);
+
+  // 9. Parallel insert back into the result tensor.
+  for (auto [index, tiledValue, destBBArg] :
+       llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
+    // 9.a. Partial subset information is inserted just before the terminator.
+    rewriter.setInsertionPoint(forallOp.getTerminator());
+
+    SmallVector<OpFoldResult> resultOffsets, resultSizes;
+    if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
+                                        tiledSizes, resultOffsets,
+                                        resultSizes)))
+      return op->emitOpError("output offsets couldn't be calculated");
+    SmallVector<OpFoldResult> strides(resultSizes.size(),
+                                      rewriter.getIndexAttr(1));
+
+    // 5.b. Parallel insertions are inserted at the end of the combining
+    // terminator.
+    rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
+  }
+
+  // 10. Return the tiling result;
+  return scf::SCFTilingResult{
+      tilingResult->tiledOps,
+      {forallOp.getOperation()},
+      llvm::to_vector(llvm::map_range(forallOp.getResults(),
+                                      [](auto val) -> Value { return val; }))};
+}
+
 //===----------------------------------------------------------------------===//
 // lowerToLoopsUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
new file mode 100644
index 000000000000000..bfc352c764ad11a
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
+
+func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul {__internal_transform__ = "simple_gemm"}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+//      CHECK: func.func @simple_matmul(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK:   %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
+// CHECK-SAME:       (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]])
+//      CHECK:     %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//      CHECK:     %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]]]
+//      CHECK:     %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:         [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
+//      CHECK:     %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME:         [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
+//      CHECK:     %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]]
+// CHECK-SAME:         [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:     %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:         outs(%[[INIT_TILE]] :
+//      CHECK:     scf.forall.in_parallel {
+//      CHECK:       tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
+// CHECK-SAME:           [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
+//      CHECK:   return %[[RESULT]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 2573e11979dbc47..2bec859b50f26ba 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -186,6 +186,51 @@ struct TestTileUsingSCFForOp
   TransformationFilter filter;
 };
 
+/// Pattern for testing `tileUsingSCFForallOp` (that tiles operations using
+/// the `TilingInterface` with `scf.forall` ops for iterating over the tiles)
+/// while using a `filter` to avoid recursive application.
+struct TestTileUsingSCFForallOp
+    : public OpInterfaceRewritePattern<TilingInterface> {
+  TestTileUsingSCFForallOp(MLIRContext *context, scf::SCFTilingOptions options,
+                           TransformationFilter filter = TransformationFilter(),
+                           PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        options(std::move(options)), filter(std::move(filter)) {}
+
+  /// Construct a generic pattern applied to `opName`.
+  TestTileUsingSCFForallOp(StringRef opName, MLIRContext *context,
+                           scf::SCFTilingOptions options,
+                           TransformationFilter filter = TransformationFilter(),
+                           PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        options(std::move(options)), filter(std::move(filter)) {}
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(filter.checkAndNotify(rewriter, op)))
+      return failure();
+
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        scf::tileUsingSCFForallOp(rewriter, op, options);
+    if (failed(tilingResult))
+      return rewriter.notifyMatchFailure(op, "failed to tile operation");
+
+    if (op->getNumResults()) {
+      rewriter.replaceOp(op, tilingResult->replacements);
+    } else {
+      rewriter.eraseOp(op);
+    }
+
+    for (auto *tiledOp : tilingResult->tiledOps)
+      filter.replaceTransformationFilter(rewriter, tiledOp);
+    return success();
+  }
+
+private:
+  scf::SCFTilingOptions options;
+  TransformationFilter filter;
+};
+
 /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
 /// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
 /// ops for iterating over the tiles) while using a `filter` to avoid recursive
@@ -415,6 +460,12 @@ struct TestTilingInterfacePass
           "Test tiling using TilingInterface with scf.for operations"),
       llvm::cl::init(false)};
 
+  Option<bool> testTilingForAll{
+      *this, "tile-using-scf-forall",
+      llvm::cl::desc(
+          "Test tiling using TilingInterface with scf.forall operations"),
+      llvm::cl::init(false)};
+
   Option<bool> testTileConsumerFuseAndYieldProducer{
       *this, "tile-consumer-fuse-and-yield-producer-using-scf-for",
       llvm::cl::desc(
@@ -455,6 +506,20 @@ static void addPatternForTiling(MLIRContext *context,
   patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
 }
 
+static void addPatternForTilingUsingForall(MLIRContext *context,
+                                           RewritePatternSet &patterns,
+                                           StringRef filterName,
+                                           ArrayRef<int64_t> tileSizes,
+                                           ArrayRef<int64_t> interchange = {}) {
+  scf::SCFTilingOptions tilingOptions;
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(context, tileSizes);
+  tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
+  TransformationFilter filter(StringAttr::get(context, filterName),
+                              StringAttr::get(context, "tiled"));
+  patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter);
+}
+
 static void addPatternForTileFuseAndYield(MLIRContext *context,
                                           RewritePatternSet &patterns,
                                           StringRef filterName,
@@ -514,6 +579,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
     addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20});
     return;
   }
+  if (testTilingForAll) {
+    addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
+    return;
+  }
   if (testTileConsumerAndFuseProducer) {
     // 1. Tile and fuse of gemm with fill producer and bias-add consumer.
     addPatternForTileAndFuse(context, patterns, "fusion", {10, 20});



More information about the llvm-commits mailing list