Skip to content

Commit

Permalink
fix: CQDG-00 fix top note filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
adipaul1981 committed Dec 8, 2023
1 parent 95d5b82 commit 3e76c06
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 45 deletions.
23 changes: 7 additions & 16 deletions src/main/scala/bio/ferlab/HPOMain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ package bio.ferlab

import bio.ferlab.config.Config
import bio.ferlab.ontology.{ICDTerm, OntologyTerm}
import bio.ferlab.transform.DownloadTransformer.filterOntologiesForTopNode
import bio.ferlab.transform.{DownloadTransformer, WriteJson, WriteParquet}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{SaveMode, SparkSession}
import pureconfig.ConfigReader.Result
import pureconfig._
import pureconfig.generic.auto._
Expand Down Expand Up @@ -39,12 +38,14 @@ object HPOMain extends App {
WriteJson.toJson(resultICD10)(outputDir)
} else {
val fileBuffer = Source.fromURL(inputOboFileUrl)
val result = generateTermsWithAncestors(fileBuffer, topNode)
val result = generateTermsWithAncestors(fileBuffer)

WriteParquet.toParquet(result)(outputDir)
val filteredDf = WriteParquet.filterForTopNode(result, topNode)

filteredDf.write.mode(SaveMode.Overwrite).parquet(outputDir)
}

def generateTermsWithAncestors(fileBuffer: BufferedSource, topNode: Option[String]) = {
def generateTermsWithAncestors(fileBuffer: BufferedSource) = {
val dT: Seq[OntologyTerm] = DownloadTransformer.downloadOntologyData(fileBuffer)

val mapDT = dT map (d => d.id -> d) toMap
Expand All @@ -55,17 +56,7 @@ def generateTermsWithAncestors(fileBuffer: BufferedSource, topNode: Option[Strin

val ontologyWithParents = DownloadTransformer.transformOntologyData(dTwAncestorsParents)

val excludedParentsIds = topNode.flatMap(node => ontologyWithParents
.find{ case(term, _) => term.id == node }
.map{ case(_, parents) => parents.map(_.id) })

val ontologyWithParentsFiltered = (excludedParentsIds, topNode) match {
case (Some(parentsIds), Some(node)) => filterOntologiesForTopNode(ontologyWithParents, node, parentsIds)

case _ => ontologyWithParents
}

ontologyWithParentsFiltered.map {
ontologyWithParents.map {
case (k, v) if allParents.contains(k.id) => k -> (v, false)
case (k, v) => k -> (v, true)
}
Expand Down
12 changes: 0 additions & 12 deletions src/main/scala/bio/ferlab/transform/DownloadTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,6 @@ object DownloadTransformer {
})
}

def filterOntologiesForTopNode(ontologyWithParents: Map[OntologyTerm, Set[OntologyTerm]], desiredTopNode: String, unwantedParentsIds: Set[String]):
Map[OntologyTerm, Set[OntologyTerm]] = {

val topOntologyTerm = ontologyWithParents
.find(t => t._1.id == desiredTopNode)
.map{ case(term, _) => (term.copy(parents = Seq.empty[OntologyTerm]), Set.empty[OntologyTerm]) }

ontologyWithParents
.filter { case(_, parents) => parents.map(_.id).contains(desiredTopNode) }
.map{ case(term, parents) => (term, parents.filter( r => !unwantedParentsIds.contains(r.id))) } ++ topOntologyTerm
}

def getAllParentPath(term: OntologyTerm, originalTerm: OntologyTerm, data: Map[String, OntologyTerm], list: Set[OntologyTerm], cumulativeList: mutable.Map[OntologyTerm, Set[OntologyTerm]], allParents: Set[String]): mutable.Map[OntologyTerm, Set[OntologyTerm]] = {
term.parents.foreach(p => {
val parentTerm = data(p.id)
Expand Down
35 changes: 32 additions & 3 deletions src/main/scala/bio/ferlab/transform/WriteParquet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,49 @@ package bio.ferlab.transform

import bio.ferlab.ontology.OntologyTerm
import org.apache.spark.sql._
import org.apache.spark.sql.functions._


object WriteParquet {

def toParquet(data: Map[OntologyTerm, (Set[OntologyTerm], Boolean)])(outputDir: String)(implicit spark: SparkSession): Unit = {
def filterForTopNode(data: Map[OntologyTerm, (Set[OntologyTerm], Boolean)], excludedNode: Option[String])
(implicit spark: SparkSession): DataFrame = {
import spark.implicits._
data.map{ case(k, v) =>

val df = data.map { case (k, v) =>
OntologyTermOutput(
k.id,
k.name,
k.parents.map(_.toString),
v._1.map(i => BasicOntologyTermOutput(i.id, i.name, i.parents.map(_.toString))).toSeq,
v._2
)}.toSeq.toDF().write.mode(SaveMode.Overwrite).parquet(outputDir)
)
}.toSeq.toDF()

val excludedParents = excludedNode match {
case Some(node) => data.find { case (term, _) => term.id == node }.map { case (_, parents) => parents._1.map(_.id) }
case None => None
}

(excludedParents, excludedNode) match {
case (Some(targetParents), Some(node)) => df
//filter out all terms not part of target root
.where(array_contains(col("ancestors")("id"), node))
.withColumn("ancestors_exp", explode(col("ancestors")))
//filter out parents of target term
.filter(!col("ancestors_exp")("id").isin(targetParents.toSeq: _*))
// Make sure target term in ancestors do not contain any parents (it is the top root)
.withColumn("ancestors_exp_f", when(col("ancestors_exp")("id").equalTo(node),
struct(col("ancestors_exp")("id") as "id",
col("ancestors_exp")("name") as "name",
array().cast("array<string>") as "parents"))
.otherwise(col("ancestors_exp"))
)
.groupBy("id", "name", "parents")
.agg(collect_list(col("ancestors_exp_f")) as "ancestors")

case _ => df
}
}
}

33 changes: 19 additions & 14 deletions src/test/scala/HPOMainSpec.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import bio.ferlab.HPOMain
import org.apache.spark.sql.SparkSession
import bio.ferlab.transform.WriteParquet
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{SparkSession, functions}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

Expand All @@ -8,7 +10,6 @@ import scala.io.Source

class HPOMainSpec extends AnyFlatSpec with Matchers {


val study = "STU0000001"
val release = "1"

Expand All @@ -19,28 +20,32 @@ class HPOMainSpec extends AnyFlatSpec with Matchers {
.master("local")
.getOrCreate()

import spark.implicits._

spark.sparkContext.setLogLevel("ERROR")

"generateTermsWithAncestors" should "return all the terms" in {
val file = Source.fromFile("src/test/scala/resources/hp.obo")
val result = HPOMain.generateTermsWithAncestors(file, None)
val hpoAncestors = HPOMain.generateTermsWithAncestors(file)
val filteredDf = WriteParquet.filterForTopNode(hpoAncestors, None)
val result = filteredDf.select(functions.col("id")).as[String].collect()

result.filter(t => t._1.id.startsWith("HP:2")).keySet.map(_.id) shouldEqual Set("HP:21", "HP:22")
result.filter(t => t._1.id.startsWith("HP:3")).keySet.map(_.id) shouldEqual Set("HP:31","HP:32","HP:33","HP:34")
result.filter(t => t._1.id.startsWith("HP:4")).keySet.map(_.id) shouldEqual Set("HP:41","HP:42","HP:43","HP:44")
result should contain theSameElementsAs Set("HP:21", "HP:22", "HP:31", "HP:32", "HP:33", "HP:34", "HP:41", "HP:42", "HP:43", "HP:44")
}

"generateTermsWithAncestors" should "return only desired branch" in {
val file = Source.fromFile("src/test/scala/resources/hp.obo")
val result = HPOMain.generateTermsWithAncestors(file, Some("HP:22"))
val hpoAncestors = HPOMain.generateTermsWithAncestors(file)
val filteredDf = WriteParquet.filterForTopNode(hpoAncestors, Some("HP:22"))
val result = filteredDf.select(functions.col("id")).as[String].collect()


val test = filteredDf.select(col("id"), col("ancestors")("id")).as[(String, Seq[String])].collect

result should contain theSameElementsAs Set("HP:33", "HP:34", "HP:43", "HP:44")

result.filter(t => t._1.id.startsWith("HP:2")).keySet.map(_.id) should contain theSameElementsAs Set("HP:22")
result.filter(t => t._1.id.startsWith("HP:3")).keySet.map(_.id) should contain theSameElementsAs Set("HP:33","HP:34")
result.filter(t => t._1.id.startsWith("HP:4")).keySet.map(_.id) should contain theSameElementsAs Set("HP:43","HP:44")
val hp4 = test.find(_._1 == "HP:44").get

val testLeaf = result.find(t => t._1.id.equals("HP:44")).get._2._1
val testLeafParentsIds = testLeaf.map(_.id)
// Should not have HP:1 as one of its parents
testLeafParentsIds.toList should contain theSameElementsAs Seq("HP:22", "HP:34")
hp4._2 should contain theSameElementsAs Seq("HP:22", "HP:34")
}
}

0 comments on commit 3e76c06

Please sign in to comment.