From 746d27372f9f6a7629a3e4d142e2b85072f5d3ec Mon Sep 17 00:00:00 2001
From: ameerj <52414509+ameerj@users.noreply.github.com>
Date: Tue, 28 Sep 2021 21:29:17 -0400
Subject: [PATCH] rescaling_pass: Scale ImageFetch offset if it exists

Plus some code deduplication
---
 .../ir_opt/rescaling_pass.cpp                 | 118 +++++++-----------
 1 file changed, 48 insertions(+), 70 deletions(-)

diff --git a/src/shader_recompiler/ir_opt/rescaling_pass.cpp b/src/shader_recompiler/ir_opt/rescaling_pass.cpp
index 51125f45a3..2aa9c31dce 100644
--- a/src/shader_recompiler/ir_opt/rescaling_pass.cpp
+++ b/src/shader_recompiler/ir_opt/rescaling_pass.cpp
@@ -137,21 +137,50 @@ void PatchImageQueryDimensions(IR::Block& block, IR::Inst& inst) {
     }
 }
 
-void ScaleIntegerCoord(IR::IREmitter& ir, IR::Inst& inst, const IR::U1& is_scaled) {
+void ScaleIntegerComposite(IR::IREmitter& ir, IR::Inst& inst, const IR::U1& is_scaled,
+                           size_t index) {
+    const IR::Value composite{inst.Arg(index)};
+    if (composite.IsEmpty()) {
+        return;
+    }
     const auto info{inst.Flags<IR::TextureInstInfo>()};
-    const IR::Value coord{inst.Arg(1)};
+    const IR::U32 x{Scale(ir, is_scaled, IR::U32{ir.CompositeExtract(composite, 0)})};
+    const IR::U32 y{Scale(ir, is_scaled, IR::U32{ir.CompositeExtract(composite, 1)})};
     switch (info.type) {
-    case TextureType::Color2D: {
-        const IR::U32 x{Scale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 0)})};
-        const IR::U32 y{Scale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 1)})};
-        inst.SetArg(1, ir.CompositeConstruct(x, y));
+    case TextureType::Color2D:
+        inst.SetArg(index, ir.CompositeConstruct(x, y));
+        break;
+    case TextureType::ColorArray2D: {
+        const IR::U32 z{ir.CompositeExtract(composite, 2)};
+        inst.SetArg(index, ir.CompositeConstruct(x, y, z));
         break;
     }
+    case TextureType::Color1D:
+    case TextureType::ColorArray1D:
+    case TextureType::Color3D:
+    case TextureType::ColorCube:
+    case TextureType::ColorArrayCube:
+    case TextureType::Buffer:
+        // Nothing to patch here
+        break;
+    }
+}
+
+void SubScaleCoord(IR::IREmitter& ir, IR::Inst& inst, const IR::U1& is_scaled) {
+    const auto info{inst.Flags<IR::TextureInstInfo>()};
+    const IR::Value coord{inst.Arg(1)};
+    const IR::U32 coord_x{ir.CompositeExtract(coord, 0)};
+    const IR::U32 coord_y{ir.CompositeExtract(coord, 1)};
+
+    const IR::U32 scaled_x{SubScale(ir, is_scaled, coord_x, IR::Attribute::PositionX)};
+    const IR::U32 scaled_y{SubScale(ir, is_scaled, coord_y, IR::Attribute::PositionY)};
+    switch (info.type) {
+    case TextureType::Color2D:
+        inst.SetArg(1, ir.CompositeConstruct(scaled_x, scaled_y));
+        break;
     case TextureType::ColorArray2D: {
-        const IR::U32 x{Scale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 0)})};
-        const IR::U32 y{Scale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 1)})};
         const IR::U32 z{ir.CompositeExtract(coord, 2)};
-        inst.SetArg(1, ir.CompositeConstruct(x, y, z));
+        inst.SetArg(1, ir.CompositeConstruct(scaled_x, scaled_y, z));
         break;
     }
     case TextureType::Color1D:
@@ -169,87 +198,36 @@ void SubScaleImageFetch(IR::Block& block, IR::Inst& inst) {
     IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
     const auto info{inst.Flags<IR::TextureInstInfo>()};
     const IR::U1 is_scaled{ir.IsTextureScaled(ir.Imm32(info.descriptor_index))};
-    const IR::Value coord{inst.Arg(1)};
-    switch (info.type) {
-    case TextureType::Color2D: {
-        const IR::U32 x{SubScale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 0)},
-                                 IR::Attribute::PositionX)};
-        const IR::U32 y{SubScale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 1)},
-                                 IR::Attribute::PositionY)};
-        inst.SetArg(1, ir.CompositeConstruct(x, y));
-        break;
-    }
-    case TextureType::ColorArray2D: {
-        const IR::U32 x{SubScale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 0)},
-                                 IR::Attribute::PositionX)};
-        const IR::U32 y{SubScale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 1)},
-                                 IR::Attribute::PositionY)};
-        const IR::U32 z{ir.CompositeExtract(coord, 2)};
-        inst.SetArg(1, ir.CompositeConstruct(x, y, z));
-        break;
-    }
-    case TextureType::Color1D:
-    case TextureType::ColorArray1D:
-    case TextureType::Color3D:
-    case TextureType::ColorCube:
-    case TextureType::ColorArrayCube:
-    case TextureType::Buffer:
-        // Nothing to patch here
-        break;
-    }
+    SubScaleCoord(ir, inst, is_scaled);
+    // Scale ImageFetch offset
+    ScaleIntegerComposite(ir, inst, is_scaled, 2);
 }
 
 void SubScaleImageRead(IR::Block& block, IR::Inst& inst) {
     IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
     const auto info{inst.Flags<IR::TextureInstInfo>()};
     const IR::U1 is_scaled{ir.IsImageScaled(ir.Imm32(info.descriptor_index))};
-    const IR::Value coord{inst.Arg(1)};
-    switch (info.type) {
-    case TextureType::Color2D: {
-        const IR::U32 x{SubScale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 0)},
-                                 IR::Attribute::PositionX)};
-        const IR::U32 y{SubScale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 1)},
-                                 IR::Attribute::PositionY)};
-        inst.SetArg(1, ir.CompositeConstruct(x, y));
-        break;
-    }
-    case TextureType::ColorArray2D: {
-        const IR::U32 x{SubScale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 0)},
-                                 IR::Attribute::PositionX)};
-        const IR::U32 y{SubScale(ir, is_scaled, IR::U32{ir.CompositeExtract(coord, 1)},
-                                 IR::Attribute::PositionY)};
-        const IR::U32 z{ir.CompositeExtract(coord, 2)};
-        inst.SetArg(1, ir.CompositeConstruct(x, y, z));
-        break;
-    }
-    case TextureType::Color1D:
-    case TextureType::ColorArray1D:
-    case TextureType::Color3D:
-    case TextureType::ColorCube:
-    case TextureType::ColorArrayCube:
-    case TextureType::Buffer:
-        // Nothing to patch here
-        break;
-    }
+    SubScaleCoord(ir, inst, is_scaled);
 }
 
 void PatchImageFetch(IR::Block& block, IR::Inst& inst) {
     IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
     const auto info{inst.Flags<IR::TextureInstInfo>()};
     const IR::U1 is_scaled{ir.IsTextureScaled(ir.Imm32(info.descriptor_index))};
-    ScaleIntegerCoord(ir, inst, is_scaled);
+    ScaleIntegerComposite(ir, inst, is_scaled, 1);
+    // Scale ImageFetch offset
+    ScaleIntegerComposite(ir, inst, is_scaled, 2);
 }
 
 void PatchImageRead(IR::Block& block, IR::Inst& inst) {
     IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
     const auto info{inst.Flags<IR::TextureInstInfo>()};
     const IR::U1 is_scaled{ir.IsImageScaled(ir.Imm32(info.descriptor_index))};
-    ScaleIntegerCoord(ir, inst, is_scaled);
+    ScaleIntegerComposite(ir, inst, is_scaled, 1);
 }
 
 void Visit(const IR::Program& program, IR::Block& block, IR::Inst& inst) {
     const bool is_fragment_shader{program.stage == Stage::Fragment};
-    const bool is_compute_shader{program.stage == Stage::Compute};
     switch (inst.GetOpcode()) {
     case IR::Opcode::GetAttribute: {
         const IR::Attribute attr{inst.Arg(0).Attribute()};
@@ -271,14 +249,14 @@ void Visit(const IR::Program& program, IR::Block& block, IR::Inst& inst) {
     case IR::Opcode::ImageFetch:
         if (is_fragment_shader) {
             SubScaleImageFetch(block, inst);
-        } else if (is_compute_shader) {
+        } else {
             PatchImageFetch(block, inst);
         }
         break;
     case IR::Opcode::ImageRead:
         if (is_fragment_shader) {
             SubScaleImageRead(block, inst);
-        } else if (is_compute_shader) {
+        } else {
             PatchImageRead(block, inst);
         }
         break;