diff --git a/Makefile b/Makefile index d7e2a73a..eab71919 100644 --- a/Makefile +++ b/Makefile @@ -1,29 +1,5 @@ -UNAME := $(shell uname) - -ifeq ($(UNAME), Darwin) - OS_MESSAGE := "MacOS detected" - CMAKE_SETUP := "which cmake || brew install cmake" - PROTOBUF_SETUP := "which protoc || brew install protobuf" - OPENSSL_SETUP := "which openssl || brew install openssl" -else ifeq ($(UNAME), Linux) - OS_MESSAGE := "Linux detected" - CMAKE_SETUP := "which cmake || sudo apt-get install --yes build-essential cmake" - PROTOBUF_SETUP := "which protoc || sudo apt-get install --yes protobuf-compiler" - OPENSSL_SETUP := "which openssl || sudo apt-get install --yes libssl-dev" -else - OS_MESSAGE := "Unsupported OS; please install rust, cmake, protobuf, and openssl manually" - CMAKE_SETUP := "" - PROTOBUF_SETUP := "" - OPENSSL_SETUP := "" -endif - setup: - @echo "${OS_MESSAGE}: installing..." - $(shell "${CMAKE_SETUP}") - $(shell "${PROTOBUF_SETUP}") - $(shell "${OPENSSL_SETUP}") - which cargo || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - which maturin || pip install maturin[patchelf] + @./setup.sh publish: maturin publish diff --git a/docs/mixer.md b/docs/mixer.md index afa49fd6..6a20add0 100644 --- a/docs/mixer.md +++ b/docs/mixer.md @@ -25,7 +25,7 @@ The following parameters are supported either via CLI (e.g. `dolma mix --paramet |`streams[].span_replacement`|No| A list of objects specifying spans of text to be replaced. | |`streams[].span_replacement[].span`|No| A json-path expression for an attribute that contains an array of spans. Each span should be list of length three: `[start, end, score]`. | |`streams[].span_replacement[].min_score`|No| If the span score is less than this value, the span will not be replaced. | -|`streams[].span_replacement[].replacement`|No| The text that should be inserted in place of the span. Use `{}` to represent the original text. | +|`streams[].span_replacement[].replacement`|No| The text that should be inserted in place of the span. Use `{}` to represent the original text. Field selection from the document is also supported by prefixing a jq selector with `$`. Note: Escape a leading $ if you do not with to use jq selector pattern. | |`work_dir.input`|No| Path to a local scratch directory where temporary input files can be placed. If not provided, Dolma will make one for you and delete it upon completion. | |`work_dir.output`|No| Path to a local scratch directory where temporary output files can be placed. If not provided, Dolma will make one for you and delete it upon completion. | |`processes`|No| Number of processes to use for mixing. By default 1 process is used. | diff --git a/pyproject.toml b/pyproject.toml index 13ff1ff7..cc48c430 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dolma" -version = "1.0.12" +version = "1.0.13" description = "Data filters" license = { text = "Apache-2.0" } readme = "README.md" diff --git a/python/dolma/cli/mixer.py b/python/dolma/cli/mixer.py index 8be8a0d9..943d7f74 100644 --- a/python/dolma/cli/mixer.py +++ b/python/dolma/cli/mixer.py @@ -40,7 +40,7 @@ class SpanReplacementConfig: default=None, help="Maximum score for the span to be replaced. Either min_score or max_score must be specified.", ) - replacement: str = field(default="", help="Replacement for the span") + replacement: str = field(default="", help="Replacement config for the span(s).") syntax: str = field( default="jsonpath", help="Syntax to use for filter expressions. Currently only JSONPath is supported. Defaults to JSONPath.", diff --git a/setup.sh b/setup.sh new file mode 100755 index 00000000..0b4212a7 --- /dev/null +++ b/setup.sh @@ -0,0 +1,40 @@ +#!/bin/bash +set -e + +UNAME="$(uname)" +PLATFORM="$(uname -m)" + +if [[ $UNAME == "Darwin" ]]; then + echo "MacOS detected..." + which cmake || brew install cmake + which protoc || brew install protobuf + which openssl || brew install openssl +elif [[ $UNAME == "Linux" ]]; then + echo "Linux detected..." + which cmake || sudo apt-get install --yes build-essential cmake + which protoc || sudo apt-get install --yes protobuf-compiler + which openssl || sudo apt-get install --yes libssl-dev +else + echo "Unsupported OS; please install rust, cmake, protobuf, maturin and openssl manually!" + exit 1 +fi + +which cargo || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + +if [[ $PLATFORM == "x86_64" ]]; then + echo "x86_64 detected..." + which maturin || pip install maturin[patchelf] +fi + +if [[ $PLATFORM = "aarch64" ]]; then + echo "aarch64 detected..." + which maturin || pip install maturin +fi + +if [[ $PLATFORM = arm* ]]; then + echo "arm detected..." + which maturin || pip install maturin +else + echo "Unsupported platform; please install maturin manually" + exit 0 +fi diff --git a/src/shard.rs b/src/shard.rs index 526f4f4f..d5992caa 100644 --- a/src/shard.rs +++ b/src/shard.rs @@ -399,9 +399,10 @@ impl Shard { ); new_text.push_str(&replacement_text); } + data["text"] = Value::String(new_text); } - // } + for f in self.discard_fields.iter().flatten() { data.as_object_mut().unwrap().remove(f); } @@ -467,7 +468,7 @@ impl Shard { } pub mod shard_config { - use crate::filters::Selector; + use crate::filters::{JqSelector, Selector}; use jsonpath_rust::JsonPathFinder; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -543,7 +544,34 @@ pub mod shard_config { selector: Selector, min_score: f64, max_score: f64, - replacement: String, + replacement: Replacement, + } + + pub enum Replacement { + Selectors(JqSelector), + String(String), + } + + impl Replacement { + pub fn new(string: &str) -> Result { + // Note: Users should escape leading $ in replacement strings + if string.starts_with("$") { + // Strip leading $ and create a selector + let selector = JqSelector::new(&string[1..])?; + Ok(Replacement::Selectors(selector)) + } else { + Ok(Replacement::String(string.to_string())) + } + } + + pub fn get(&self, json: &Value) -> Result { + match self { + Replacement::Selectors(selector) => { + Ok(serde_json::from_value(selector.select(json)?.to_owned()).unwrap()) + } + Replacement::String(s) => Ok(s.clone()), + } + } } impl SpanReplacer { @@ -553,7 +581,7 @@ pub mod shard_config { selector: Selector::new(&config).unwrap(), min_score: config.min_score.unwrap_or(f64::NEG_INFINITY), max_score: config.max_score.unwrap_or(f64::INFINITY), - replacement: config.replacement.clone(), + replacement: Replacement::new(&config.replacement).unwrap(), } } @@ -575,7 +603,7 @@ pub mod shard_config { let replacement = SpanReplacement { start: start as usize, end: end as usize, - replacement: self.replacement.clone(), + replacement: self.replacement.get(json).unwrap(), }; Some(replacement) } else {