Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CalculateCurveFitAction and support for 3PL curve fit #6053

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions api/src/org/labkey/api/data/statistics/CurveFit.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,83 @@ default JSONObject toJSON()
* @return The integrated area under the curve.
*/
double calculateAUC(StatsService.AUCType type, double startX, double endX) throws FitFailedException;

cnathe marked this conversation as resolved.
Show resolved Hide resolved
/**
* Calculates the residual sum of squares (RSS) for the curve fit (https://en.wikipedia.org/wiki/Residual_sum_of_squares)
* @param parameters the parameters to use for the give calculated curve fit
* @return the calculated residual sum of squares
*/
default double residualSumSquares(P parameters)
{
double sumSq = 0;
for (DoublePoint point : getData())
{
double expectedValue = point.getY();
double foundValue = fitCurve(point.getX(), parameters);

sumSq += Math.pow(foundValue - expectedValue, 2);
}
return sumSq;
}

/**
* Calculates the root mean square error (RMSE), or root mean square deviation (RMSD),
* for the curve fit (https://en.wikipedia.org/wiki/Root_mean_square_deviation)
* @param parameters the parameters to use for the give calculated curve fit
* @return the calculated root mean square error
*/
default double rootMeanSquareError(P parameters)
{
return Math.sqrt(residualSumSquares(parameters) / getData().length);
}

/**
* Calculates the total sum of squares (TSS) for the data points (https://en.wikipedia.org/wiki/Total_sum_of_squares).
* This value is used in the R^2 calculation.
* @return the calculated total sum of squares
*/
default double totalSumSquares()
{
double sumSq = 0;
double mean = 0;
for (DoublePoint point : getData())
mean += point.getY();
mean /= getData().length;

for (DoublePoint point : getData())
{
double expectedValue = point.getY();
sumSq += Math.pow(expectedValue - mean, 2);
}
return sumSq;
}

/**
* Calculates the R^2 value for the curve fit (https://en.wikipedia.org/wiki/Coefficient_of_determination)
* using the residualSumSquares() and totalSumSquares() methods.
* @param parameters the parameters to use for the give calculated curve fit
* @return the calculated R^2 value
*/
default double rSquared(P parameters)
{
return 1 - residualSumSquares(parameters) / totalSumSquares();
}

// see description below, this version of the method is here so that each applicable curve fit can override it
// to set the correct p value for the degrees of freedom
default double adjustedRSquared(P parameters)
{
return Double.NaN;
}

/**
* Calculates the adjusted R^2 value for the curve fit (https://en.wikipedia.org/wiki/Coefficient_of_determination)
* @param parameters the parameters to use for the give calculated curve fit
* @return the calculated adjusted R^2 value (if possible)
*/
default double adjustedRSquared(P parameters, int p)
{
int n = getData().length;
return 1 - (1 - rSquared(parameters)) * (n - 1) / (n - p - 1);
}
}
3 changes: 3 additions & 0 deletions api/src/org/labkey/api/data/statistics/StatsService.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ static void setInstance(StatsService impl)

enum CurveFitType
{
THREE_PARAMETER("Three Parameter", "3pl"),
FOUR_PARAMETER("Four Parameter", "4pl"),
FIVE_PARAMETER("Five Parameter", "5pl"),
THREE_PARAMETER_ALT("3 Parameter", "3param"),
FOUR_PARAMETER_SIMPLEX("4 Parameter", "4param"),
POLYNOMIAL("Polynomial", "poly"),
LINEAR("Linear", "linear"),
Expand Down Expand Up @@ -120,4 +122,5 @@ public String getLabel()
* @param data an array of {@code DoublePoint} instances to initialize the curve fit with.
*/
CurveFit getCurveFit(CurveFitType type, DoublePoint[] data);
CurveFit getCurveFit(CurveFitType type, DoublePoint[] data, @Nullable Double asymptoteMin, @Nullable Double asymptoteMax);
}
9 changes: 1 addition & 8 deletions core/src/org/labkey/core/statistics/DefaultCurveFit.java
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,7 @@ public double getFitError() throws FitFailedException

protected double calculateFitError(P parameters)
{
double deviationValue = 0;
for (DoublePoint point : getData())
{
double expectedValue = point.getY();
double foundValue = fitCurve(point.getX(), parameters);
deviationValue += Math.pow(foundValue - expectedValue, 2);
}
return Math.sqrt(deviationValue / getData().length);
return rootMeanSquareError(parameters);
}

@Override
Expand Down
37 changes: 8 additions & 29 deletions core/src/org/labkey/core/statistics/FourParameterSimplex.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class FourParameterSimplex extends ParameterCurveFit implements Multivari

public FourParameterSimplex(DoublePoint[] data)
{
super(data, StatsService.CurveFitType.FOUR_PARAMETER_SIMPLEX);
super(data, StatsService.CurveFitType.FOUR_PARAMETER_SIMPLEX, null, null);
}

@Override
Expand Down Expand Up @@ -106,26 +106,13 @@ private void optimize(MultivariateOptimizer optimizer, double[] start)

protected double calculateFitError(SigmoidalParameters parameters)
{
double deviationValue = 0;
double varianceValue = 0;
double total = 0;

// find the mean
for (DoublePoint point : getData())
{
total += point.getY();
}
double mean = total / getData().length;

for (DoublePoint point : getData())
{
double expectedValue = point.getY();
double foundValue = fitCurve(point.getX(), parameters);
deviationValue += Math.pow(foundValue - expectedValue, 2);
varianceValue += Math.pow(expectedValue - mean, 2);
}
return rSquared(parameters);
}

return 1 - deviationValue / varianceValue;
@Override
public double adjustedRSquared(SigmoidalParameters parameters)
{
return adjustedRSquared(parameters, 4);
}

@Override
Expand All @@ -139,15 +126,7 @@ public double value(double[] point)
private double sumSquares(double[] params)
{
SigmoidalParameters parameters = createParams(params);
double sumSq = 0;
for (DoublePoint point : getData())
{
double expectedValue = point.getY();
double foundValue = fitCurve(point.getX(), parameters);

sumSq += Math.pow(foundValue - expectedValue, 2);
}
return sumSq;
return residualSumSquares(parameters);
}

private SigmoidalParameters createParams(double[] params)
Expand Down
62 changes: 56 additions & 6 deletions core/src/org/labkey/core/statistics/ParameterCurveFit.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.labkey.core.statistics;

import org.jetbrains.annotations.Nullable;
import org.json.JSONObject;
import org.labkey.api.data.statistics.CurveFit;
import org.labkey.api.data.statistics.DoublePoint;
Expand All @@ -32,7 +33,9 @@
*/
public class ParameterCurveFit extends DefaultCurveFit<ParameterCurveFit.SigmoidalParameters> implements CurveFit<ParameterCurveFit.SigmoidalParameters>
{
private final StatsService.CurveFitType _fitType;
protected final StatsService.CurveFitType _fitType;
private Double _asymptoteMin;
private Double _asymptoteMax;

public static class SigmoidalParameters implements CurveFit.Parameters, Cloneable
{
Expand Down Expand Up @@ -112,10 +115,12 @@ public static SigmoidalParameters fromJSON(JSONObject json)
}
}

public ParameterCurveFit(DoublePoint[] data, StatsService.CurveFitType fitType)
public ParameterCurveFit(DoublePoint[] data, StatsService.CurveFitType fitType, @Nullable Double asymptoteMin, @Nullable Double asymptoteMax)
{
super(data);
_fitType = fitType;
setAsymptoteMin(asymptoteMin);
setAsymptoteMax(asymptoteMax);
}

@Override
Expand All @@ -124,6 +129,26 @@ public StatsService.CurveFitType getType()
return _fitType;
}

public Double getAsymptoteMin()
{
return _asymptoteMin;
}

public void setAsymptoteMin(Double asymptoteMin)
{
_asymptoteMin = asymptoteMin;
}

public Double getAsymptoteMax()
{
return _asymptoteMax;
}

public void setAsymptoteMax(Double asymptoteMax)
{
_asymptoteMax = asymptoteMax;
}

@Override
protected SigmoidalParameters computeParameters()
{
Expand Down Expand Up @@ -196,27 +221,50 @@ public double solveForX(double y)
}
}

@Override
public double adjustedRSquared(SigmoidalParameters parameters)
{
return switch (_fitType)
{
case THREE_PARAMETER, THREE_PARAMETER_ALT -> adjustedRSquared(parameters, 3);
case FOUR_PARAMETER -> adjustedRSquared(parameters, 4);
case FIVE_PARAMETER -> adjustedRSquared(parameters, 5);
default -> throw new IllegalStateException("Unsupported curve fit type: " + _fitType.name());
};
}

private boolean is3Parameter()
{
return _fitType == StatsService.CurveFitType.THREE_PARAMETER || _fitType == StatsService.CurveFitType.THREE_PARAMETER_ALT;
}

private boolean is4Parameter()
{
return _fitType == StatsService.CurveFitType.FOUR_PARAMETER;
}

protected SigmoidalParameters calculateFitParameters(double minValue, double maxValue)
{
SigmoidalParameters bestFit = null;
SigmoidalParameters parameters = new SigmoidalParameters();
double step = 10;
if (_fitType == StatsService.CurveFitType.FOUR_PARAMETER)
Double asymptoteDiff = getAsymptoteMax() != null && getAsymptoteMin() != null ? Math.abs(getAsymptoteMax() - getAsymptoteMin()) : null;
double step = asymptoteDiff != null ? Math.min(asymptoteDiff / 100, 10) : 10;
if (is3Parameter() || is4Parameter())
parameters.asymmetry = 1;

// try reasonable variants of max and min, in case there's a better fit. We'll keep going past "reasonable" if
// we haven't found a single bestFit option, but we need to bail out at some point. We currently quit once max
// reaches 200 or min reaches -100, since these values don't seem biologically reasonable.
for (double min = minValue; (bestFit == null || min > 0 - step) && min > (minValue - 100); min -= step )
{
parameters.min = min;
parameters.min = getAsymptoteMin() != null ? getAsymptoteMin() : min;
for (double max = maxValue; (bestFit == null || max <= 100 + step) && max < (maxValue + 100); max += step )
{
double absoluteCutoff = min + (0.5 * (max - min));
double relativeEC50 = getInterpolatedCutoffXValue(absoluteCutoff);
if (!Double.isInfinite(relativeEC50) && !Double.isNaN(relativeEC50))
{
parameters.max = max;
parameters.max = getAsymptoteMax() != null ? getAsymptoteMax() : max;
parameters.inflection = relativeEC50;
for (double slopeRadians = 0; slopeRadians < Math.PI; slopeRadians += Math.PI / 30)
{
Expand All @@ -232,6 +280,8 @@ protected SigmoidalParameters calculateFitParameters(double minValue, double max
bestFit = parameters.copy();
}
break;
case THREE_PARAMETER:
case THREE_PARAMETER_ALT:
case FOUR_PARAMETER:
parameters.asymmetry = 1;
parameters.fitError = calculateFitError(parameters);
Expand Down
Loading