Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for nested record patterns and record patterns in switches #417

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,42 @@ private static boolean checkBranch(Exprent exprent, IfStatement statement, State
Exprent left = first.getAllExprents().get(0);
Exprent right = first.getAllExprents().get(1);

// Right side needs to be a cast function
// If it's not, we might be a record pattern match
if (!(right instanceof FunctionExprent)) {
return identifyRecordPatternMatch(statement, branch, iof, (AssignmentExprent) first);
boolean result = findPatternMatchingInstanceof(left, right, source, target, branch, iof, head);

if (head.getExprents() != null && !head.getExprents().isEmpty() && head.getExprents().get(0) instanceof AssignmentExprent assignment) {
// If it's an assignement, get both sides
left = assignment.getAllExprents().get(0);
right = assignment.getAllExprents().get(1);

// Right side needs to be a cast function
// If it's not, we might be a record pattern match
if (!(right instanceof FunctionExprent)) {
result |= identifyIfRecordPatternMatch(statement, branch, iof, assignment);
}
}

if (((FunctionExprent) right).getFuncType() != FunctionType.CAST) {
statement.setPatternMatched(true);

BasicBlockStatement before = statement.getBasichead();
if (before.getExprents() != null && before.getExprents().size() > 0) {
Exprent last = before.getExprents().get(before.getExprents().size() - 1);
if (last instanceof AssignmentExprent && source instanceof VarExprent) {
Exprent stored = last.getAllExprents().get(0);
Exprent method = last.getAllExprents().get(1);
VarExprent checked = (VarExprent) source;
if ((!(method instanceof FunctionExprent) || ((FunctionExprent) method).getFuncType() != FunctionType.CAST)
&& checked.equals(stored) && !checked.isVarReferenced(root, (VarExprent) stored)) {
iof.getLstOperands().set(0, last.getAllExprents().get(1));
before.getExprents().remove(before.getExprents().size() - 1);
}
}
}

return result;
}

private static boolean findPatternMatchingInstanceof(Exprent left, Exprent right, Exprent source, Exprent target, Statement branch, FunctionExprent iof, Statement head) {
if (!(right instanceof FunctionExprent function) || function.getFuncType() != FunctionType.CAST) {
return false;
}

Expand Down Expand Up @@ -173,24 +202,6 @@ private static boolean checkBranch(Exprent exprent, IfStatement statement, State
if (storeType.isGeneric()) {
iof.getLstOperands().set(1, new ConstExprent(storeType, null, iof.getLstOperands().get(1).bytecode));
}

statement.setPatternMatched(true);

BasicBlockStatement before = statement.getBasichead();
if (before.getExprents() != null && before.getExprents().size() > 0) {
Exprent last = before.getExprents().get(before.getExprents().size() - 1);
if (last instanceof AssignmentExprent && source instanceof VarExprent) {
Exprent stored = last.getAllExprents().get(0);
Exprent method = last.getAllExprents().get(1);
VarExprent checked = (VarExprent) source;
if ((!(method instanceof FunctionExprent) || ((FunctionExprent) method).getFuncType() != FunctionType.CAST)
&& checked.equals(stored) && !checked.isVarReferenced(root, (VarExprent) stored)) {
iof.getLstOperands().set(0, last.getAllExprents().get(1));
before.getExprents().remove(before.getExprents().size() - 1);
}
}
}

return true;
}

Expand Down Expand Up @@ -236,7 +247,7 @@ private static void findVarsInPredecessors(List<VarVersionPair> vvs, Statement r
}
}

private static boolean identifyRecordPatternMatch(IfStatement stat, Statement branch, FunctionExprent instOf, AssignmentExprent head) {
private static boolean identifyIfRecordPatternMatch(IfStatement stat, Statement branch, FunctionExprent instOf, AssignmentExprent head) {
if (!stat.getTopParent().mt.getBytecodeVersion().hasRecordPatternMatching()) {
return false;
}
Expand All @@ -249,19 +260,72 @@ private static boolean identifyRecordPatternMatch(IfStatement stat, Statement br
// if (v instanceof MyType) {
// var10000 = v;
// ...
if (!instOf.getLstOperands().get(0).equals(headRight)) {
// or:
//
// if (v instanceof MyType var10000) {
// ...

if (!(instOf.getLstOperands().size() > 2 ? instOf.getLstOperands().get(2) : instOf.getLstOperands().get(0)).equals(headRight)) {
return false;
}

VarType type = instOf.getLstOperands().get(1).getExprType();

PatternExprent exprent = identifyRecordPatternMatch(stat, branch, headRight, type, false);
if (exprent == null) {
return false;
}

if (instOf.getLstOperands().size() > 2) {
instOf.getLstOperands().set(2, exprent);
} else {
instOf.getLstOperands().add(2, exprent);
}

stat.setPatternMatched(true);
return true;
}

public static PatternExprent identifyRecordPatternMatch(Statement parent, Statement branch, Exprent storeVariable, VarType type, boolean simulate) {
Statement original = branch;

StructClass cl = DecompilerContext.getStructContext().getClass(type.value);
if (cl == null || cl.getRecordComponents() == null) {
return false; // No idea what class, or not a record!

if (cl == null || cl.getRecordComponents() == null || cl.getRecordComponents().isEmpty()) {
return null;
}

List<StructRecordComponent> comp = cl.getRecordComponents();
// Ending exprents we may want to remove
Map<BasicBlockStatement, Exprent> remove = new HashMap<>();
// Statements that ought to be destroyed as a result of creating the pattern
List<Statement> toDestroy = new ArrayList<>();

PatternData pattern = getChildPattern(cl, storeVariable, type, branch, 1, toDestroy, remove);
if (pattern == null) {
return null;
}
branch = pattern.stat;

if (simulate) {
return pattern.exp;
}

if (original != branch) {
parent.replaceStatement(original, branch);
}

for (Statement st : toDestroy) {
st.replaceWithEmpty();
}

for (Map.Entry<BasicBlockStatement, Exprent> e : remove.entrySet()) {
e.getKey().getExprents().remove(e.getValue());
}

return pattern.exp;
}

private static PatternData getChildPattern(StructClass cl, Exprent storeVariable, VarType type, Statement branch, int stIdx, List<Statement> toDestroy, Map<BasicBlockStatement, Exprent> remove) {
// Iteratively go through the sequence to see if it extracts from the record

// The general strategy is to identify an "extracting try" [1] for each record component.
Expand All @@ -280,19 +344,20 @@ private static boolean identifyRecordPatternMatch(IfStatement stat, Statement br
// realVar = exVar;
// <stackVar> = <originalVar>;

int stIdx = 1;

// Map which variable refers to which part of the record
Map<StructRecordComponent, VarExprent> vars = new LinkedHashMap<>();
if (cl == null || cl.getRecordComponents() == null) {
return null; // No idea what class, or not a record!
}

// Ending exprents we may want to remove
Map<BasicBlockStatement, Exprent> remove = new HashMap<>();
// Statements that ought to be destroyed as a result of creating the pattern
List<Statement> toDestroy = new ArrayList<>();
record PatternStore(StructRecordComponent component, StructClass cl, VarType type, VarExprent store) {
}
List<PatternStore> patternStores = new ArrayList<>();
List<StructRecordComponent> comp = cl.getRecordComponents();

// Map which variable refers to which part of the record
Map<StructRecordComponent, Exprent> vars = new LinkedHashMap<>();
for (StructRecordComponent c : comp) {
if (branch.getStats().size() <= stIdx) {
return false;
return null;
}

Statement next = branch.getStats().get(stIdx);
Expand All @@ -317,7 +382,7 @@ private static boolean identifyRecordPatternMatch(IfStatement stat, Statement br
}

if (foundVar == null) {
return false;
return null;
}

toDestroy.add(next);
Expand All @@ -340,7 +405,7 @@ private static boolean identifyRecordPatternMatch(IfStatement stat, Statement br
// If that's the only other thing in the statement, then we can destroy it!
boolean destroyed = false;
if (next.getExprents().size() == 2) {
if (next.getExprents().get(1) instanceof AssignmentExprent nAssign && nAssign.getRight().equals(headRight)) {
if (next.getExprents().get(1) instanceof AssignmentExprent nAssign && nAssign.getRight().equals(storeVariable)) {
toDestroy.add(next);

destroyed = true;
Expand All @@ -353,11 +418,21 @@ private static boolean identifyRecordPatternMatch(IfStatement stat, Statement br
}
}
}
} else {
} else if (next instanceof IfStatement ifSt && ifSt.iftype == IfStatement.IFTYPE_IF && ifSt.getHeadexprent().getCondition() instanceof FunctionExprent func) {
// Is the next statement an if with an instanceof inside? It might be a type-improving if. Search inside it too.
if (next instanceof IfStatement ifSt && ifSt.iftype == IfStatement.IFTYPE_IF
&& ifSt.getHeadexprent().getCondition() instanceof FunctionExprent func && func.getFuncType() == FunctionType.INSTANCEOF) {
FunctionExprent function = null;
boolean found = false;
boolean inverted = false;
if (func.getFuncType() == FunctionType.INSTANCEOF) {
found = true;
function = func;
} else if (func.getFuncType() == FunctionType.BOOL_NOT && func.getLstOperands().get(0) instanceof FunctionExprent inner && inner.getFuncType() == FunctionType.INSTANCEOF) {
found = true;
inverted = true;
function = inner;
}

if (found) {
// "<stackVar> = <originalVar>;" idiom
// Ensure this is the right idiom be fore we mark it for destruction.
if (branch.getBasichead().getExprents().size() == 1) {
Expand All @@ -367,8 +442,20 @@ private static boolean identifyRecordPatternMatch(IfStatement stat, Statement br
}
}

branch = ifSt.getIfstat();
stIdx = 0;
Exprent store = function.getLstOperands().size() > 2 ? function.getLstOperands().get(2) : function.getLstOperands().get(0);
if (store instanceof VarExprent variable) {
patternStores.add(new PatternStore(c, DecompilerContext.getStructContext().getClass(variable.getExprType().value), variable.getExprType(), variable));
vars.put(c, variable);
ok = true;
}

if (inverted) {
stIdx++;
toDestroy.add(ifSt);
} else {
branch = ifSt.getIfstat();
stIdx = 0;
}
}
}

Expand All @@ -381,26 +468,30 @@ private static boolean identifyRecordPatternMatch(IfStatement stat, Statement br
stIdx++;
}
} else {
return false;
return null;
}
}

PatternExprent pattern = new PatternExprent(PatternExprent.recordData(cl), type, new ArrayList<>(vars.values()));

instOf.getLstOperands().add(2, pattern);
stat.setPatternMatched(true);

for (Statement st : toDestroy) {
st.replaceWithEmpty();
}

for (Map.Entry<BasicBlockStatement, Exprent> e : remove.entrySet()) {
e.getKey().getExprents().remove(e.getValue());
// Check for any nested record patterns
for (PatternStore patternStore : patternStores) {
List<Statement> tmpToDestroy = new ArrayList<>();
Map<BasicBlockStatement, Exprent> tmpRemove = new HashMap<>();
PatternData patternData = getChildPattern(patternStore.cl, patternStore.store, patternStore.type, branch, stIdx, tmpToDestroy, tmpRemove);
if (patternData != null) {
vars.put(patternStore.component, patternData.exp);
branch = patternData.stat;
stIdx = patternData.index;
toDestroy.addAll(tmpToDestroy);
remove.putAll(tmpRemove);
}
}

return true;
PatternExprent pattern = new PatternExprent(PatternExprent.recordData(cl), type, new ArrayList<>(vars.values()));
return new PatternData(pattern, branch, stIdx);
}

private record PatternData(PatternExprent exp, Statement stat, int index) {}

public static boolean isStatementMatchThrow(Statement st) {
if (st instanceof BasicBlockStatement && st.getExprents().size() == 1) {
// throw ...
Expand Down
Loading
Loading