-
Notifications
You must be signed in to change notification settings - Fork 0
/
splinter_model.py
326 lines (264 loc) · 9.62 KB
/
splinter_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# coding: utf-8
__all__ = ['BaseFetcherModel', 'CSSField', 'XPathField', 'RedisCache']
import json
import logging
import requests
from collections import Sequence
from redis import Redis
from redis.exceptions import ConnectionError
from scrapy.selector import Selector
logger = logging.getLogger(__name__)
class NoCache(object):
def __init__(self, *args, **kwargs):
pass
def get(self, key):
return None
def set(self, key, value, expire=None):
pass
class RedisCache(object):
def __init__(self, *args, **kwargs):
self.cache = Redis(*args, **kwargs)
def get(self, key):
try:
return self.cache.get(key)
except ConnectionError as e:
logger.error("Cant connect to Redis server %s", e)
return None
def set(self, key, value, expire=None):
try:
self.cache.set(key, value, expire)
except ConnectionError as e:
logger.error("Cant connect to Redis server %s", e)
class Storage(dict):
"""
A dict that accepts [keys] or .attributes
>>> obj = Storage()
>>> obj["name"] = "Bruno"
>>> obj.company = "ACME"
>>> obj.name == obj["name]
>>> obj["company] == obj.company
"""
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, attr, value):
self[attr] = value
class BaseField(object):
"""
Base for other selector fields
"""
def __init__(self,
query,
auto_extract=False,
takes_first=False,
processor=None,
query_validator=None,
default=None):
self.query = [query] if isinstance(query, basestring) else query
self.query_validator = query_validator or (lambda data: True)
self.default = default
self.auto_extract = auto_extract
self.takes_first = takes_first
self.processor = processor or (lambda untouched_data: untouched_data)
self._data = self.selector = self._raw_data = None
@property
def value(self):
return self._data
def _parse(self, selector):
parsed = self.parse(selector)
if hasattr(parsed, 'extract'):
extracted = parsed.extract()
if self.takes_first and len(extracted) > 0:
for value in extracted:
if value is not None and value != '':
return self._processor(value)
elif self.auto_extract:
return self._processor(extracted)
return self._processor(parsed)
def _processor(self, data):
"""
runs the processor if defined
processor can be a list of functions to be chained
or a single function
"""
if isinstance(self.processor, Sequence):
for function in self.processor:
data = function(data)
else:
data = self.processor(data)
return data
def parse(self, selector):
raise NotImplementedError("Must be implemented in child class")
def get_identifier(self):
return getattr(self, 'identifier', "")
def __repr__(self):
return u"<{} - {} - {}>".format(
self.__class__.__name__, self.get_identifier(), self._data
)
def __str__(self):
return unicode(self._data)
def __unicode__(self):
return unicode(self._data)
class GenericField(BaseField):
def __init__(self, identifier=None, value=None):
super(GenericField, self).__init__("")
self._data = value
self.identifier = identifier
def parse(self, selector):
return None
class CSSField(BaseField):
def parse(self, selector):
for query in self.query:
res = selector.css(query)
if len(res) and self.query_validator(res):
return res
return self.default or selector.css("__empty_selector__")
class XPathField(BaseField):
def parse(self, selector):
for query in self.query:
res = selector.xpath(query)
if len(res) and self.query_validator(res):
return res
return self.default or selector.css("__empty_selector__")
class BaseFetcherModel(object):
"""
fields example:
name = CSSField("div.perfil > div > div.perf.col-md-12 >"
" div.col-md-10.desc > h1::text")
mappings example:
mappings = {
'name': {'css': 'div#test'},
'phone': {'xpath': '//phone'},
'location': '.location' # assumes css
}
Any method named parsed_<field_name> will run after the data is collected
"""
mappings = {}
def __init__(self, url=None, mappings=None,
cache_fetch=False,
cache=NoCache,
cache_args=None,
cache_expire=None):
self.load_fields()
self.url = url
self.refresh = False
self._data = Storage()
self._selector = None
self.mappings = mappings or self.mappings.copy()
self.cache_fetch = cache_fetch
self.cache_expire = cache_expire
if isinstance(cache, type):
self.cache = cache(**(cache_args or {}))
else:
self.cache = cache
def load_fields(self):
self._fields = []
for name, field in self.__class__.__dict__.items():
if isinstance(field, BaseField):
field.identifier = name
self._fields.append(field)
def fetch(self, url=None):
url = self.url or url
cached = self.cache.get(url)
if cached and self.cache_fetch:
return cached
response = requests.get(url)
if self.cache_fetch:
self.cache.set(url, response.content, expire=self.cache_expire)
return response.content
@property
def selector(self):
if not self._selector or self.refresh:
self._selector = Selector(text=self.fetch())
self.refresh = False
return self._selector
def pre_parse(self, selector=None):
"""
To be implemented optionally in child classes
Example: in this method is possible to validade
if there is a parse_ writen for each field in a model
class MyFetcherModel(BaseFetcherModel):
model_class = AModelFromAnyORM
def pre_parse(self, selector=None):
# considering model_class as Django or MongoEngine model
model_fields = self.model_class._meta.field_names
parse_methods = [
k for k, v in self.__dict__.items()
if k.startswith('parse_') and callable(v)
]
for field_name in model_fields:
if not field_name in parse_methods:
raise Exception(
"parse method for %s is mandatory!" % field_name
)
"""
def parse(self, selector=None):
"""
The entry point
fetcher = Fetcher(url="http://...")
fetcher.parse()
"""
self.pre_parse(selector)
selector = selector or self.selector
for field in self._fields:
data = field._parse(selector)
self._data[field.identifier] = field._raw_data = data
# mappings has always the priority
for field_name, query in self.mappings.items():
if isinstance(query, dict):
method = query.keys()[0]
path = query.values()[0]
else:
method = 'css'
path = query
self._data[field_name] = getattr(selector, method)(path)
self.run_field_parsers()
for field in self._fields:
field._data = field.selector = self._data.get(field.identifier)
self.post_parse()
self.load_generic_fields()
def load_generic_fields(self):
for k, v in self._data.items():
if k not in self._fields:
field = GenericField(k, v)
self._fields.append(field)
setattr(self, k, field)
def post_parse(self):
"""
To be implemented optionally in child classes
"""
def run_field_parsers(self):
self._raw_data = self._data.copy()
for field_name, raw_selector in self._data.items():
field_parser = getattr(self, 'parse_%s' % field_name, None)
if field_parser:
try:
parsed_data = field_parser(raw_selector)
except Exception as e:
logger.error(
"Exception ocurred in parse_%s: %s", field_name, e
)
self._data[field_name] = raw_selector
else:
self._data[field_name] = parsed_data
def populate(self, obj, fields=None):
fields = fields or self._data.keys()
for field in fields:
setattr(obj, field, self._data.get(field))
def load_mappings_from_file(self, path_or_file):
"""
Will take a JSON file object, string or path
and loads on to self.mappings
{
'name': {'css': 'div#test'},
'phone': {'xpath': '//phone'},
'location': '.location' # assumes css
}
"""
if isinstance(path_or_file, basestring):
try:
data = open(path_or_file).read()
except IOError:
data = path_or_file
elif isinstance(path_or_file, file) or hasattr(path_or_file, 'read'):
data = path_or_file.read()
self.mappings.update(json.loads(data))