From 5caff7cb0b368f892373d4c6d63112339b772f4a Mon Sep 17 00:00:00 2001 From: Fallen_Breath Date: Mon, 11 Dec 2023 22:01:18 +0800 Subject: [PATCH] Add `!!pb database convert_hash_method` command to convert database no need to pin on a hash method forever --- lang/en_us.yml | 13 +++ lang/zh_cn.yml | 13 +++ .../action/convert_hash_method_action.py | 98 +++++++++++++++++++ prime_backup/action/export_backup_action.py | 6 +- prime_backup/config/config.py | 4 +- prime_backup/db/access.py | 16 +-- prime_backup/db/session.py | 41 +++++++- prime_backup/exceptions.py | 4 + prime_backup/mcdr/command/commands.py | 29 ++++-- prime_backup/mcdr/task/__init__.py | 3 +- .../mcdr/task/backup/restore_backup_task.py | 6 +- .../mcdr/task/db/convert_hash_method_task.py | 50 ++++++++++ .../mcdr/task/general/show_help_task.py | 2 + prime_backup/mcdr/text_components.py | 7 ++ prime_backup/utils/blob_utils.py | 3 +- prime_backup/utils/bypass_io.py | 14 ++- prime_backup/utils/hash_utils.py | 32 +++--- 17 files changed, 296 insertions(+), 45 deletions(-) create mode 100644 prime_backup/action/convert_hash_method_action.py create mode 100644 prime_backup/mcdr/task/db/convert_hash_method_task.py diff --git a/lang/en_us.yml b/lang/en_us.yml index 0bea0df..5437318 100644 --- a/lang/en_us.yml +++ b/lang/en_us.yml @@ -141,6 +141,17 @@ prime_backup: pause.hover: Click to pause job {} resume: resume resume.hover: Click to resume job {} + db_convert_hash_method: + name: convert hash method + hash_method_unchanged: The hash method is already {} + missing_library: + Failed to import the target hasher, please make sure you have installed the required python library for {}. + Search hash_method in the document {} for more help + show_conversion: Prepare for hash method conversion, from {} to {} + confirm_target: convert + no_confirm: No choice has been made, convert hash method task aborted + aborted: Convert hash method task aborted + done: Converted the hash method from {} to {} db_overview: name: overview database title: Database overview @@ -238,8 +249,10 @@ prime_backup: §7{prefix} database overview§r: Report an overview of the database §7{prefix} database validate §a§r: Validate the correctness of contents in the database. Might take a long time §7{prefix} database vacuum§r: Compact the SQLite database manually, to reduce the size of the database file + §7{prefix} database convert_hash_method §r: Convert the currently used hash method to another. This will affect all data {scheduled_compact_notes} §d[Arguments]§r + §d§r: Available options: {hash_methods} §a§r: - §ablobs§r: Validate the correctness of blobs, e.g. data size, hash value - §afiles§r: Validate the correctness of file objects, e.g. the association between files and blobs diff --git a/lang/zh_cn.yml b/lang/zh_cn.yml index 2491e9f..9426f43 100644 --- a/lang/zh_cn.yml +++ b/lang/zh_cn.yml @@ -141,6 +141,17 @@ prime_backup: pause.hover: 点击以暂停运行作业{} resume: 继续 resume.hover: 点击以继续运行作业{} + db_convert_hash_method: + name: 哈希算法转换 + hash_method_unchanged: 哈希算法已经是{}了 + missing_library: + 无法导入目标哈希算法, 请确保你已经安装了算法{}所需的Python依赖库。 + 在文档{}中搜索hash_method以获得更多帮助 + show_conversion: 准备把哈希算法从{}转换为{} + confirm_target: 转换 + no_confirm: 未做出选择, 哈希算法转换任务中止 + aborted: 哈希算法转换任务中止 + done: 已将哈希算法从{}转换为{} db_overview: name: 概览数据库 title: 数据库概览 @@ -238,8 +249,10 @@ prime_backup: §7{prefix} database overview§r: 查看数据库信息概览 §7{prefix} database validate §a<组件>§r: 验证数据库内容的正确性。耗时可能较长 §7{prefix} database vacuum§r: 手动执行SQLite数据库的精简操作,减少数据库文件的体积 + §7{prefix} database convert_hash_method <哈希算法>§r: 转换当前使用的哈希算法为另一种算法。这将影响所有的数据 {scheduled_compact_notes} §d【参数帮助】§r + §d<哈希算法>§r: 可以选项: {hash_methods} §a<组件>§r: - §ablobs§r: 验证数据对象的正确性,如数据大小、哈希值 - §afiles§r: 验证文件对象的正确性,如文件与数据的关联 diff --git a/prime_backup/action/convert_hash_method_action.py b/prime_backup/action/convert_hash_method_action.py new file mode 100644 index 0000000..83585a4 --- /dev/null +++ b/prime_backup/action/convert_hash_method_action.py @@ -0,0 +1,98 @@ +import shutil +import time +from pathlib import Path +from typing import List, Dict, Set + +from prime_backup.action import Action +from prime_backup.compressors import Compressor +from prime_backup.db.access import DbAccess +from prime_backup.db.session import DbSession +from prime_backup.exceptions import PrimeBackupError +from prime_backup.types.hash_method import HashMethod +from prime_backup.utils import blob_utils, hash_utils, collection_utils + + +class HashCollisionError(PrimeBackupError): + """ + Same hash value, between 2 hash methods + """ + pass + + +class ConvertHashMethodAction(Action[None]): + def __init__(self, new_hash_method: HashMethod): + super().__init__() + self.new_hash_method = new_hash_method + + def __convert_blobs(self, session: DbSession, blob_hashes: List[str], old_hashes: Set[str], processed_hash_mapping: Dict[str, str]): + hash_mapping: Dict[str, str] = {} + blobs = list(session.get_blobs(blob_hashes).values()) + + # calc blob hashes + for blob in blobs: + blob_path = blob_utils.get_blob_path(blob.hash) + with Compressor.create(blob.compress).open_decompressed(blob_path) as f: + sah = hash_utils.calc_reader_size_and_hash(f, hash_method=self.new_hash_method) + hash_mapping[blob.hash] = sah.hash + if sah.hash in old_hashes: + raise HashCollisionError(sah.hash) + + # update the objects + for blob in blobs: + old_hash, new_hash = blob.hash, hash_mapping[blob.hash] + old_path = blob_utils.get_blob_path(old_hash) + new_path = blob_utils.get_blob_path(new_hash) + old_path.rename(new_path) + + processed_hash_mapping[old_hash] = new_hash + blob.hash = new_hash + + for file in session.get_file_by_blob_hashes(list(hash_mapping.keys())): + file.blob_hash = hash_mapping[file.blob_hash] + + def __replace_blob_store(self, old_store: Path, new_store: Path): + trash_bin = self.config.storage_path / 'temp' / 'old_blobs' + trash_bin.parent.mkdir(parents=True, exist_ok=True) + + old_store.rename(trash_bin) + new_store.rename(old_store) + shutil.rmtree(trash_bin) + + def run(self): + processed_hash_mapping: Dict[str, str] = {} # old -> new + try: + t = time.time() + with DbAccess.open_session() as session: + meta = session.get_db_meta() + if meta.hash_method == self.new_hash_method.name: + self.logger.info('Hash method of the database is already {}, no need to convert'.format(self.new_hash_method.name)) + return + + total_blob_count = session.get_blob_count() + all_hashes = session.get_all_blob_hashes() + all_hash_set = set(all_hashes) + cnt = 0 + for blob_hashes in collection_utils.slicing_iterate(all_hashes, 1000): + blob_hashes: List[str] = list(blob_hashes) + cnt += len(blob_hashes) + self.logger.info('Converting blobs {} / {}'.format(cnt, total_blob_count)) + + self.__convert_blobs(session, blob_hashes, all_hash_set, processed_hash_mapping) + session.flush_and_expunge_all() + + meta = session.get_db_meta() # get the meta again, cuz expunge_all() was called + meta.hash_method = self.new_hash_method.name + + self.logger.info('Syncing config and variables') + DbAccess.sync_hash_method() + self.config.backup.hash_method = self.new_hash_method.name + + self.logger.info('Conversion done, cost {}s'.format(round(time.time() - t, 2))) + + except Exception: + self.logger.info('Error occurs during convertion, applying rollback') + for old_hash, new_hash in processed_hash_mapping.items(): + old_path = blob_utils.get_blob_path(old_hash) + new_path = blob_utils.get_blob_path(new_hash) + new_path.rename(old_path) + raise diff --git a/prime_backup/action/export_backup_action.py b/prime_backup/action/export_backup_action.py index cab5576..6f593ac 100644 --- a/prime_backup/action/export_backup_action.py +++ b/prime_backup/action/export_backup_action.py @@ -17,7 +17,7 @@ from prime_backup.db import schema from prime_backup.db.access import DbAccess from prime_backup.db.session import DbSession -from prime_backup.exceptions import PrimeBackupError +from prime_backup.exceptions import PrimeBackupError, VerificationError from prime_backup.types.backup_meta import BackupMeta from prime_backup.types.export_failure import ExportFailures from prime_backup.types.tar_format import TarFormat @@ -25,10 +25,6 @@ from prime_backup.utils.bypass_io import BypassReader -class VerificationError(PrimeBackupError): - pass - - class _ExportInterrupted(PrimeBackupError): pass diff --git a/prime_backup/config/config.py b/prime_backup/config/config.py index 0a31dce..ab5bbf8 100644 --- a/prime_backup/config/config.py +++ b/prime_backup/config/config.py @@ -35,11 +35,11 @@ def get(cls) -> 'Config': return cls.__get_default() return _config - @functools.cached_property + @property def storage_path(self) -> Path: return Path(self.storage_root) - @functools.cached_property + @property def source_path(self) -> Path: return Path(self.backup.source_root) diff --git a/prime_backup/db/access.py b/prime_backup/db/access.py index fb77302..304ee0c 100644 --- a/prime_backup/db/access.py +++ b/prime_backup/db/access.py @@ -39,12 +39,7 @@ def init(cls, auto_migrate: bool = True): else: migration.ensure_version() - with cls.open_session() as session: - hash_method_str = str(session.get_db_meta().hash_method) - try: - cls.__hash_method = HashMethod[hash_method_str] - except KeyError: - raise ValueError('invalid hash method {!r} in db meta'.format(hash_method_str)) from None + cls.sync_hash_method() @classmethod def shutdown(cls): @@ -52,6 +47,15 @@ def shutdown(cls): for hdr in list(logger.handlers): logger.removeHandler(hdr) + @classmethod + def sync_hash_method(cls): + with cls.open_session() as session: + hash_method_str = str(session.get_db_meta().hash_method) + try: + cls.__hash_method = HashMethod[hash_method_str] + except KeyError: + raise ValueError('invalid hash method {!r} in db meta'.format(hash_method_str)) from None + @classmethod def __ensure_not_none(cls, value): if value is None: diff --git a/prime_backup/db/session.py b/prime_backup/db/session.py index f1cf806..f25606d 100644 --- a/prime_backup/db/session.py +++ b/prime_backup/db/session.py @@ -1,6 +1,6 @@ import contextlib import time -from typing import Optional, Sequence, Dict, ContextManager +from typing import Optional, Sequence, Dict, ContextManager, Iterator from typing import TypeVar, List from sqlalchemy import select, delete, desc, func, Select, JSON, text @@ -40,9 +40,22 @@ def __init__(self, session: Session): def add(self, obj: schema.Base): self.session.add(obj) + def expunge(self, obj: schema.Base): + self.session.expunge(obj) + + def expunge_all(self): + self.session.expunge_all() + def flush(self): self.session.flush() + def flush_and_expunge_all(self): + self.flush() + self.expunge_all() + + def commit(self): + self.session.commit() + @contextlib.contextmanager def no_auto_flush(self) -> ContextManager[None]: with self.session.no_autoflush: @@ -94,6 +107,14 @@ def list_blobs(self, limit: Optional[int] = None, offset: Optional[int] = None) s = s.offset(offset) return _list_it(self.session.execute(s).scalars().all()) + def iterate_blob_batch(self, *, batch_size: int = 3000) -> Iterator[List[schema.Blob]]: + limit, offset = batch_size, 0 + while True: + blobs = self.list_blobs(limit=limit, offset=offset) + if len(blobs) == 0: + break + yield blobs + def get_all_blob_hashes(self) -> List[str]: return _list_it(self.session.execute(select(schema.Blob.hash)).scalars().all()) @@ -161,6 +182,16 @@ def get_file(self, backup_id: int, path: str) -> schema.File: def get_file_raw_size_sum(self) -> int: return _int_or_0(self.session.execute(func.sum(schema.File.blob_raw_size).select()).scalar_one()) + def get_file_by_blob_hashes(self, hashes: List[str]) -> List[schema.File]: + hashes = collection_utils.deduplicated_list(hashes) + result = [] + for view in collection_utils.slicing_iterate(hashes, self.__safe_var_limit): + result.extend(self.session.execute( + select(schema.File). + where(schema.File.blob_hash.in_(view)) + ).scalars().all()) + return result + def get_file_count_by_blob_hashes(self, hashes: List[str]) -> int: cnt = 0 for view in collection_utils.slicing_iterate(hashes, self.__safe_var_limit): @@ -179,6 +210,14 @@ def list_files(self, limit: Optional[int] = None, offset: Optional[int] = None) s = s.offset(offset) return _list_it(self.session.execute(s).scalars().all()) + def iterate_file_batch(self, *, batch_size: int = 3000) -> Iterator[List[schema.File]]: + limit, offset = batch_size, 0 + while True: + files = self.list_files(limit=limit, offset=offset) + if len(files) == 0: + break + yield files + def delete_file(self, file: schema.File): self.session.delete(file) diff --git a/prime_backup/exceptions.py b/prime_backup/exceptions.py index 3b5a29f..9cb1962 100644 --- a/prime_backup/exceptions.py +++ b/prime_backup/exceptions.py @@ -18,3 +18,7 @@ def __init__(self, backup_id: int, path: str): class UnsupportedFileFormat(PrimeBackupError): def __init__(self, mode: int): self.mode = mode + + +class VerificationError(PrimeBackupError): + pass diff --git a/prime_backup/mcdr/command/commands.py b/prime_backup/mcdr/command/commands.py index 728298a..6e6b4e1 100644 --- a/prime_backup/mcdr/command/commands.py +++ b/prime_backup/mcdr/command/commands.py @@ -23,6 +23,7 @@ from prime_backup.mcdr.task.crontab.list_crontab_task import ListCrontabJobTask from prime_backup.mcdr.task.crontab.operate_crontab_task import OperateCrontabJobTask from prime_backup.mcdr.task.crontab.show_crontab_task import ShowCrontabJobTask +from prime_backup.mcdr.task.db.convert_hash_method_task import ConvertHashMethodTask from prime_backup.mcdr.task.db.show_db_overview_task import ShowDbOverviewTask from prime_backup.mcdr.task.db.vacuum_sqlite_task import VacuumSqliteTask from prime_backup.mcdr.task.db.validate_db_task import ValidateDbTask, ValidateParts @@ -31,6 +32,7 @@ from prime_backup.mcdr.task_manager import TaskManager from prime_backup.types.backup_filter import BackupFilter from prime_backup.types.backup_tags import BackupTagName +from prime_backup.types.hash_method import HashMethod from prime_backup.types.operator import Operator from prime_backup.types.standalone_backup_format import StandaloneBackupFormat from prime_backup.utils import misc_utils @@ -72,6 +74,10 @@ def cmd_db_validate(self, source: CommandSource, _: CommandContext, parts: Valid def cmd_db_vacuum(self, source: CommandSource, _: CommandContext): self.task_manager.add_task(VacuumSqliteTask(source)) + def cmd_db_convert_hash_method(self, source: CommandSource, context: CommandContext): + new_hash_method = context['hash_method'] + self.task_manager.add_task(ConvertHashMethodTask(source, new_hash_method)) + def cmd_make(self, source: CommandSource, context: CommandContext): def callback(_, err): if err is None: @@ -217,9 +223,13 @@ def create_backup_id(arg: str = 'backup_id', clazz: Type[Integer] = Integer) -> builder = SimpleCommandBuilder() + # help + builder.command('help', self.cmd_help) builder.command('help ', self.cmd_help) + builder.arg('what', Text).suggests(lambda: ShowHelpTask.COMMANDS_WITH_DETAILED_HELP) + # backup builder.command('make', self.cmd_make) builder.command('make ', self.cmd_make) @@ -228,32 +238,33 @@ def create_backup_id(arg: str = 'backup_id', clazz: Type[Integer] = Integer) -> builder.command('delete_range ', self.cmd_delete_range) builder.command('prune', self.cmd_prune) + builder.arg('backup_id', create_backup_id) + builder.arg('backup_id_range', IdRangeNode) + builder.arg('comment', GreedyText) + # crontab builder.command('crontab', self.cmd_crontab_show) builder.command('crontab ', self.cmd_crontab_show) builder.command('crontab pause', self.cmd_crontab_pause) builder.command('crontab resume', self.cmd_crontab_resume) + builder.arg('job_id', lambda n: Enumeration(n, CrontabJobId)) + # db builder.command('database overview', self.cmd_db_overview) builder.command('database validate all', functools.partial(self.cmd_db_validate, parts=ValidateParts.all())) builder.command('database validate blobs', functools.partial(self.cmd_db_validate, parts=ValidateParts.blobs)) builder.command('database validate files', functools.partial(self.cmd_db_validate, parts=ValidateParts.files)) builder.command('database vacuum', self.cmd_db_vacuum) + builder.command('database convert_hash_method ', self.cmd_db_convert_hash_method) + + builder.arg('hash_method', lambda n: Enumeration(n, HashMethod)) # operations builder.command('confirm', self.cmd_confirm) builder.command('abort', self.cmd_abort) - # node defs - builder.arg('backup_id', create_backup_id) - builder.arg('backup_id_range', IdRangeNode) - builder.arg('comment', GreedyText) - builder.arg('job_id', lambda n: Enumeration(n, CrontabJobId)) - builder.arg('page', lambda n: Integer(n).at_min(1)) - builder.arg('per_page', lambda n: Integer(n).at_min(1)) - builder.arg('what', Text).suggests(lambda: ShowHelpTask.COMMANDS_WITH_DETAILED_HELP) - + # subcommand permissions for name, level in permissions.items(): builder.literal(name).requires(get_permission_checker(name), get_permission_denied_text) diff --git a/prime_backup/mcdr/task/__init__.py b/prime_backup/mcdr/task/__init__.py index 73b518f..b2a27f0 100644 --- a/prime_backup/mcdr/task/__init__.py +++ b/prime_backup/mcdr/task/__init__.py @@ -19,8 +19,9 @@ class TaskEvent(enum.Enum): class Task(Generic[T], mcdr_utils.TranslationContext, ABC): def __init__(self, source: CommandSource): super().__init__(f'task.{self.id}') + from prime_backup.mcdr import mcdr_globals self.source = source - self.server = source.get_server() + self.server = mcdr_globals.server def get_name_text(self) -> RTextBase: return self.tr('name').set_color(RColor.aqua) diff --git a/prime_backup/mcdr/task/backup/restore_backup_task.py b/prime_backup/mcdr/task/backup/restore_backup_task.py index f4270b0..e8f05b4 100644 --- a/prime_backup/mcdr/task/backup/restore_backup_task.py +++ b/prime_backup/mcdr/task/backup/restore_backup_task.py @@ -59,11 +59,11 @@ def run(self): if self.needs_confirm: self.broadcast(self.tr('show_backup', TextComponents.backup_brief(backup))) - cr = self.wait_confirm(self.tr('confirm_target')) - if not cr.is_set(): + wr = self.wait_confirm(self.tr('confirm_target')) + if not wr.is_set(): self.broadcast(self.tr('no_confirm')) return - elif cr.get().is_cancelled(): + elif wr.get().is_cancelled(): self.broadcast(self.tr('aborted')) return diff --git a/prime_backup/mcdr/task/db/convert_hash_method_task.py b/prime_backup/mcdr/task/db/convert_hash_method_task.py new file mode 100644 index 0000000..a3660fc --- /dev/null +++ b/prime_backup/mcdr/task/db/convert_hash_method_task.py @@ -0,0 +1,50 @@ +from mcdreforged.api.all import * + +from prime_backup import constants +from prime_backup.action.convert_hash_method_action import ConvertHashMethodAction +from prime_backup.action.get_db_meta_action import GetDbMetaAction +from prime_backup.mcdr.task.basic_task import HeavyTask +from prime_backup.mcdr.text_components import TextComponents +from prime_backup.types.hash_method import HashMethod + + +class ConvertHashMethodTask(HeavyTask[None]): + def __init__(self, source: CommandSource, new_hash_method: HashMethod): + super().__init__(source) + self.new_hash_method = new_hash_method + + @property + def id(self) -> str: + return 'db_convert_hash_method' + + def run(self): + try: + self.new_hash_method.value.create_hasher() + except ImportError as e: + self.logger.warning('Failed to create hasher of {} due to ImportError: {}'.format(self.new_hash_method, e)) + self.reply(self.tr( + 'missing_library', + TextComponents.hash_method(self.new_hash_method), + TextComponents.url(constants.DOCUMENTATION_URL, click=True), + str(e) + )) + return + + db_meta = self.run_action(GetDbMetaAction()) + if db_meta.hash_method == self.new_hash_method.name: + self.reply(self.tr('hash_method_unchanged', TextComponents.hash_method(self.new_hash_method))) + return + + self.reply(self.tr('show_conversion', TextComponents.hash_method(db_meta.hash_method), TextComponents.hash_method(self.new_hash_method))) + wr = self.wait_confirm(self.tr('confirm_target')) + if not wr.is_set(): + self.reply(self.tr('no_confirm')) + return + elif wr.get().is_cancelled(): + self.reply(self.tr('aborted')) + return + + self.run_action(ConvertHashMethodAction(self.new_hash_method)) + self.server.save_config_simple(self.config) + + self.reply(self.tr('done', TextComponents.hash_method(db_meta.hash_method), TextComponents.hash_method(self.new_hash_method))) diff --git a/prime_backup/mcdr/task/general/show_help_task.py b/prime_backup/mcdr/task/general/show_help_task.py index f6fb70b..d81150e 100644 --- a/prime_backup/mcdr/task/general/show_help_task.py +++ b/prime_backup/mcdr/task/general/show_help_task.py @@ -8,6 +8,7 @@ from prime_backup.mcdr.task.basic_task import ImmediateTask from prime_backup.mcdr.task.general import help_message_utils from prime_backup.mcdr.text_components import TextComponents, TextColors +from prime_backup.types.hash_method import HashMethod from prime_backup.utils.mcdr_utils import mkcmd @@ -84,6 +85,7 @@ def run(self) -> None: elif self.what == 'database': name = mcdr_globals.metadata.name kwargs['name'] = name + kwargs['hash_methods'] = ', '.join([f'§d{hm.name}§r' for hm in HashMethod]) if self.config.database.compact.enabled: kwargs['scheduled_compact_notes'] = self.tr( f'node_help.{self.what}.scheduled_compact.on', diff --git a/prime_backup/mcdr/text_components.py b/prime_backup/mcdr/text_components.py index 92d3b13..5a7b396 100644 --- a/prime_backup/mcdr/text_components.py +++ b/prime_backup/mcdr/text_components.py @@ -8,6 +8,7 @@ from prime_backup.types.backup_info import BackupInfo from prime_backup.types.backup_tags import BackupTagName from prime_backup.types.blob_info import BlobListSummary +from prime_backup.types.hash_method import HashMethod from prime_backup.types.operator import Operator from prime_backup.types.units import ByteCount, Duration from prime_backup.utils import conversion_utils, misc_utils, backup_utils @@ -210,6 +211,12 @@ def file_size(cls, byte_cnt: Union[int, ByteCount], *, ndigits: int = 2, color: byte_cnt = ByteCount(byte_cnt) return RText(byte_cnt.auto_str(ndigits=ndigits), color=color) + @classmethod + def hash_method(cls, hash_method: Union[str, HashMethod]) -> RTextBase: + if isinstance(hash_method, HashMethod): + hash_method = hash_method.name + return RText(hash_method, RColor.light_purple) + @classmethod def number(cls, value: Any) -> RTextBase: return RText(value, TextColors.number) diff --git a/prime_backup/utils/blob_utils.py b/prime_backup/utils/blob_utils.py index bbe6b5f..bbae3b5 100644 --- a/prime_backup/utils/blob_utils.py +++ b/prime_backup/utils/blob_utils.py @@ -14,6 +14,7 @@ def get_blob_path(h: str) -> Path: def prepare_blob_directories(): + blob_store = get_blob_store() for i in range(0, 256): - p = get_blob_store() / hex(i)[2:].rjust(2, '0') + p = blob_store / hex(i)[2:].rjust(2, '0') p.mkdir(parents=True, exist_ok=True) diff --git a/prime_backup/utils/bypass_io.py b/prime_backup/utils/bypass_io.py index f39a80c..7a74b50 100644 --- a/prime_backup/utils/bypass_io.py +++ b/prime_backup/utils/bypass_io.py @@ -1,15 +1,21 @@ import io -from typing import Union +from typing import Union, TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from prime_backup.types.hash_method import HashMethod class BypassReader(io.BytesIO): - def __init__(self, file_obj, calc_hash: bool): + def __init__(self, file_obj, calc_hash: bool, *, hash_method: Optional['HashMethod'] = None): super().__init__() self.file_obj: io.BytesIO = file_obj self.read_len = 0 - from prime_backup.utils import hash_utils - self.hasher = hash_utils.create_hasher() if calc_hash else None + if calc_hash: + from prime_backup.utils import hash_utils + self.hasher = hash_utils.create_hasher(hash_method=hash_method) + else: + self.hasher = None def read(self, *args, **kwargs): data = self.file_obj.read(*args, **kwargs) diff --git a/prime_backup/utils/hash_utils.py b/prime_backup/utils/hash_utils.py index 05ac615..a4452f9 100644 --- a/prime_backup/utils/hash_utils.py +++ b/prime_backup/utils/hash_utils.py @@ -1,13 +1,15 @@ from pathlib import Path -from typing import NamedTuple, IO +from typing import NamedTuple, IO, Optional -from prime_backup.types.hash_method import Hasher +from prime_backup.types.hash_method import Hasher, HashMethod from prime_backup.utils.bypass_io import BypassReader -def create_hasher() -> 'Hasher': - from prime_backup.db.access import DbAccess - return DbAccess.get_hash_method().value.create_hasher() +def create_hasher(*, hash_method: Optional[HashMethod] = None) -> 'Hasher': + if hash_method is None: + from prime_backup.db.access import DbAccess + hash_method = DbAccess.get_hash_method() + return hash_method.value.create_hasher() _READ_BUF_SIZE = 128 * 1024 @@ -18,24 +20,28 @@ class SizeAndHash(NamedTuple): hash: str -def calc_reader_size_and_hash(file_obj: IO[bytes], *, buf_size: int = _READ_BUF_SIZE) -> SizeAndHash: - reader = BypassReader(file_obj, True) +def calc_reader_size_and_hash( + file_obj: IO[bytes], *, + buf_size: int = _READ_BUF_SIZE, + hash_method: Optional[HashMethod] = None, +) -> SizeAndHash: + reader = BypassReader(file_obj, True, hash_method=hash_method) while reader.read(buf_size): pass return SizeAndHash(reader.get_read_len(), reader.get_hash()) -def calc_file_size_and_hash(path: Path, *, buf_size: int = _READ_BUF_SIZE) -> SizeAndHash: +def calc_file_size_and_hash(path: Path, **kwargs) -> SizeAndHash: with open(path, 'rb') as f: - return calc_reader_size_and_hash(f, buf_size=buf_size) + return calc_reader_size_and_hash(f, **kwargs) -def calc_reader_hash(file_obj: IO[bytes], *, buf_size: int = _READ_BUF_SIZE) -> str: - return calc_reader_size_and_hash(file_obj, buf_size=buf_size).hash +def calc_reader_hash(file_obj: IO[bytes], **kwargs) -> str: + return calc_reader_size_and_hash(file_obj, **kwargs).hash -def calc_file_hash(path: Path, *, buf_size: int = _READ_BUF_SIZE) -> str: - return calc_file_size_and_hash(path, buf_size=buf_size).hash +def calc_file_hash(path: Path, **kwargs) -> str: + return calc_file_size_and_hash(path, **kwargs).hash def calc_bytes_hash(buf: bytes) -> str: