Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/bg-18 ACMG - Adding PP2 #195

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,33 @@ import org.apache.spark.sql.{Column, DataFrame}

object ACMGImplicits {

val variantColumns = Array("chromosome", "start", "end", "reference", "alternate")

private def getJoinedSchema(leftSchema: StructType, rightSchema: StructType, joinedKeys: Seq[String]): StructType = {
StructType(leftSchema ++ rightSchema.fields.filterNot(field => joinedKeys.contains(field.name)))
}

private def validateRequiredColumns(map: Map[DataFrame, (String, Array[String])], criteriaName: String = "criteria"): Unit = {
map.foreach {
case (df, (dfName, columns)) => columns.foreach(
col => require(
df.columns.contains(col),
s"Column `$col` is required in DataFrame $dfName for $criteriaName.")
)
}
}

/**
* inColArray
* Anonymous helper function generating boolean columns for array-containing columns.
*
* Designed to be used in a withColumn statement. Checks whether an entry in the column contains
* at least one value from values argument.
*
* @return A Column of boolean
*/
val inColArray = (colName: String, values: List[String]) => values.map(m => array_contains(col(colName), m)).reduce(_ || _)

implicit class ACMGOperations(df: DataFrame) {

/**
Expand Down Expand Up @@ -46,6 +73,111 @@ object ACMGImplicits {
}
}

def getBS2(orphanet: DataFrame, frequencies: DataFrame): DataFrame = {

val map = Map(
df -> ("df", Array("symbol") ++ variantColumns),
orphanet -> ("orphanet", Array("gene_symbol", "average_age_of_onset", "type_of_inheritance")),
frequencies -> ("frequencies", Array("external_frequencies", "genes_symbol") ++ variantColumns)
)
validateRequiredColumns(map, "PM2")

val threshold = 4

val onsets = List(
"Adult",
"Elderly",
"All ages",
"No data available")

val is_dominant_inheritance = List(
"Autosomal dominant",
"X-linked dominant",
"Y-linked",
"Mitochondrial inheritance")

val orphanetDF = orphanet.select("gene_symbol", "average_age_of_onset", "type_of_inheritance")
.withColumn("is_adult_onset", inColArray("average_age_of_onset", onsets))
.filter(col("is_adult_onset") === false)
.withColumn("is_dominant", inColArray("type_of_inheritance", is_dominant_inheritance))
.select(
col("gene_symbol").as("symbol"),
col("is_dominant"))
.distinct()

val freqDF = frequencies
.select(
col("chromosome"),
col("start"),
col("end"),
col("reference"),
col("alternate"),
explode(col("genes_symbol")).as("symbol"),
col("external_frequencies.gnomad_genomes_3_1_1.ac").as("gnomad_ac"),
col("external_frequencies.gnomad_genomes_3_1_1.hom").as("gnomad_hom"))
Comment on lines +99 to +117
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maynbe we should prepare a table offline, and join on this instead of calculating this dataframe each time?


df
.join(orphanetDF, Seq("symbol"), "leftouter")
.na.fill(false, Seq("is_dominant"))
.join(freqDF, Seq("chromosome", "start", "end", "reference", "alternate", "symbol"), "leftouter")
.withColumn("BS2", struct(
col("gnomad_ac"),
col("gnomad_hom"),
col("is_dominant"),
(
col("gnomad_hom").isNotNull &&
(
col("gnomad_hom") >= threshold ||
(col("is_dominant") && col("gnomad_ac").isNotNull && col("gnomad_ac") >= threshold)
)
).as("score")
))
.drop("gnomad_ac", "gnomad_hom", "is_dominant")

}

def getPP2(clinvar: DataFrame): DataFrame = {

val map = Map(
df -> ("df", Array("symbol", "consequences") ++ variantColumns),
clinvar -> ("clinvar", Array("geneinfo", "mc", "clin_sig")),
)
validateRequiredColumns(map, "PM2")

val pathogenicClinSig = List("Pathogenic", "Likely_pathogenic")
val benignClinSig = List("Benign", "Likely_benign")

val clinvarDF = clinvar
.filter(array_contains(col("mc"), "missense_variant") === true)
.withColumn("is_pathogenic", inColArray("clin_sig", pathogenicClinSig))
.withColumn("is_benign", inColArray("clin_sig", benignClinSig))
.filter((col("is_pathogenic") =!= col("is_benign")) === true)
.withColumn("symbol", explode(split(col("geneinfo"), "\\|")))
.groupBy("symbol").agg(
sum(col("is_pathogenic").cast("int")).alias("n_pathogenic"),
sum(col("is_benign").cast("int")).alias("n_benign")
)
.withColumn("is_missense_pathogenic", col("n_pathogenic") >= 3 && col("n_pathogenic") > col("n_benign") * 2)
.withColumn("symbol", split(col("symbol"), ":").getItem(0))

val joinedSchema = getJoinedSchema(df.schema, clinvarDF.schema, Seq("symbol")).fieldNames

df.join(clinvarDF, Seq("symbol"), "leftouter")
.select(joinedSchema.head, joinedSchema.tail: _*)
.na.fill(0, Seq("n_benign", "n_pathogenic"))
.na.fill(false, Seq("is_missense_pathogenic"))
.withColumn("pp2", struct(
col("n_benign"),
col("n_pathogenic"),
col("is_missense_pathogenic"),
(
col("is_missense_pathogenic").isNotNull &&
col("is_missense_pathogenic") === true &&
array_contains(col("consequences"), "missense_variant")
).as("score")
)).drop("n_pathogenic", "n_benign", "is_missense_pathogenic")
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers

spark.sparkContext.setLogLevel("ERROR")

val variantSchema = new StructType()
.add("chromosome", StringType, true)
.add("start", IntegerType, true)
.add("end", IntegerType, true)
.add("reference", StringType, true)
.add("alternate", StringType, true)

def ba1Fixture = {
new {
val querySchema = new StructType()
Expand Down Expand Up @@ -65,4 +72,160 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers
f.result.collect() should contain theSameElementsAs f.resultData
}

def bs2Fixture = {
new {
val orphanetSchema = new StructType()
.add("gene_symbol", StringType, false)
.add("average_age_of_onset", new ArrayType(StringType, true), true)
.add("type_of_inheritance", new ArrayType(StringType, true), true)

val orphanetData = Seq(
Row("gene1", Array("Neonatal", "Antenatal"), Array("Autosomal recessive")),
Row("gene2", Array("Neonatal"), Array("Autosomal dominant")),
Row("gene3", Array("All ages"), Array("Autosomal dominant")),
)

val orphanetDF = spark.createDataFrame(spark.sparkContext.parallelize(orphanetData), orphanetSchema)

val freqSchema = variantSchema
.add("genes_symbol", new ArrayType(StringType, true), true)
.add("external_frequencies", new StructType()
.add("gnomad_genomes_3_1_1", new StructType()
.add("ac", IntegerType, true)
.add("hom", IntegerType, true)))

val freqData = Seq(Row("1", 1, 2, "A", "C", Array("gene1"), Row(Row(10, 0))))

val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), freqSchema)

val querySchema = variantSchema
.add("symbol", StringType, true)

val queryData = Seq(Row("1", 1, 2, "A", "C", "gene1"))

val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), querySchema)
}
}

"get_BS2" should "throw IllegalArgumentException if `average_age_of_onset` column is absent from the Orphanet DataFrame" in {
val f = bs2Fixture

an[IllegalArgumentException] should be thrownBy f.queryDF.getBS2(f.orphanetDF.drop("average_age_of_onset"), f.freqDF)
}

it should "return observed homozygote alleles as BS2 true" in {
val f = bs2Fixture

val freqData = Seq(
Row("1", 1, 2, "A", "C", Array("gene1"), Row(Row(25, 4))),
Row("1", 3, 4, "T", "C", Array("gene1"), Row(Row(10, 0))))
val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), f.freqSchema)

val queryData = Seq(
Row("1", 1, 2, "A", "C", "gene1"),
Row("1", 3, 4, "T", "C", "gene1"),
Row("1", 5, 6, "G", "C", "gene1"))
val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), f.querySchema)

val resultData = Seq(
Row("1", 1, 2, "A", "C", "gene1", Row(25, 4, false, true)),
Row("1", 3, 4, "T", "C", "gene1", Row(10, 0, false, false)),
Row("1", 5, 6, "G", "C", "gene1", Row(null, null, false, false)),
)

val result = queryDF.getBS2(f.orphanetDF, freqDF)
result.collect() should contain theSameElementsAs resultData
}

it should "return observed heterozygote allele in recessive non-adult onset diseases as BS2 true" in {
val f = bs2Fixture

val freqData = Seq(
Row("1", 1, 2, "A", "C", Array("gene1"), Row(Row(15, 0))),
Row("1", 1, 2, "A", "C", Array("gene2"), Row(Row(15, 0))),
Row("1", 1, 2, "A", "C", Array("gene3"), Row(Row(15, 0))),
)
val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), f.freqSchema)

val queryData = Seq(
Row("1", 1, 2, "A", "C", "gene1"),
Row("1", 1, 2, "A", "C", "gene2"),
Row("1", 1, 2, "A", "C", "gene3"))
val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), f.querySchema)

val resultData = Seq(
Row("1", 1, 2, "A", "C", "gene1", Row(15, 0, false, false)),
Row("1", 1, 2, "A", "C", "gene2", Row(15, 0, true, true)),
Row("1", 1, 2, "A", "C", "gene3", Row(15, 0, false, false)),
)

val result = queryDF.getBS2(f.orphanetDF, freqDF)
result.collect() should contain theSameElementsAs resultData
}

def pp2Fixture = {
new {
val clinvarSchema = new StructType()
.add("geneinfo", StringType, true)
.add("clin_sig", new ArrayType(StringType, true), true)
.add("mc", new ArrayType(StringType, true), true)

val clinvarData = Seq(
Row("gene1", Array("Pathogenic"), Array("missense_variant")),
Row("gene1", Array("Pathogenic"), Array("missense_variant")),
Row("gene1", Array("Pathogenic"), Array("missense_variant")),
Row("gene1", Array("Benign"), Array("missense_variant")),
Row("gene1", Array("Pathogenic"), Array("upstream_gene_variant")),
Row("gene2", Array("Benign"), Array("missense_variant")),
)

val clinvarDF = spark.createDataFrame(spark.sparkContext.parallelize(clinvarData), clinvarSchema)

val querySchema = variantSchema
.add("symbol", StringType, true)
.add("consequences", new ArrayType(StringType, true), true)

val queryData = Seq(
Row("1", 1, 2, "A", "C", "gene1", Array("missense_variant")),
Row("1", 1, 2, "A", "T", "gene1", Array("upstream_gene_variant")),
Row("1", 1, 2, "A", "T", "gene2", Array("missense_variant"))
)

val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), querySchema)

val resultSchema = variantSchema
.add("symbol", StringType, true)
.add("consequences", new ArrayType(StringType, true), true)
.add("pp2", new StructType()
.add("n_benign", IntegerType, false)
.add("n_pathogenic", IntegerType, false)
.add("is_missense_pathogenic", BooleanType, false)
.add("score", BooleanType, true), false
)

val resultData = Seq(
Row("1", 1, 2, "A", "C", "gene1", Array("missense_variant"), Row(1, 3, true, true)),
Row("1", 1, 2, "A", "T", "gene1", Array("upstream_gene_variant"), Row(1, 3, true, false)),
Row("1", 1, 2, "A", "T", "gene2", Array("missense_variant"), Row(1, 0, false, false)),
)

val resultDF = spark.createDataFrame(spark.sparkContext.parallelize(resultData), resultSchema)

}
}

"get_PP2" should "throw IllegalArgumentException if `mc` column is absent from the clinvar DataFrame" in {
val f = pp2Fixture

an[IllegalArgumentException] should be thrownBy f.queryDF.getPP2(f.clinvarDF.drop("mc"))
}

it should "correctly classify PP2 variants" in {
val f = pp2Fixture

val result = f.queryDF.getPP2(f.clinvarDF)
result.collect() should contain theSameElementsAs f.resultDF.collect()
}


}