Skip to content

Commit

Permalink
Merge pull request #367 from FederatedAI/develop-1.10.0
Browse files Browse the repository at this point in the history
Develop 1.10.0
  • Loading branch information
zhihuiwan authored Dec 28, 2022
2 parents 6feffaf + 996eab5 commit cee0306
Show file tree
Hide file tree
Showing 15 changed files with 349 additions and 184 deletions.
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Release 1.10.0
## Major Features and Improvements
* Add connection test API
* May configure gRPC message size limit
## Bug Fixes
* Fix module duplication issue in model

# Release 1.9.1
## Bug Fixes
* Fix parameter inheritance when loading non-model modules from ModelLoader
Expand Down
4 changes: 3 additions & 1 deletion conf/component_registry.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
"homo_model_convert": "protobuf.homo_model_convert.homo_model_convert",
"anonymous_generator": "util.anonymous_generator_util.Anonymous",
"data_format": "util.data_format_preprocess.DataFormatPreProcess",
"hetero_model_merge": "protobuf.model_merge.merge_hetero_models.hetero_model_merge"
"hetero_model_merge": "protobuf.model_merge.merge_hetero_models.hetero_model_merge",
"extract_woe_array_dict": "protobuf.model_migrate.binning_model_migrate.extract_woe_array_dict",
"merge_woe_array_dict": "protobuf.model_migrate.binning_model_migrate.merge_woe_array_dict"
}
}
}
99 changes: 51 additions & 48 deletions python/fate_flow/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@
from werkzeug.wrappers.request import Request

from fate_arch.common.base_utils import CustomJSONEncoder

from fate_flow.entity import RetCode
from fate_flow.hook import HookManager
from fate_flow.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
from fate_flow.settings import (API_VERSION, access_logger, stat_logger, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION)
from fate_flow.utils.api_utils import server_error_response, get_json_result
from fate_flow.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
from fate_flow.utils.api_utils import get_json_result, server_error_response


__all__ = ['app']


logger = logging.getLogger('flask.app')
for h in access_logger.handlers:
logger.addHandler(h)
Expand All @@ -38,75 +41,75 @@

app = Flask(__name__)
app.url_map.strict_slashes = False
app.errorhandler(Exception)(server_error_response)
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)

pages_dir = [
Path(__file__).parent,
Path(__file__).parent.parent / 'scheduling_apps'
]
pages_path = [j for i in pages_dir for j in i.glob('*_app.py')]
scheduling_url_prefix = []
client_url_prefix = []
for path in pages_path:
page_name = path.stem.rstrip('_app')
module_name = '.'.join(path.parts[path.parts.index('fate_flow'):-1] + (page_name, ))

spec = spec_from_file_location(module_name, path)

def search_pages_path(pages_dir):
return [path for path in pages_dir.glob('*_app.py') if not path.name.startswith('.')]


def register_page(page_path):
page_name = page_path.stem.rstrip('_app')
module_name = '.'.join(page_path.parts[page_path.parts.index('fate_flow'):-1] + (page_name, ))

spec = spec_from_file_location(module_name, page_path)
page = module_from_spec(spec)
page.app = app
page.manager = Blueprint(page_name, module_name)
sys.modules[module_name] = page
spec.loader.exec_module(page)

if not isinstance(page.manager, Blueprint):
raise TypeError(f'page.manager should be {Blueprint!r}, got {type(page.manager)}. filepath: {path!s}')

api_version = getattr(page, 'api_version', API_VERSION)
page_name = getattr(page, 'page_name', page_name)
url_prefix = f'/{API_VERSION}/{page_name}'

app.register_blueprint(page.manager, url_prefix=f'/{api_version}/{page_name}')
if 'scheduling_apps' in path.parts:
scheduling_url_prefix.append(f'/{api_version}/{page_name}')
else:
client_url_prefix.append(f'/{api_version}/{page_name}')

app.register_blueprint(page.manager, url_prefix=url_prefix)
return url_prefix

stat_logger.info('imported pages: %s', ' '.join(str(path) for path in pages_path))


@app.before_request
def authentication_before_request():
if CLIENT_AUTHENTICATION:
_result = client_authentication_before_request()
if _result:
return _result
if SITE_AUTHENTICATION:
_result = site_authentication_before_request()
if _result:
return _result
client_urls_prefix = [
register_page(path)
for path in search_pages_path(Path(__file__).parent)
]
scheduling_urls_prefix = [
register_page(path)
for path in search_pages_path(Path(__file__).parent.parent / 'scheduling_apps')
]


def client_authentication_before_request():
for url_prefix in scheduling_url_prefix:
for url_prefix in scheduling_urls_prefix:
if request.path.startswith(url_prefix):
return
parm = ClientAuthenticationParameters(full_path=request.full_path, headers=request.headers, form=request.form,
data=request.data, json=request.json)
result = HookManager.client_authentication(parm)

result = HookManager.client_authentication(ClientAuthenticationParameters(
request.full_path, request.headers,
request.form, request.data, request.json,
))

if result.code != RetCode.SUCCESS:
return get_json_result(result.code, result.message)


def site_authentication_before_request():
from flask import request
for url_prefix in client_url_prefix:
for url_prefix in client_urls_prefix:
if request.path.startswith(url_prefix):
return
body = request.json
headers = request.headers
site_signature = headers.get("site_signature")
result = HookManager.site_authentication(
AuthenticationParameters(site_signature=site_signature, src_party_id=headers.get("src_party_id"), body=body))

result = HookManager.site_authentication(AuthenticationParameters(
request.headers.get('src_party_id'),
request.headers.get('site_signature'),
request.json,
))

if result.code != RetCode.SUCCESS:
return get_json_result(result.code, result.message)


@app.before_request
def authentication_before_request():
if CLIENT_AUTHENTICATION:
return client_authentication_before_request()

if SITE_AUTHENTICATION:
return site_authentication_before_request()
150 changes: 146 additions & 4 deletions python/fate_flow/apps/component_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

from fate_flow.component_env_utils.env_utils import get_class_object
from fate_flow.db.component_registry import ComponentRegistry
from fate_flow.db.db_models import PipelineComponentMeta
from fate_flow.model.sync_model import SyncComponent
from fate_flow.pipelined_model.pipelined_model import PipelinedModel
from fate_flow.settings import ENABLE_MODEL_STORE
from fate_flow.utils.api_utils import error_response, get_json_result, validate_request
from fate_flow.utils.detect_utils import check_config
from fate_flow.utils.job_utils import generate_job_id
from fate_flow.utils.model_utils import gen_party_model_id
from fate_flow.utils.schedule_utils import get_dsl_parser_by_version

Expand All @@ -41,7 +43,7 @@ def get_component(component_name):
@manager.route('/validate', methods=['POST'])
def validate_component_param():
if not request.json or not isinstance(request.json, dict):
return error_response(400, 'bad request')
return error_response(400)

required_keys = [
'component_name',
Expand Down Expand Up @@ -79,7 +81,7 @@ def validate_component_param():
'model_id', 'model_version', 'guest_party_id', 'host_party_ids',
'component_name', 'model_type', 'output_format',
)
def hetero_merge():
def hetero_model_merge():
request_data = request.json

if ENABLE_MODEL_STORE:
Expand All @@ -91,7 +93,7 @@ def hetero_merge():
component_name=request_data['component_name'],
)
if not sync_component.local_exists() and sync_component.remote_exists():
sync_component.download(True)
sync_component.download()

for party_id in request_data['host_party_ids']:
sync_component = SyncComponent(
Expand All @@ -102,7 +104,7 @@ def hetero_merge():
component_name=request_data['component_name'],
)
if not sync_component.local_exists() and sync_component.remote_exists():
sync_component.download(True)
sync_component.download()

model = PipelinedModel(
gen_party_model_id(
Expand Down Expand Up @@ -167,3 +169,143 @@ def hetero_merge():
request_data.get('include_guest_coef', False),
)
return get_json_result(data=data)


@manager.route('/woe_array/extract', methods=['POST'])
@validate_request(
'model_id', 'model_version', 'role', 'party_id', 'component_name',
)
def woe_array_extract():
if request.json['role'] != 'guest':
return error_response(400, 'Only support guest role.')

if ENABLE_MODEL_STORE:
sync_component = SyncComponent(
role=request.json['role'],
party_id=request.json['party_id'],
model_id=request.json['model_id'],
model_version=request.json['model_version'],
component_name=request.json['component_name'],
)
if not sync_component.local_exists() and sync_component.remote_exists():
sync_component.download()

model = PipelinedModel(
gen_party_model_id(
request.json['model_id'],
request.json['role'],
request.json['party_id'],
),
request.json['model_version'],
).read_component_model(
request.json['component_name'],
output_json=True,
)

param = None
meta = None

for k, v in model.items():
if k.endswith('Param'):
param = v
elif k.endswith('Meta'):
meta = v
else:
return error_response(400, f'Unknown model key: "{k}".')

if param is None or meta is None:
return error_response(400, 'Invalid model.')

data = get_class_object('extract_woe_array_dict')(param)
return get_json_result(data=data)


@manager.route('/woe_array/merge', methods=['POST'])
@validate_request(
'model_id', 'model_version', 'role', 'party_id', 'component_name', 'woe_array',
)
def woe_array_merge():
if request.json['role'] != 'host':
return error_response(400, 'Only support host role.')

pipelined_model = PipelinedModel(
gen_party_model_id(
request.json['model_id'],
request.json['role'],
request.json['party_id'],
),
request.json['model_version'],
)

query = pipelined_model.pipelined_component.get_define_meta_from_db(
PipelineComponentMeta.f_component_name == request.json['component_name'],
)
if not query:
return error_response(404, 'Component not found.')
query = query[0]

if ENABLE_MODEL_STORE:
sync_component = SyncComponent(
role=query.f_role,
party_id=query.f_party_id,
model_id=query.f_model_id,
model_version=query.f_model_version,
component_name=query.f_component_name,
)
if not sync_component.local_exists() and sync_component.remote_exists():
sync_component.download()

model = pipelined_model._read_component_model(
query.f_component_name,
query.f_model_alias,
)

for model_name, (
buffer_name,
buffer_string,
buffer_dict,
) in model.items():
if model_name.endswith('Param'):
string_merged, dict_merged = get_class_object('merge_woe_array_dict')(
buffer_name,
buffer_string,
buffer_dict,
request.json['woe_array'],
)
model[model_name] = (
buffer_name,
string_merged,
dict_merged,
)
break

pipelined_model = PipelinedModel(
pipelined_model.party_model_id,
generate_job_id()
)

pipelined_model.save_component_model(
query.f_component_name,
query.f_component_module_name,
query.f_model_alias,
model,
query.f_run_parameters,
)

if ENABLE_MODEL_STORE:
sync_component = SyncComponent(
role=query.f_role,
party_id=query.f_party_id,
model_id=query.f_model_id,
model_version=pipelined_model.model_version,
component_name=query.f_component_name,
)
sync_component.upload()

return get_json_result(data={
'role': query.f_role,
'party_id': query.f_party_id,
'model_id': query.f_model_id,
'model_version': pipelined_model.model_version,
'component_name': query.f_component_name,
})
Loading

0 comments on commit cee0306

Please sign in to comment.