diff --git a/tests/unit/compiler/venom/test_load_elimination.py b/tests/unit/compiler/venom/test_load_elimination.py new file mode 100644 index 0000000000..ee8c082f23 --- /dev/null +++ b/tests/unit/compiler/venom/test_load_elimination.py @@ -0,0 +1,134 @@ +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.basicblock import IRLiteral, IRVariable +from vyper.venom.context import IRContext +from vyper.venom.passes.load_elimination import LoadElimination + + +def test_simple_load_elimination(): + ctx = IRContext() + fn = ctx.create_function("test") + + bb = fn.get_basic_block() + + ptr = IRLiteral(11) + bb.append_instruction("mload", ptr) + bb.append_instruction("mload", ptr) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + LoadElimination(ac, fn).run_pass() + + assert len([inst for inst in bb.instructions if inst.opcode == "mload"]) == 1 + + inst0, inst1, inst2 = bb.instructions + + assert inst0.opcode == "mload" + assert inst1.opcode == "store" + assert inst1.operands[0] == inst0.output + assert inst2.opcode == "stop" + + +def test_equivalent_var_elimination(): + ctx = IRContext() + fn = ctx.create_function("test") + + bb = fn.get_basic_block() + + ptr1 = bb.append_instruction("store", IRLiteral(11)) + ptr2 = bb.append_instruction("store", ptr1) + bb.append_instruction("mload", ptr1) + bb.append_instruction("mload", ptr2) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + LoadElimination(ac, fn).run_pass() + + assert len([inst for inst in bb.instructions if inst.opcode == "mload"]) == 1 + + inst0, inst1, inst2, inst3, inst4 = bb.instructions + + assert inst0.opcode == "store" + assert inst1.opcode == "store" + assert inst2.opcode == "mload" + assert inst2.operands[0] == inst0.output + assert inst3.opcode == "store" + assert inst3.operands[0] == inst2.output + assert inst4.opcode == "stop" + + +def test_elimination_barrier(): + ctx = IRContext() + fn = ctx.create_function("test") + + bb = fn.get_basic_block() + + ptr = IRLiteral(11) + bb.append_instruction("mload", ptr) + + arbitrary = IRVariable("%100") + # fence, writes to memory + bb.append_instruction("staticcall", arbitrary, arbitrary, arbitrary, arbitrary) + + bb.append_instruction("mload", ptr) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + instructions = bb.instructions.copy() + LoadElimination(ac, fn).run_pass() + + assert instructions == bb.instructions # no change + + +def test_store_load_elimination(): + ctx = IRContext() + fn = ctx.create_function("test") + + bb = fn.get_basic_block() + + val = IRLiteral(55) + ptr1 = bb.append_instruction("store", IRLiteral(11)) + ptr2 = bb.append_instruction("store", ptr1) + bb.append_instruction("mstore", val, ptr1) + bb.append_instruction("mload", ptr2) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + LoadElimination(ac, fn).run_pass() + + assert len([inst for inst in bb.instructions if inst.opcode == "mload"]) == 0 + + inst0, inst1, inst2, inst3, inst4 = bb.instructions + + assert inst0.opcode == "store" + assert inst1.opcode == "store" + assert inst2.opcode == "mstore" + assert inst3.opcode == "store" + assert inst3.operands[0] == inst2.operands[0] + assert inst4.opcode == "stop" + + +def test_store_load_barrier(): + ctx = IRContext() + fn = ctx.create_function("test") + + bb = fn.get_basic_block() + + val = IRLiteral(55) + ptr1 = bb.append_instruction("store", IRLiteral(11)) + ptr2 = bb.append_instruction("store", ptr1) + bb.append_instruction("mstore", val, ptr1) + + arbitrary = IRVariable("%100") + # fence, writes to memory + bb.append_instruction("staticcall", arbitrary, arbitrary, arbitrary, arbitrary) + + bb.append_instruction("mload", ptr2) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + + instructions = bb.instructions.copy() + LoadElimination(ac, fn).run_pass() + + assert instructions == bb.instructions diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index bf3115b4dd..06d316d34b 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -14,6 +14,7 @@ AlgebraicOptimizationPass, BranchOptimizationPass, DFTPass, + LoadElimination, MakeSSA, Mem2Var, RemoveUnusedVariablesPass, @@ -52,8 +53,11 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None: Mem2Var(ac, fn).run_pass() MakeSSA(ac, fn).run_pass() SCCP(ac, fn).run_pass() + StoreElimination(ac, fn).run_pass() SimplifyCFGPass(ac, fn).run_pass() + LoadElimination(ac, fn).run_pass() + AlgebraicOptimizationPass(ac, fn).run_pass() # NOTE: MakeSSA is after algebraic optimization it currently produces # smaller code by adding some redundant phi nodes. This is not a diff --git a/vyper/venom/passes/__init__.py b/vyper/venom/passes/__init__.py index 83098234c1..f2ce0045cb 100644 --- a/vyper/venom/passes/__init__.py +++ b/vyper/venom/passes/__init__.py @@ -1,6 +1,7 @@ from .algebraic_optimization import AlgebraicOptimizationPass from .branch_optimization import BranchOptimizationPass from .dft import DFTPass +from .load_elimination import LoadElimination from .make_ssa import MakeSSA from .mem2var import Mem2Var from .normalization import NormalizationPass diff --git a/vyper/venom/passes/load_elimination.py b/vyper/venom/passes/load_elimination.py new file mode 100644 index 0000000000..c9d6f8c07a --- /dev/null +++ b/vyper/venom/passes/load_elimination.py @@ -0,0 +1,73 @@ +from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis, VarEquivalenceAnalysis +from vyper.venom.passes.base_pass import IRPass + + +class LoadElimination(IRPass): + """ + Eliminate sloads, mloads and tloads + """ + + # should this be renamed to EffectsElimination? + + def run_pass(self): + self.equivalence = self.analyses_cache.request_analysis(VarEquivalenceAnalysis) + + for bb in self.function.get_basic_blocks(): + self._process_bb(bb) + + self.analyses_cache.invalidate_analysis(LivenessAnalysis) + self.analyses_cache.invalidate_analysis(DFGAnalysis) + + def equivalent(self, op1, op2): + return op1 == op2 or self.equivalence.equivalent(op1, op2) + + def _process_bb(self, bb): + transient = () + storage = () + memory = () + + for inst in bb.instructions: + if "memory" in inst.get_write_effects(): + memory = () + if "storage" in inst.get_write_effects(): + storage = () + if "transient" in inst.get_write_effects(): + transient = () + + if inst.opcode == "mstore": + # mstore [val, ptr] + memory = (inst.operands[1], inst.operands[0]) + if inst.opcode == "sstore": + storage = (inst.operands[1], inst.operands[0]) + if inst.opcode == "tstore": + transient = (inst.operands[1], inst.operands[0]) + + if inst.opcode == "mload": + prev_memory = memory + memory = (inst.operands[0], inst.output) + if not prev_memory: + continue + if not self.equivalent(inst.operands[0], prev_memory[0]): + continue + inst.opcode = "store" + inst.operands = [prev_memory[1]] + + if inst.opcode == "sload": + prev_storage = storage + storage = (inst.operands[0], inst.output) + if not prev_storage: + continue + if not self.equivalent(inst.operands[0], prev_storage[0]): + continue + inst.opcode = "store" + inst.operands = [prev_storage[1]] + + if inst.opcode == "tload": + prev_transient = transient + transient = (inst.operands[0], inst.output) + if not prev_transient: + continue + if not self.equivalent(inst.operands[0], prev_transient[0]): + continue + inst.opcode = "store" + inst.operands = [prev_transient[1]]