Skip to content

Commit

Permalink
Update for v0.2.1
Browse files Browse the repository at this point in the history
* Add AutoInt & InteractingLayer
  • Loading branch information
Weichen Shen authored Dec 27, 2018
1 parent 1107a82 commit cc844f3
Show file tree
Hide file tree
Showing 21 changed files with 346 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ Through `pip install deepctr` get the package and [**Get Started!**](https://d
|Neural Factorization Machine|[SIGIR 2017][Neural Factorization Machines for Sparse Predictive Analytics](https://arxiv.org/pdf/1708.05027.pdf)|
|Deep Interest Network|[KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf)|
|xDeepFM|[KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://arxiv.org/pdf/1803.05170.pdf)|
| AutoInt|[arxiv 2018][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921)|
2 changes: 1 addition & 1 deletion deepctr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
from .import sequence
from . import models
from .utils import check_version
__version__ = '0.2.0post1'
__version__ = '0.2.1'
check_version(__version__)
85 changes: 83 additions & 2 deletions deepctr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def call(self, inputs, **kwargs):

def compute_output_shape(self, input_shape):
if self.split_half:
featuremap_num = sum(self.layer_size[:-1]) // 2 + self.layer_size[-1]
featuremap_num = sum(
self.layer_size[:-1]) // 2 + self.layer_size[-1]
else:
featuremap_num = sum(self.layer_size)
return (None, featuremap_num)
Expand Down Expand Up @@ -480,7 +481,6 @@ def call(self, inputs, **kwargs):
col.append(j)
p = tf.concat([embed_list[idx]
for idx in row], axis=1) # batch num_pairs k
# Reshape([num_pairs, self.embedding_size])
q = tf.concat([embed_list[idx] for idx in col], axis=1)
inner_product = p * q
if self.reduce_sum:
Expand All @@ -504,6 +504,87 @@ def get_config(self,):
return dict(list(base_config.items()) + list(config.items()))



class InteractingLayer(Layer):
"""A Layer used in AutoInt that model the correlations between different feature fields by multi-head self-attention mechanism.
Input shape
- A 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
Output shape
- 3D tensor with shape:``(batch_size,field_size,att_embedding_size * head_num)``.
Arguments
- **att_embedding_size**: int.The embedding size in multi-head self-attention network.
- **head_num**: int.The head number in multi-head self-attention network.
- **use_res**: bool.Whether or not use standard residual connections before output.
- **seed**: A Python integer to use as random seed.
References
- [Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.](https://arxiv.org/abs/1810.11921)
"""
def __init__(self, att_embedding_size=8, head_num=2, use_res=True, seed=1024, **kwargs):
if head_num <= 0:
raise ValueError('head_num must be a int > 0')
self.att_embedding_size = att_embedding_size
self.head_num = head_num
self.use_res = use_res
self.seed = seed
super(InteractingLayer, self).__init__(**kwargs)

def build(self, input_shape):
if len(input_shape) != 3:
raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(input_shape)))
embedding_size = input_shape[-1].value
self.W_Query = self.add_weight(name='query', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
self.W_key = self.add_weight(name='key', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
self.W_Value = self.add_weight(name='value', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
if self.use_res:
self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))

super(InteractingLayer, self).build(input_shape) # Be sure to call this somewhere!

def call(self, inputs, **kwargs):
if K.ndim(inputs) != 3:
raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs)))

querys = tf.tensordot(inputs, self.W_Query, axes=(-1, 0)) # None F D*head_num
keys = tf.tensordot(inputs, self.W_key, axes=(-1, 0))
values = tf.tensordot(inputs, self.W_Value, axes=(-1, 0))

querys = tf.stack(tf.split(querys, self.head_num, axis=2)) # head_num None F D
keys = tf.stack(tf.split(keys, self.head_num, axis=2))
values = tf.stack(tf.split(values, self.head_num, axis=2))

inner_product = tf.matmul(querys, keys, transpose_b=True) # head_num None F F
self.normalized_att_scores = tf.nn.softmax(inner_product)

result = tf.matmul(self.normalized_att_scores, values)#head_num None F D
result = tf.concat(tf.split(result, self.head_num, ), axis=-1)
result = tf.squeeze(result, axis=0)#None F D*head_num

if self.use_res:
result += tf.tensordot(inputs, self.W_Res, axes=(-1, 0))
result = tf.nn.relu(result)

return result

def compute_output_shape(self, input_shape):

return (None, input_shape[1], self.att_embedding_size * self.head_num)

def get_config(self, ):
config = {'att_embedding_size': self.att_embedding_size, 'head_num': self.head_num, 'use_res': self.use_res,
'seed': self.seed}
base_config = super(InteractingLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class LocalActivationUnit(Layer):
"""The LocalActivationUnit used in DIN with which the representation of
user interests varies adaptively given different candidate items.
Expand Down
3 changes: 2 additions & 1 deletion deepctr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .pnn import PNN
from .wdl import WDL
from .xdeepfm import xDeepFM
from .autoint import AutoInt

__all__ = ["AFM", "DCN", "MLR", "DeepFM",
"MLR", "NFM", "DIN", "FNN", "PNN", "WDL", "xDeepFM"]
"MLR", "NFM", "DIN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt"]
103 changes: 103 additions & 0 deletions deepctr/models/autoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# -*- coding:utf-8 -*-
"""
Author:
Weichen Shen,[email protected]
Reference:
[1] Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.(https://arxiv.org/abs/1810.11921)
"""

from tensorflow.python.keras.layers import Dense, Embedding, Concatenate
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.initializers import RandomNormal
from tensorflow.python.keras.regularizers import l2
import tensorflow as tf

from ..utils import get_input
from ..layers import PredictionLayer, MLP, InteractingLayer


def AutoInt(feature_dim_dict, embedding_size=8, att_layer_num=3, att_embedding_size=8, att_head_num=2, att_res=True, hidden_size=(256, 256), activation='relu',
l2_reg_deep=0, l2_reg_embedding=1e-5, use_bn=False, keep_prob=1.0, init_std=0.0001, seed=1024,
final_activation='sigmoid',):
"""Instantiates the AutoInt Network architecture.
:param feature_dim_dict: dict,to indicate sparse field and dense field like {'sparse':{'field_1':4,'field_2':3,'field_3':2},'dense':['field_4','field_5']}
:param embedding_size: positive integer,sparse feature embedding_size
:param att_layer_num: int.The InteractingLayer number to be used.
:param att_embedding_size: int.The embedding size in multi-head self-attention network.
:param att_head_num: int.The head number in multi-head self-attention network.
:param att_res: bool.Whether or not use standard residual connections before output.
:param hidden_size: list,list of positive integer or empty list, the layer number and units in each layer of deep net
:param activation: Activation function to use in deep net
:param l2_reg_deep: float. L2 regularizer strength applied to deep net
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
:param use_bn: bool. Whether use BatchNormalization before activation or not.in deep net
:param keep_prob: float in (0,1]. keep_prob used in deep net
:param init_std: float,to use as the initialize std of embedding vector
:param seed: integer ,to use as random seed.
:param final_activation: output activation,usually ``'sigmoid'`` or ``'linear'``
:return: A Keras model instance.
"""

if len(hidden_size) <= 0 and att_layer_num <= 0:
raise ValueError("Either hidden_layer or att_layer_num must > 0")
if not isinstance(feature_dim_dict, dict) or "sparse" not in feature_dim_dict or "dense" not in feature_dim_dict:
raise ValueError(
"feature_dim must be a dict like {'sparse':{'field_1':4,'field_2':3,'field_3':2},'dense':['field_5',]}")

sparse_input, dense_input = get_input(feature_dim_dict, None,)
sparse_embedding = get_embeddings(
feature_dim_dict, embedding_size, init_std, seed, l2_reg_embedding)
embed_list = [sparse_embedding[i](sparse_input[i])
for i in range(len(sparse_input))]

att_input = Concatenate(axis=1)(embed_list) if len(
embed_list) > 1 else embed_list[0]

for i in range(att_layer_num):
att_input = InteractingLayer(
att_embedding_size, att_head_num, att_res)(att_input)
att_output = tf.keras.layers.Flatten()(att_input)

deep_input = tf.keras.layers.Flatten()(Concatenate()(embed_list)
if len(embed_list) > 1 else embed_list[0])
if len(dense_input) > 0:
if len(dense_input) == 1:
continuous_list = dense_input[0]
else:
continuous_list = Concatenate()(dense_input)

deep_input = Concatenate()([deep_input, continuous_list])

if len(hidden_size) > 0 and att_layer_num > 0: # Deep & Interacting Layer
deep_out = MLP(hidden_size, activation, l2_reg_deep, keep_prob,
use_bn, seed)(deep_input)
stack_out = Concatenate()([att_output, deep_out])
final_logit = Dense(1, use_bias=False, activation=None)(stack_out)
elif len(hidden_size) > 0: # Only Deep
deep_out = MLP(hidden_size, activation, l2_reg_deep, keep_prob,
use_bn, seed)(deep_input)
final_logit = Dense(1, use_bias=False, activation=None)(deep_out)
elif att_layer_num > 0: # Only Interacting Layer
final_logit = Dense(1, use_bias=False, activation=None)(att_output)
else: # Error
raise NotImplementedError

output = PredictionLayer(final_activation)(final_logit)
model = Model(inputs=sparse_input + dense_input, outputs=output)

return model


def get_embeddings(feature_dim_dict, embedding_size, init_std, seed, l2_rev_V):
sparse_embedding = [Embedding(feature_dim_dict["sparse"][feat], embedding_size,
embeddings_initializer=RandomNormal(
mean=0.0, stddev=init_std, seed=seed),
embeddings_regularizer=l2(l2_rev_V),
name='sparse_emb_' + str(i) + '-' + feat) for i, feat in
enumerate(feature_dim_dict["sparse"])]

return sparse_embedding
2 changes: 1 addition & 1 deletion deepctr/models/xdeepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def xDeepFM(feature_dim_dict, embedding_size=8, hidden_size=(256, 256), cin_laye
:param embedding_size: positive integer,sparse feature embedding_size
:param hidden_size: list,list of positive integer or empty list, the layer number and units in each layer of deep net
:param cin_layer_size: list,list of positive integer or empty list, the feature maps in each hidden layer of Compressed Interaction Network
:param cin_split_half: bool.if set to False, half of the feature maps in each hidden will connect to output unit
:param cin_split_half: bool.if set to True, half of the feature maps in each hidden will connect to output unit
:param cin_activation: activation function used on feature maps
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
:param l2_reg_embedding: L2 regularizer strength applied to embedding vector
Expand Down
3 changes: 2 additions & 1 deletion deepctr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
'Dice': Dice,
'SequencePoolingLayer': SequencePoolingLayer,
'AttentionSequencePoolingLayer': AttentionSequencePoolingLayer,
'CIN': CIN, }
'CIN': CIN,
'InteractingLayer': InteractingLayer}


def get_input(feature_dim_dict, bias_feature_dim_dict=None):
Expand Down
Binary file added docs/pics/AutoInt.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pics/InteractingLayer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 24 additions & 6 deletions docs/source/FAQ.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FAQ
==========
1. How to save or load weights/models?

1. Save or load weights/models
----------------------------------------
To save/load weights,you can write codes just like any other keras models.

.. code-block:: python
Expand All @@ -22,8 +22,26 @@ To save/load models,just a little different.
from deepctr.utils import custom_objects
model = load_model('DeepFM.h5',custom_objects)# load_model,just add a parameter
2. How can I get the attentional weights of feature interactions in AFM?
2. Set learning rate and use earlystopping
---------------------------------------------------
You can use any models in DeepCTR like a keras model object.
Here is a example of how to set learning rate and earlystopping:

.. code-block:: python
import deepctr
from tensorflow.python.keras.optimizers import Adam,Adagrad
from tensorflow.python.keras.callbacks import EarlyStopping
model = deepctr.models.DeepFM({"sparse": sparse_feature_dict, "dense": dense_feature_list})
model.compile(Adagrad('0.0808'),'binary_crossentropy',metrics=['binary_crossentropy'])
es = EarlyStopping(monitor='val_binary_crossentropy')
history = model.fit(model_input, data[target].values,batch_size=256, epochs=10, verbose=2, validation_split=0.2,callbacks=[es] )
3. Get the attentional weights of feature interactions in AFM
--------------------------------------------------------------------------
First,make sure that you have install the latest version of deepctr.

Then,use the following code,the ``attentional_weights[:,i,0]`` is the ``feature_interactions[i]``'s attentional weight of all samples.
Expand All @@ -46,7 +64,7 @@ Then,use the following code,the ``attentional_weights[:,i,0]`` is the ``feature_
3. Does the models support multi-value input?

4. Does the models support multi-value input?
---------------------------------------------------
Now only the `DIN <Features.html#din-deep-interest-network>`_ model support multi-value input,you can use layers in `sequence <deepctr.sequence.html>`_ to build your own models!
And I will add the feature soon~
And it will be supported in a future release
22 changes: 21 additions & 1 deletion docs/source/Features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ DIN use a local activation unit to get the activation score between candidate it
User's interest are represented by weighted sum of user behaviors.
user's interest vector and other embedding vectors are concatenated and fed into a MLP to get the prediction.

**DIN api** `link <./deepctr.models.din.html>`_
**DIN api** `link <./deepctr.models.din.html>`_ **DIN demo** `link <https://github.com/shenweichen/DeepCTR/tree/master/examples
/run_din.py>`_

.. image:: ../pics/DIN.png
:align: center
Expand Down Expand Up @@ -191,6 +192,25 @@ Finally,apply sum pooling on all the feature maps :math:`H_k` to get one vector.

`Lian J, Zhou X, Zhang F, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems[J]. arXiv preprint arXiv:1803.05170, 2018. <https://arxiv.org/pdf/1803.05170.pdf>`_

AutoInt(Automatic Feature Interaction)
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

AutoInt use a interacting layer to model the interactions between different features.
Within each interacting layer, each feature is allowed to interact with all the other features and is able to automatically identify relevant features to form meaningful higher-order features via the multi-head attention mechanism.
By stacking multiple interacting layers,AutoInt is able to model different orders of feature interactions.

**AutoInt api** `link <./deepctr.models.autoint.html>`_

.. image:: ../pics/InteractingLayer.png
:align: center
:scale: 70 %

.. image:: ../pics/AutoInt.png
:align: center
:scale: 70 %

`Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018. <https://arxiv.org/abs/1810.11921>`_

Layers
--------

Expand Down
1 change: 1 addition & 0 deletions docs/source/History.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# History
- 12/27/2018 : [v0.2.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.1) released.Add [AutoInt](./Features.html#autoint-automatic-feature-interactiont) Model.
- 12/22/2018 : [v0.2.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.0) released.Add [xDeepFM](./Features.html#xdeepfm) and automatic check for new version.
- 12/19/2018 : [v0.1.6](https://github.com/shenweichen/DeepCTR/releases/tag/v0.1.6) released.Now DeepCTR is compatible with tensorflow from `1.4-1.12` except for `1.7` and `1.8`.
- 29/11/2018 : [v0.1.4](https://github.com/shenweichen/DeepCTR/releases/tag/v0.1.4) released.Add [FAQ](./FAQ.html) in docs
Expand Down
3 changes: 2 additions & 1 deletion docs/source/Models-API.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ DeepCTR Models API
AFM<deepctr.models.afm>
DCN<deepctr.models.dcn>
DIN<deepctr.models.din>
xDeepFM<deepctr.models.xdeepfm>
xDeepFM<deepctr.models.xdeepfm>
AutoInt<deepctr.models.autoint>
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '0.2.0'
release = '0.2.1'


# -- General configuration ---------------------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions docs/source/deepctr.models.autoint.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
deepctr.models.autoint module
=============================

.. automodule:: deepctr.models.autoint
:members:
:no-undoc-members:
:no-show-inheritance:
1 change: 1 addition & 0 deletions docs/source/deepctr.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Submodules
.. toctree::

deepctr.models.afm
deepctr.models.autoint
deepctr.models.dcn
deepctr.models.deepfm
deepctr.models.din
Expand Down
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Welcome to DeepCTR's documentation!
.. _Stars: https://github.com/shenweichen/DeepCTR

.. |Forks| image:: https://img.shields.io/github/forks/shenweichen/deepctr.svg
.. _Forks: https://github.com/shenweichen/DeepCTR
.. _Forks: https://github.com/shenweichen/DeepCTR/fork

.. |PyPi| image:: https://img.shields.io/pypi/v/deepctr.svg
.. _PyPi: https://pypi.org/project/deepctr/
Expand All @@ -35,6 +35,7 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR

News
-----
12/27/2018 : Add `AutoInt <./Features.html#autoint-automatic-feature-interaction>`_ . `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.1>`_

12/22/2018 : Add `xDeepFM <./Features.html#xdeepfm>`_ and automatic check for new version. `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.0>`_

Expand Down
Loading

0 comments on commit cc844f3

Please sign in to comment.