Skip to content

Commit

Permalink
search size
Browse files Browse the repository at this point in the history
  • Loading branch information
pdasigi committed Oct 25, 2024
1 parent b3c19d0 commit a7b608f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
4 changes: 3 additions & 1 deletion decontamination/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,6 @@ You can specify a `--match_threshold` here as well, and the behavior is similar

### Decontamination

If you need to remove instances from the training sets that match any of the test instances, just pass a `--decontaminate` option to `search.py`. The output directory will contain one decontaminated `jsonl` file per training dataset. If you pass a `--match_treshold`, only those train instances that have a matching score greater than the threshold with *any* of the test instances will be removed.
If you need to remove instances from the training sets that match any of the test instances, just pass a `--decontaminate` option to `search.py`. The output directory will contain one decontaminated `jsonl` file per training dataset. If you pass a `--match_treshold`, only those train instances that have a matching score greater than the threshold with *any* of the test instances will be removed.

Note that elasticsearch retrieves a limited number of hits each time you search. You can increase this by requesting a larger number of results by passing a different value to `--search_size` (default is 100). Setting this to a larger number (e.g. 10000) is a good idea if you are decontaminating datasets. Since elasticsearch does not necessarily retrieve all the documents that match, it is not guaranteed that decontamination removes all the matching training instances. You can always check for contamination after decontaminating a dataset to see how effective it was.
19 changes: 11 additions & 8 deletions decontamination/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_ngram_mapping(string: str, n: int):
return mapping


def exact_match(es, index_name, query_dataset, fields):
def exact_match(es, index_name, query_dataset, fields, search_size):
match_scores = []
output_data = []
matching_train_indices = set()
Expand All @@ -48,6 +48,7 @@ def exact_match(es, index_name, query_dataset, fields):
index=index_name,
search_type="query_then_fetch",
rest_total_hits_as_int=True,
size=search_size,
query={
"bool": {
"filter": [
Expand Down Expand Up @@ -79,7 +80,7 @@ def exact_match(es, index_name, query_dataset, fields):
return match_scores, output_data, matching_train_indices


def ngram_match(es, index_name, query_dataset, fields, ngram_size):
def ngram_match(es, index_name, query_dataset, fields, ngram_size, search_size):
match_scores = []
output_data = []
# Maps ids in the HF dataset ("original_id") to the list of matching scores with test instances, so that we can compute the max score for
Expand Down Expand Up @@ -107,6 +108,7 @@ def ngram_match(es, index_name, query_dataset, fields, ngram_size):
index=index_name,
search_type="query_then_fetch",
rest_total_hits_as_int=True,
size=search_size,
query={
"bool": {
"filter": [
Expand All @@ -132,7 +134,7 @@ def ngram_match(es, index_name, query_dataset, fields, ngram_size):

if matching_doc_ids:
# Averaging the match scores of training documents over all query strings.
aggregated_match_scores = {doc_id: sum([x.get(doc_id, 0.0) for x in query_string_match_scores]) / len(query_string_match_scores) for doc_id in matching_doc_ids}
aggregated_match_scores = {doc_id: sum([x.get(doc_id, 0.0) for x in query_string_match_scores]) / len(query_strings) for doc_id in matching_doc_ids}
sorted_matches = sorted(aggregated_match_scores.items(), key=lambda x: x[1], reverse=True)
match_info = []
for doc_id, score in sorted_matches:
Expand Down Expand Up @@ -160,7 +162,7 @@ def ngram_match(es, index_name, query_dataset, fields, ngram_size):
return match_scores, output_data, max_train_match_scores


def vector_match(es, index_name, query_dataset, fields, model, tokenizer, max_batch_tokens):
def vector_match(es, index_name, query_dataset, fields, model, tokenizer, max_batch_tokens, search_size):
match_scores = []
output_data = []
# Maps ids in the HF dataset ("original_id") to the list of matching scores with test instances, so that we can compute the max score for
Expand All @@ -179,7 +181,7 @@ def vector_match(es, index_name, query_dataset, fields, model, tokenizer, max_ba
for query, embedding in zip(batch_inputs, question_embeddings):
sem_search = es.search(
index=index_name,
knn={"field": "vector", "query_vector": embedding.cpu().numpy(), "k": 10, "num_candidates": 100},
knn={"field": "vector", "query_vector": embedding.cpu().numpy(), "k": search_size, "num_candidates": 10 * search_size},
)
results = sem_search["hits"]["hits"][:5]
match_scores.append(results[0]["_score"])
Expand Down Expand Up @@ -219,6 +221,7 @@ def main():
parser.add_argument("--train_dataset_names", type=str, nargs="+")
parser.add_argument("--dataset_mixer_config", type=str, help="Path to a train config file in yml format with a `dataset_mixer` field.")
parser.add_argument("--index_type", type=str, choices=["text", "vector"], default="text")
parser.add_argument("--search_size", type=int, default=100, help="Number of search results to retrieve from elasticsearch. Increasing this makes decontamination more accurate and search slower.")
parser.add_argument("--ngram_size", type=int, help="If `index_type` is `text`, will use n-gram matches of this size if this field is set. Default is full match.")
parser.add_argument("--match_threshold", type=float, help="For ngram and vector matching, transform match scores to 0/1 based on this threshold.")
parser.add_argument("--model", type=str, default="nvidia/NV-Embed-v2")
Expand Down Expand Up @@ -291,17 +294,17 @@ def main():

if args.index_type == "text":
if args.ngram_size is None:
match_scores, output_data, train_indices = exact_match(es, index_name, query_dataset, fields)
match_scores, output_data, train_indices = exact_match(es, index_name, query_dataset, fields, args.search_size)
contaminated_ids.update(train_indices)
else:
match_scores, output_data, train_indices_with_scores = ngram_match(es, index_name, query_dataset, fields, args.ngram_size)
match_scores, output_data, train_indices_with_scores = ngram_match(es, index_name, query_dataset, fields, args.ngram_size, args.search_size)
if args.match_threshold is not None:
match_scores = [1 if score > args.match_threshold else 0 for score in match_scores]
contaminated_ids.update([_id for _id, score in train_indices_with_scores.items() if score > args.match_threshold])

else:
model, tokenizer = prepare_embedding_model(args.model)
match_scores, output_data, train_indices_with_scores = vector_match(es, index_name, query_dataset, fields, model, tokenizer, args.max_batch_tokens)
match_scores, output_data, train_indices_with_scores = vector_match(es, index_name, query_dataset, fields, model, tokenizer, args.max_batch_tokens, args.search_size)
if args.match_threshold is not None:
match_scores = [1 if score > args.match_threshold else 0 for score in match_scores]
contaminated_ids.update([_id for _id, score in train_indices_with_scores.items() if score > args.match_threshold])
Expand Down

0 comments on commit a7b608f

Please sign in to comment.