Skip to content

Commit

Permalink
Polymorphic span replacement (#200)
Browse files Browse the repository at this point in the history
* WIP: Polymorphic span replacement

* fmt

* Handle escaped chars in replacement

* Fixes

* Docs and environment cleanup

* Style

* Check for any arm platform
  • Loading branch information
undfined authored Sep 20, 2024
1 parent 621a6f4 commit 2436a6c
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 33 deletions.
26 changes: 1 addition & 25 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,29 +1,5 @@
UNAME := $(shell uname)

ifeq ($(UNAME), Darwin)
OS_MESSAGE := "MacOS detected"
CMAKE_SETUP := "which cmake || brew install cmake"
PROTOBUF_SETUP := "which protoc || brew install protobuf"
OPENSSL_SETUP := "which openssl || brew install openssl"
else ifeq ($(UNAME), Linux)
OS_MESSAGE := "Linux detected"
CMAKE_SETUP := "which cmake || sudo apt-get install --yes build-essential cmake"
PROTOBUF_SETUP := "which protoc || sudo apt-get install --yes protobuf-compiler"
OPENSSL_SETUP := "which openssl || sudo apt-get install --yes libssl-dev"
else
OS_MESSAGE := "Unsupported OS; please install rust, cmake, protobuf, and openssl manually"
CMAKE_SETUP := ""
PROTOBUF_SETUP := ""
OPENSSL_SETUP := ""
endif

setup:
@echo "${OS_MESSAGE}: installing..."
$(shell "${CMAKE_SETUP}")
$(shell "${PROTOBUF_SETUP}")
$(shell "${OPENSSL_SETUP}")
which cargo || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
which maturin || pip install maturin[patchelf]
@./setup.sh

publish:
maturin publish
Expand Down
2 changes: 1 addition & 1 deletion docs/mixer.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The following parameters are supported either via CLI (e.g. `dolma mix --paramet
|`streams[].span_replacement`|No| A list of objects specifying spans of text to be replaced. |
|`streams[].span_replacement[].span`|No| A json-path expression for an attribute that contains an array of spans. Each span should be list of length three: `[start, end, score]`. |
|`streams[].span_replacement[].min_score`|No| If the span score is less than this value, the span will not be replaced. |
|`streams[].span_replacement[].replacement`|No| The text that should be inserted in place of the span. Use `{}` to represent the original text. |
|`streams[].span_replacement[].replacement`|No| The text that should be inserted in place of the span. Use `{}` to represent the original text. Field selection from the document is also supported by prefixing a jq selector with `$`. Note: Escape a leading $ if you do not with to use jq selector pattern. |
|`work_dir.input`|No| Path to a local scratch directory where temporary input files can be placed. If not provided, Dolma will make one for you and delete it upon completion. |
|`work_dir.output`|No| Path to a local scratch directory where temporary output files can be placed. If not provided, Dolma will make one for you and delete it upon completion. |
|`processes`|No| Number of processes to use for mixing. By default 1 process is used. |
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dolma"
version = "1.0.12"
version = "1.0.13"
description = "Data filters"
license = { text = "Apache-2.0" }
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SpanReplacementConfig:
default=None,
help="Maximum score for the span to be replaced. Either min_score or max_score must be specified.",
)
replacement: str = field(default="", help="Replacement for the span")
replacement: str = field(default="", help="Replacement config for the span(s).")
syntax: str = field(
default="jsonpath",
help="Syntax to use for filter expressions. Currently only JSONPath is supported. Defaults to JSONPath.",
Expand Down
40 changes: 40 additions & 0 deletions setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
set -e

UNAME="$(uname)"
PLATFORM="$(uname -m)"

if [[ $UNAME == "Darwin" ]]; then
echo "MacOS detected..."
which cmake || brew install cmake
which protoc || brew install protobuf
which openssl || brew install openssl
elif [[ $UNAME == "Linux" ]]; then
echo "Linux detected..."
which cmake || sudo apt-get install --yes build-essential cmake
which protoc || sudo apt-get install --yes protobuf-compiler
which openssl || sudo apt-get install --yes libssl-dev
else
echo "Unsupported OS; please install rust, cmake, protobuf, maturin and openssl manually!"
exit 1
fi

which cargo || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y

if [[ $PLATFORM == "x86_64" ]]; then
echo "x86_64 detected..."
which maturin || pip install maturin[patchelf]
fi

if [[ $PLATFORM = "aarch64" ]]; then
echo "aarch64 detected..."
which maturin || pip install maturin
fi

if [[ $PLATFORM = arm* ]]; then
echo "arm detected..."
which maturin || pip install maturin
else
echo "Unsupported platform; please install maturin manually"
exit 0
fi
38 changes: 33 additions & 5 deletions src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,10 @@ impl Shard {
);
new_text.push_str(&replacement_text);
}

data["text"] = Value::String(new_text);
}
// }

for f in self.discard_fields.iter().flatten() {
data.as_object_mut().unwrap().remove(f);
}
Expand Down Expand Up @@ -467,7 +468,7 @@ impl Shard {
}

pub mod shard_config {
use crate::filters::Selector;
use crate::filters::{JqSelector, Selector};
use jsonpath_rust::JsonPathFinder;
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -543,7 +544,34 @@ pub mod shard_config {
selector: Selector,
min_score: f64,
max_score: f64,
replacement: String,
replacement: Replacement,
}

pub enum Replacement {
Selectors(JqSelector),
String(String),
}

impl Replacement {
pub fn new(string: &str) -> Result<Replacement, IoError> {
// Note: Users should escape leading $ in replacement strings
if string.starts_with("$") {
// Strip leading $ and create a selector
let selector = JqSelector::new(&string[1..])?;
Ok(Replacement::Selectors(selector))
} else {
Ok(Replacement::String(string.to_string()))
}
}

pub fn get(&self, json: &Value) -> Result<String, IoError> {
match self {
Replacement::Selectors(selector) => {
Ok(serde_json::from_value(selector.select(json)?.to_owned()).unwrap())
}
Replacement::String(s) => Ok(s.clone()),
}
}
}

impl SpanReplacer {
Expand All @@ -553,7 +581,7 @@ pub mod shard_config {
selector: Selector::new(&config).unwrap(),
min_score: config.min_score.unwrap_or(f64::NEG_INFINITY),
max_score: config.max_score.unwrap_or(f64::INFINITY),
replacement: config.replacement.clone(),
replacement: Replacement::new(&config.replacement).unwrap(),
}
}

Expand All @@ -575,7 +603,7 @@ pub mod shard_config {
let replacement = SpanReplacement {
start: start as usize,
end: end as usize,
replacement: self.replacement.clone(),
replacement: self.replacement.get(json).unwrap(),
};
Some(replacement)
} else {
Expand Down

0 comments on commit 2436a6c

Please sign in to comment.