diff --git a/flask_whooshee.py b/flask_whooshee.py index 00d250d..b799508 100644 --- a/flask_whooshee.py +++ b/flask_whooshee.py @@ -15,7 +15,10 @@ from whoosh.filedb.filestore import RamStorage from flask import current_app -from flask_sqlalchemy import BaseQuery +try: + from flask_sqlalchemy.query import Query +except ImportError: + from flask_sqlalchemy import BaseQuery as Query from sqlalchemy import text, event from sqlalchemy.inspection import inspect from sqlalchemy.orm.mapper import Mapper @@ -48,7 +51,7 @@ def _assure_dirs_exists(path): if err.errno != errno.EEXIST: raise -class WhoosheeQuery(BaseQuery): +class WhoosheeQuery(Query): """An override for SQLAlchemy query used to do fulltext search.""" def whooshee_search(self, search_string, group=whoosh.qparser.OrGroup, whoosheer=None, @@ -293,7 +296,7 @@ def register_whoosheer(self, wh): pass # ensure there can be a stable MRO - elif query_class not in (BaseQuery, SQLAQuery, WhoosheeQuery): + elif query_class not in (Query, SQLAQuery, WhoosheeQuery): query_class_name = query_class.__name__ model.query_class = type( "Whooshee{}".format(query_class_name), (query_class, self.query), {} diff --git a/test.py b/test.py index eb49012..500d769 100644 --- a/test.py +++ b/test.py @@ -9,7 +9,11 @@ import whoosh from whoosh.filedb.filestore import RamStorage from flask import Flask -from flask_sqlalchemy import SQLAlchemy, BaseQuery +from flask_sqlalchemy import SQLAlchemy +try: + from flask_sqlalchemy.query import Query +except ImportError: + from flask_sqlalchemy import BaseQuery as Query from sqlalchemy.orm import Query as SQLAQuery from flask_whooshee import AbstractWhoosheer, Whooshee, WhoosheeQuery @@ -712,14 +716,14 @@ def setUp(self): self.wh = Whooshee(self.app) def test_mixes_with_model_query(self): - class CustomQueryClass(BaseQuery): + class CustomQueryClass(Query): pass self._make_model_and_whoosheer(CustomQueryClass) self.assertEqual('WhoosheeCustomQueryClass', self.user_model.query_class.__name__) def test_doesnt_mix_with_default_query_class(self): - self._make_model_and_whoosheer(BaseQuery) + self._make_model_and_whoosheer(Query) self.assertIs(self.wh.query, self.user_model.query_class) def test_doesnt_mix_with_explicit_whooshee_query_class(self):