Skip to content

Commit

Permalink
Adding rank as default required field (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
lastmansleeping authored Jan 20, 2022
1 parent 0692022 commit 02e873b
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 26 deletions.
6 changes: 6 additions & 0 deletions docs/source/misc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.1.11] - 2021-01-18

### Changed

- Adding rank feature to serving parse fn by default and removing dependence on required serving_info attribute

## [0.1.10] - 2021-12-29

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ query_key:
shape: null
serving_info:
name: queryId
required: false
default_value: ""
tfrecord_type: context
rank:
name: rank
Expand All @@ -23,7 +21,6 @@ rank:
shape: null
serving_info:
name: originalRank
required: true
default_value: 0
tfrecord_type: sequence
label:
Expand All @@ -37,8 +34,6 @@ label:
shape: null
serving_info:
name: clicked
required: false
default_value: 0
tfrecord_type: sequence
features:
- name: text_match_score
Expand All @@ -51,8 +46,6 @@ features:
shape: null
serving_info:
name: textMatchScore
required: true
default_value: 0.0
tfrecord_type: sequence
- name: page_views_score
node_name: page_views_score
Expand All @@ -77,8 +70,6 @@ features:
clip_value_max: 1000000.
serving_info:
name: pageViewsScore
required: true
default_value: 0.0
tfrecord_type: sequence
- name: quality_score
node_name: quality_score
Expand All @@ -90,7 +81,6 @@ features:
shape: null
serving_info:
name: qualityScore
required: false
tfrecord_type: sequence
- name: name_match
node_name: name_match
Expand All @@ -103,8 +93,6 @@ features:
shape: null
serving_info:
name: nameMatch
required: true
default_value: 0.0
tfrecord_type: sequence
- name: query_text
node_name: query_text
Expand All @@ -127,8 +115,6 @@ features:
to_lower: true
serving_info:
name: q
required: true
default_value: ""
tfrecord_type: context
- name: domain_id
node_name: domain_id
Expand All @@ -146,8 +132,6 @@ features:
default_value: null
serving_info:
name: domainID
required: true
default_value: 0
tfrecord_type: context
- name: domain_name
node_name: domain_name
Expand All @@ -166,7 +150,5 @@ features:
num_oov_buckets: 1
serving_info:
name: domainName
required: true
default_value: ""
tfrecord_type: context

15 changes: 9 additions & 6 deletions python/ml4ir/base/data/tfrecord_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,12 @@ def get_features_spec(self):
for feature_info in self.feature_config.get_all_features():
serving_info = feature_info["serving_info"]
if not self.required_fields_only or serving_info.get(
"required", feature_info["trainable"]) or feature_info["trainable"]:
"required", feature_info["trainable"]) or feature_info["trainable"]:
feature_name = feature_info["name"]
dtype = feature_info["dtype"]
default_value = self.feature_config.get_default_value(feature_info)
features_spec[feature_name] = io.FixedLenFeature([], dtype, default_value=default_value)
features_spec[feature_name] = io.FixedLenFeature(
[], dtype, default_value=default_value)

return features_spec

Expand Down Expand Up @@ -354,7 +355,7 @@ def get_feature(self, feature_info, extracted_features, sequence_size=0):
default_tensor = self.get_default_tensor(feature_info, sequence_size)

feature_tensor = extracted_features.get(feature_info["name"], default_tensor)

# Adjust shape
feature_tensor = tf.expand_dims(feature_tensor, axis=0)

Expand Down Expand Up @@ -455,8 +456,10 @@ def get_features_spec(self):
if feature_info.get("name") == self.feature_config.get_mask("name"):
continue
serving_info = feature_info["serving_info"]
if not self.required_fields_only or serving_info.get(
"required", feature_info["trainable"]) or feature_info["trainable"]:
if not self.required_fields_only or feature_info["trainable"] or \
(serving_info.get("required", feature_info["trainable"])) or \
(feature_info.get("name") == self.feature_config.get_rank("name")):

feature_name = feature_info["name"]
dtype = feature_info["dtype"]
default_value = self.feature_config.get_default_value(
Expand Down Expand Up @@ -585,7 +588,7 @@ def generate_and_add_mask(self, extracted_features, features_dict):
context_features, sequence_features = extracted_features
if (
self.required_fields_only
and not self.feature_config.get_rank("serving_info")["required"]
and not self.feature_config.get_rank("serving_info").get("required", True)
):
"""
Define dummy mask if the rank field is not a required field for serving
Expand Down
5 changes: 4 additions & 1 deletion python/ml4ir/base/features/feature_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,10 @@ def create_dummy_protobuf(self, num_records=1, required_only=False):
sequence_features = [
f
for f in self.get_sequence_features()
if ((not required_only) or (f["serving_info"].get("required", False)) or f["trainable"])
if ((not required_only) or \
(f["serving_info"].get("required", False)) or \
f["trainable"] or \
(f["name"] == self.get_rank("name")))
]

dummy_query = dict()
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def getReadMe():
setup(
name="ml4ir",
packages=find_namespace_packages(include=["ml4ir.*"]),
version="0.1.10",
version="0.1.11",
description="Machine Learning libraries for Information Retrieval",
long_description=getReadMe(),
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 02e873b

Please sign in to comment.