shader: Implement BRX

This commit is contained in:
FernandoS27 2021-03-27 22:30:24 +01:00 committed by ameerj
parent 39a379632e
commit 34aba9627a
21 changed files with 437 additions and 48 deletions

View file

@ -14,6 +14,7 @@
#include "shader_recompiler/exception.h"
#include "shader_recompiler/frontend/maxwell/control_flow.h"
#include "shader_recompiler/frontend/maxwell/decode.h"
#include "shader_recompiler/frontend/maxwell/indirect_branch_table_track.h"
#include "shader_recompiler/frontend/maxwell/location.h"
namespace Shader::Maxwell::Flow {
@ -252,9 +253,7 @@ CFG::AnalysisState CFG::AnalyzeInst(Block* block, FunctionId function_id, Locati
const Opcode opcode{Decode(inst.raw)};
switch (opcode) {
case Opcode::BRA:
case Opcode::BRX:
case Opcode::JMP:
case Opcode::JMX:
case Opcode::RET:
if (!AnalyzeBranch(block, function_id, pc, inst, opcode)) {
return AnalysisState::Continue;
@ -264,10 +263,6 @@ CFG::AnalysisState CFG::AnalyzeInst(Block* block, FunctionId function_id, Locati
case Opcode::JMP:
AnalyzeBRA(block, function_id, pc, inst, IsAbsoluteJump(opcode));
break;
case Opcode::BRX:
case Opcode::JMX:
AnalyzeBRX(block, pc, inst, IsAbsoluteJump(opcode));
break;
case Opcode::RET:
block->end_class = EndClass::Return;
break;
@ -302,6 +297,9 @@ CFG::AnalysisState CFG::AnalyzeInst(Block* block, FunctionId function_id, Locati
case Opcode::SSY:
block->stack.Push(OpcodeToken(opcode), BranchOffset(pc, inst));
return AnalysisState::Continue;
case Opcode::BRX:
case Opcode::JMX:
return AnalyzeBRX(block, pc, inst, IsAbsoluteJump(opcode), function_id);
case Opcode::EXIT:
return AnalyzeEXIT(block, function_id, pc, inst);
case Opcode::PRET:
@ -407,8 +405,46 @@ void CFG::AnalyzeBRA(Block* block, FunctionId function_id, Location pc, Instruct
block->branch_true = AddLabel(block, block->stack, bra_pc, function_id);
}
void CFG::AnalyzeBRX(Block*, Location, Instruction, bool is_absolute) {
throw NotImplementedException("{}", is_absolute ? "JMX" : "BRX");
CFG::AnalysisState CFG::AnalyzeBRX(Block* block, Location pc, Instruction inst, bool is_absolute,
FunctionId function_id) {
const std::optional brx_table{TrackIndirectBranchTable(env, pc, block->begin)};
if (!brx_table) {
TrackIndirectBranchTable(env, pc, block->begin);
throw NotImplementedException("Failed to track indirect branch");
}
const IR::FlowTest flow_test{inst.branch.flow_test};
const Predicate pred{inst.Pred()};
if (flow_test != IR::FlowTest::T || pred != Predicate{true}) {
throw NotImplementedException("Conditional indirect branch");
}
std::vector<u32> targets;
targets.reserve(brx_table->num_entries);
for (u32 i = 0; i < brx_table->num_entries; ++i) {
u32 target{env.ReadCbufValue(brx_table->cbuf_index, brx_table->cbuf_offset + i * 4)};
if (!is_absolute) {
target += pc.Offset();
}
target += brx_table->branch_offset;
target += 8;
targets.push_back(target);
}
std::ranges::sort(targets);
targets.erase(std::unique(targets.begin(), targets.end()), targets.end());
block->indirect_branches.reserve(targets.size());
for (const u32 target : targets) {
Block* const branch{AddLabel(block, block->stack, target, function_id)};
block->indirect_branches.push_back(branch);
}
block->cond = IR::Condition{true};
block->end = pc + 1;
block->end_class = EndClass::IndirectBranch;
block->branch_reg = brx_table->branch_reg;
block->branch_offset = brx_table->branch_offset + 8;
if (!is_absolute) {
block->branch_offset += pc.Offset();
}
return AnalysisState::Branch;
}
CFG::AnalysisState CFG::AnalyzeEXIT(Block* block, FunctionId function_id, Location pc,
@ -449,7 +485,6 @@ Block* CFG::AddLabel(Block* block, Stack stack, Location pc, FunctionId function
// Block already exists and it has been visited
return &*it;
}
// TODO: FIX DANGLING BLOCKS
Block* const new_block{block_pool.Create(Block{
.begin{pc},
.end{pc},
@ -494,6 +529,11 @@ std::string CFG::Dot() const {
add_branch(block.branch_false, false);
}
break;
case EndClass::IndirectBranch:
for (Block* const branch : block.indirect_branches) {
add_branch(branch, false);
}
break;
case EndClass::Call:
dot += fmt::format("\t\t{}->N{};\n", name, node_uid);
dot += fmt::format("\t\tN{}->{};\n", node_uid, NameOf(*block.return_block));