diff --git a/README.md b/README.md index 9ff135de..ba8563a6 100644 --- a/README.md +++ b/README.md @@ -33,3 +33,4 @@ Through `pip install deepctr` get the package and [**Get Started!**](https://d |Neural Factorization Machine|[SIGIR 2017][Neural Factorization Machines for Sparse Predictive Analytics](https://arxiv.org/pdf/1708.05027.pdf)| |Deep Interest Network|[KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf)| |xDeepFM|[KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://arxiv.org/pdf/1803.05170.pdf)| +| AutoInt|[arxiv 2018][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921)| diff --git a/deepctr/__init__.py b/deepctr/__init__.py index 4aadf50e..c5e03447 100644 --- a/deepctr/__init__.py +++ b/deepctr/__init__.py @@ -3,5 +3,5 @@ from .import sequence from . import models from .utils import check_version -__version__ = '0.2.0post1' +__version__ = '0.2.1' check_version(__version__) diff --git a/deepctr/layers.py b/deepctr/layers.py index 53b705e9..c163ba90 100644 --- a/deepctr/layers.py +++ b/deepctr/layers.py @@ -286,7 +286,8 @@ def call(self, inputs, **kwargs): def compute_output_shape(self, input_shape): if self.split_half: - featuremap_num = sum(self.layer_size[:-1]) // 2 + self.layer_size[-1] + featuremap_num = sum( + self.layer_size[:-1]) // 2 + self.layer_size[-1] else: featuremap_num = sum(self.layer_size) return (None, featuremap_num) @@ -480,7 +481,6 @@ def call(self, inputs, **kwargs): col.append(j) p = tf.concat([embed_list[idx] for idx in row], axis=1) # batch num_pairs k - # Reshape([num_pairs, self.embedding_size]) q = tf.concat([embed_list[idx] for idx in col], axis=1) inner_product = p * q if self.reduce_sum: @@ -504,6 +504,87 @@ def get_config(self,): return dict(list(base_config.items()) + list(config.items())) + +class InteractingLayer(Layer): + """A Layer used in AutoInt that model the correlations between different feature fields by multi-head self-attention mechanism. + + Input shape + - A 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. + + Output shape + - 3D tensor with shape:``(batch_size,field_size,att_embedding_size * head_num)``. + + + Arguments + - **att_embedding_size**: int.The embedding size in multi-head self-attention network. + - **head_num**: int.The head number in multi-head self-attention network. + - **use_res**: bool.Whether or not use standard residual connections before output. + - **seed**: A Python integer to use as random seed. + + References + - [Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.](https://arxiv.org/abs/1810.11921) + """ + def __init__(self, att_embedding_size=8, head_num=2, use_res=True, seed=1024, **kwargs): + if head_num <= 0: + raise ValueError('head_num must be a int > 0') + self.att_embedding_size = att_embedding_size + self.head_num = head_num + self.use_res = use_res + self.seed = seed + super(InteractingLayer, self).__init__(**kwargs) + + def build(self, input_shape): + if len(input_shape) != 3: + raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(input_shape))) + embedding_size = input_shape[-1].value + self.W_Query = self.add_weight(name='query', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32, + initializer=tf.keras.initializers.glorot_uniform(seed=self.seed)) + self.W_key = self.add_weight(name='key', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32, + initializer=tf.keras.initializers.glorot_uniform(seed=self.seed)) + self.W_Value = self.add_weight(name='value', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32, + initializer=tf.keras.initializers.glorot_uniform(seed=self.seed)) + if self.use_res: + self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32, + initializer=tf.keras.initializers.glorot_uniform(seed=self.seed)) + + super(InteractingLayer, self).build(input_shape) # Be sure to call this somewhere! + + def call(self, inputs, **kwargs): + if K.ndim(inputs) != 3: + raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs))) + + querys = tf.tensordot(inputs, self.W_Query, axes=(-1, 0)) # None F D*head_num + keys = tf.tensordot(inputs, self.W_key, axes=(-1, 0)) + values = tf.tensordot(inputs, self.W_Value, axes=(-1, 0)) + + querys = tf.stack(tf.split(querys, self.head_num, axis=2)) # head_num None F D + keys = tf.stack(tf.split(keys, self.head_num, axis=2)) + values = tf.stack(tf.split(values, self.head_num, axis=2)) + + inner_product = tf.matmul(querys, keys, transpose_b=True) # head_num None F F + self.normalized_att_scores = tf.nn.softmax(inner_product) + + result = tf.matmul(self.normalized_att_scores, values)#head_num None F D + result = tf.concat(tf.split(result, self.head_num, ), axis=-1) + result = tf.squeeze(result, axis=0)#None F D*head_num + + if self.use_res: + result += tf.tensordot(inputs, self.W_Res, axes=(-1, 0)) + result = tf.nn.relu(result) + + return result + + def compute_output_shape(self, input_shape): + + return (None, input_shape[1], self.att_embedding_size * self.head_num) + + def get_config(self, ): + config = {'att_embedding_size': self.att_embedding_size, 'head_num': self.head_num, 'use_res': self.use_res, + 'seed': self.seed} + base_config = super(InteractingLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + class LocalActivationUnit(Layer): """The LocalActivationUnit used in DIN with which the representation of user interests varies adaptively given different candidate items. diff --git a/deepctr/models/__init__.py b/deepctr/models/__init__.py index 87461abf..e933ead3 100644 --- a/deepctr/models/__init__.py +++ b/deepctr/models/__init__.py @@ -8,6 +8,7 @@ from .pnn import PNN from .wdl import WDL from .xdeepfm import xDeepFM +from .autoint import AutoInt __all__ = ["AFM", "DCN", "MLR", "DeepFM", - "MLR", "NFM", "DIN", "FNN", "PNN", "WDL", "xDeepFM"] + "MLR", "NFM", "DIN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt"] diff --git a/deepctr/models/autoint.py b/deepctr/models/autoint.py new file mode 100644 index 00000000..c98b4fa9 --- /dev/null +++ b/deepctr/models/autoint.py @@ -0,0 +1,103 @@ +# -*- coding:utf-8 -*- +""" + +Author: + Weichen Shen,wcshen1994@163.com + +Reference: + [1] Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.(https://arxiv.org/abs/1810.11921) + +""" + +from tensorflow.python.keras.layers import Dense, Embedding, Concatenate +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.initializers import RandomNormal +from tensorflow.python.keras.regularizers import l2 +import tensorflow as tf + +from ..utils import get_input +from ..layers import PredictionLayer, MLP, InteractingLayer + + +def AutoInt(feature_dim_dict, embedding_size=8, att_layer_num=3, att_embedding_size=8, att_head_num=2, att_res=True, hidden_size=(256, 256), activation='relu', + l2_reg_deep=0, l2_reg_embedding=1e-5, use_bn=False, keep_prob=1.0, init_std=0.0001, seed=1024, + final_activation='sigmoid',): + """Instantiates the AutoInt Network architecture. + + :param feature_dim_dict: dict,to indicate sparse field and dense field like {'sparse':{'field_1':4,'field_2':3,'field_3':2},'dense':['field_4','field_5']} + :param embedding_size: positive integer,sparse feature embedding_size + :param att_layer_num: int.The InteractingLayer number to be used. + :param att_embedding_size: int.The embedding size in multi-head self-attention network. + :param att_head_num: int.The head number in multi-head self-attention network. + :param att_res: bool.Whether or not use standard residual connections before output. + :param hidden_size: list,list of positive integer or empty list, the layer number and units in each layer of deep net + :param activation: Activation function to use in deep net + :param l2_reg_deep: float. L2 regularizer strength applied to deep net + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param use_bn: bool. Whether use BatchNormalization before activation or not.in deep net + :param keep_prob: float in (0,1]. keep_prob used in deep net + :param init_std: float,to use as the initialize std of embedding vector + :param seed: integer ,to use as random seed. + :param final_activation: output activation,usually ``'sigmoid'`` or ``'linear'`` + :return: A Keras model instance. + """ + + if len(hidden_size) <= 0 and att_layer_num <= 0: + raise ValueError("Either hidden_layer or att_layer_num must > 0") + if not isinstance(feature_dim_dict, dict) or "sparse" not in feature_dim_dict or "dense" not in feature_dim_dict: + raise ValueError( + "feature_dim must be a dict like {'sparse':{'field_1':4,'field_2':3,'field_3':2},'dense':['field_5',]}") + + sparse_input, dense_input = get_input(feature_dim_dict, None,) + sparse_embedding = get_embeddings( + feature_dim_dict, embedding_size, init_std, seed, l2_reg_embedding) + embed_list = [sparse_embedding[i](sparse_input[i]) + for i in range(len(sparse_input))] + + att_input = Concatenate(axis=1)(embed_list) if len( + embed_list) > 1 else embed_list[0] + + for i in range(att_layer_num): + att_input = InteractingLayer( + att_embedding_size, att_head_num, att_res)(att_input) + att_output = tf.keras.layers.Flatten()(att_input) + + deep_input = tf.keras.layers.Flatten()(Concatenate()(embed_list) + if len(embed_list) > 1 else embed_list[0]) + if len(dense_input) > 0: + if len(dense_input) == 1: + continuous_list = dense_input[0] + else: + continuous_list = Concatenate()(dense_input) + + deep_input = Concatenate()([deep_input, continuous_list]) + + if len(hidden_size) > 0 and att_layer_num > 0: # Deep & Interacting Layer + deep_out = MLP(hidden_size, activation, l2_reg_deep, keep_prob, + use_bn, seed)(deep_input) + stack_out = Concatenate()([att_output, deep_out]) + final_logit = Dense(1, use_bias=False, activation=None)(stack_out) + elif len(hidden_size) > 0: # Only Deep + deep_out = MLP(hidden_size, activation, l2_reg_deep, keep_prob, + use_bn, seed)(deep_input) + final_logit = Dense(1, use_bias=False, activation=None)(deep_out) + elif att_layer_num > 0: # Only Interacting Layer + final_logit = Dense(1, use_bias=False, activation=None)(att_output) + else: # Error + raise NotImplementedError + + output = PredictionLayer(final_activation)(final_logit) + model = Model(inputs=sparse_input + dense_input, outputs=output) + + return model + + +def get_embeddings(feature_dim_dict, embedding_size, init_std, seed, l2_rev_V): + sparse_embedding = [Embedding(feature_dim_dict["sparse"][feat], embedding_size, + embeddings_initializer=RandomNormal( + mean=0.0, stddev=init_std, seed=seed), + embeddings_regularizer=l2(l2_rev_V), + name='sparse_emb_' + str(i) + '-' + feat) for i, feat in + enumerate(feature_dim_dict["sparse"])] + + return sparse_embedding diff --git a/deepctr/models/xdeepfm.py b/deepctr/models/xdeepfm.py index 0eb87da0..03798e98 100644 --- a/deepctr/models/xdeepfm.py +++ b/deepctr/models/xdeepfm.py @@ -20,7 +20,7 @@ def xDeepFM(feature_dim_dict, embedding_size=8, hidden_size=(256, 256), cin_laye :param embedding_size: positive integer,sparse feature embedding_size :param hidden_size: list,list of positive integer or empty list, the layer number and units in each layer of deep net :param cin_layer_size: list,list of positive integer or empty list, the feature maps in each hidden layer of Compressed Interaction Network - :param cin_split_half: bool.if set to False, half of the feature maps in each hidden will connect to output unit + :param cin_split_half: bool.if set to True, half of the feature maps in each hidden will connect to output unit :param cin_activation: activation function used on feature maps :param l2_reg_linear: float. L2 regularizer strength applied to linear part :param l2_reg_embedding: L2 regularizer strength applied to embedding vector diff --git a/deepctr/utils.py b/deepctr/utils.py index 8e6036d0..329b1784 100644 --- a/deepctr/utils.py +++ b/deepctr/utils.py @@ -27,7 +27,8 @@ 'Dice': Dice, 'SequencePoolingLayer': SequencePoolingLayer, 'AttentionSequencePoolingLayer': AttentionSequencePoolingLayer, - 'CIN': CIN, } + 'CIN': CIN, + 'InteractingLayer': InteractingLayer} def get_input(feature_dim_dict, bias_feature_dim_dict=None): diff --git a/docs/pics/AutoInt.png b/docs/pics/AutoInt.png new file mode 100644 index 00000000..f2cd23c4 Binary files /dev/null and b/docs/pics/AutoInt.png differ diff --git a/docs/pics/InteractingLayer.png b/docs/pics/InteractingLayer.png new file mode 100644 index 00000000..bfb4a458 Binary files /dev/null and b/docs/pics/InteractingLayer.png differ diff --git a/docs/source/FAQ.rst b/docs/source/FAQ.rst index 01100d19..4d00eb54 100644 --- a/docs/source/FAQ.rst +++ b/docs/source/FAQ.rst @@ -1,7 +1,7 @@ FAQ ========== -1. How to save or load weights/models? - +1. Save or load weights/models +---------------------------------------- To save/load weights,you can write codes just like any other keras models. .. code-block:: python @@ -22,8 +22,26 @@ To save/load models,just a little different. from deepctr.utils import custom_objects model = load_model('DeepFM.h5',custom_objects)# load_model,just add a parameter -2. How can I get the attentional weights of feature interactions in AFM? +2. Set learning rate and use earlystopping +--------------------------------------------------- +You can use any models in DeepCTR like a keras model object. +Here is a example of how to set learning rate and earlystopping: + +.. code-block:: python + + import deepctr + from tensorflow.python.keras.optimizers import Adam,Adagrad + from tensorflow.python.keras.callbacks import EarlyStopping + model = deepctr.models.DeepFM({"sparse": sparse_feature_dict, "dense": dense_feature_list}) + model.compile(Adagrad('0.0808'),'binary_crossentropy',metrics=['binary_crossentropy']) + + es = EarlyStopping(monitor='val_binary_crossentropy') + history = model.fit(model_input, data[target].values,batch_size=256, epochs=10, verbose=2, validation_split=0.2,callbacks=[es] ) + + +3. Get the attentional weights of feature interactions in AFM +-------------------------------------------------------------------------- First,make sure that you have install the latest version of deepctr. Then,use the following code,the ``attentional_weights[:,i,0]`` is the ``feature_interactions[i]``'s attentional weight of all samples. @@ -46,7 +64,7 @@ Then,use the following code,the ``attentional_weights[:,i,0]`` is the ``feature_ -3. Does the models support multi-value input? - +4. Does the models support multi-value input? +--------------------------------------------------- Now only the `DIN `_ model support multi-value input,you can use layers in `sequence `_ to build your own models! -And I will add the feature soon~ \ No newline at end of file +And it will be supported in a future release \ No newline at end of file diff --git a/docs/source/Features.rst b/docs/source/Features.rst index 3681164a..09452e79 100644 --- a/docs/source/Features.rst +++ b/docs/source/Features.rst @@ -163,7 +163,8 @@ DIN use a local activation unit to get the activation score between candidate it User's interest are represented by weighted sum of user behaviors. user's interest vector and other embedding vectors are concatenated and fed into a MLP to get the prediction. -**DIN api** `link <./deepctr.models.din.html>`_ +**DIN api** `link <./deepctr.models.din.html>`_ **DIN demo** `link `_ .. image:: ../pics/DIN.png :align: center @@ -191,6 +192,25 @@ Finally,apply sum pooling on all the feature maps :math:`H_k` to get one vector. `Lian J, Zhou X, Zhang F, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems[J]. arXiv preprint arXiv:1803.05170, 2018. `_ +AutoInt(Automatic Feature Interaction) +>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +AutoInt use a interacting layer to model the interactions between different features. +Within each interacting layer, each feature is allowed to interact with all the other features and is able to automatically identify relevant features to form meaningful higher-order features via the multi-head attention mechanism. +By stacking multiple interacting layers,AutoInt is able to model different orders of feature interactions. + +**AutoInt api** `link <./deepctr.models.autoint.html>`_ + +.. image:: ../pics/InteractingLayer.png + :align: center + :scale: 70 % + +.. image:: ../pics/AutoInt.png + :align: center + :scale: 70 % + +`Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018. `_ + Layers -------- diff --git a/docs/source/History.md b/docs/source/History.md index d476b77c..152b0d54 100644 --- a/docs/source/History.md +++ b/docs/source/History.md @@ -1,4 +1,5 @@ # History +- 12/27/2018 : [v0.2.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.1) released.Add [AutoInt](./Features.html#autoint-automatic-feature-interactiont) Model. - 12/22/2018 : [v0.2.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.2.0) released.Add [xDeepFM](./Features.html#xdeepfm) and automatic check for new version. - 12/19/2018 : [v0.1.6](https://github.com/shenweichen/DeepCTR/releases/tag/v0.1.6) released.Now DeepCTR is compatible with tensorflow from `1.4-1.12` except for `1.7` and `1.8`. - 29/11/2018 : [v0.1.4](https://github.com/shenweichen/DeepCTR/releases/tag/v0.1.4) released.Add [FAQ](./FAQ.html) in docs diff --git a/docs/source/Models-API.rst b/docs/source/Models-API.rst index cd5b6c20..3be6fd7e 100644 --- a/docs/source/Models-API.rst +++ b/docs/source/Models-API.rst @@ -12,4 +12,5 @@ DeepCTR Models API AFM DCN DIN - xDeepFM \ No newline at end of file + xDeepFM + AutoInt \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 0a26ecf5..844ce894 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ # The short X.Y version version = '' # The full version, including alpha/beta/rc tags -release = '0.2.0' +release = '0.2.1' # -- General configuration --------------------------------------------------- diff --git a/docs/source/deepctr.models.autoint.rst b/docs/source/deepctr.models.autoint.rst new file mode 100644 index 00000000..b2452f34 --- /dev/null +++ b/docs/source/deepctr.models.autoint.rst @@ -0,0 +1,7 @@ +deepctr.models.autoint module +============================= + +.. automodule:: deepctr.models.autoint + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/docs/source/deepctr.models.rst b/docs/source/deepctr.models.rst index 55dd032b..4621dedf 100644 --- a/docs/source/deepctr.models.rst +++ b/docs/source/deepctr.models.rst @@ -7,6 +7,7 @@ Submodules .. toctree:: deepctr.models.afm + deepctr.models.autoint deepctr.models.dcn deepctr.models.deepfm deepctr.models.din diff --git a/docs/source/index.rst b/docs/source/index.rst index bc7572e7..a8d51c17 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -15,7 +15,7 @@ Welcome to DeepCTR's documentation! .. _Stars: https://github.com/shenweichen/DeepCTR .. |Forks| image:: https://img.shields.io/github/forks/shenweichen/deepctr.svg -.. _Forks: https://github.com/shenweichen/DeepCTR +.. _Forks: https://github.com/shenweichen/DeepCTR/fork .. |PyPi| image:: https://img.shields.io/pypi/v/deepctr.svg .. _PyPi: https://pypi.org/project/deepctr/ @@ -35,6 +35,7 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR News ----- +12/27/2018 : Add `AutoInt <./Features.html#autoint-automatic-feature-interaction>`_ . `Changelog `_ 12/22/2018 : Add `xDeepFM <./Features.html#xdeepfm>`_ and automatic check for new version. `Changelog `_ diff --git a/examples/run_din.py b/examples/run_din.py new file mode 100644 index 00000000..3a4f2968 --- /dev/null +++ b/examples/run_din.py @@ -0,0 +1,42 @@ +import numpy as np +from deepctr.models import DIN + + +def get_xy_fd(): + + feature_dim_dict = {"sparse": {'user_age': 4, 'user_gender': 2, + 'item_id': 4, 'item_gender': 2}, "dense": []} # raw feature:single value feature + + # history behavior feature:multi-value value feature + behavior_feature_list = ["item_id", "item_gender"] + # single value feature input + user_age = np.array([1, 2, 3]) + user_gender = np.array([0, 1, 0]) + item_id = np.array([0, 1, 2]) + item_gender = np.array([0, 1, 0]) + + # multi-value feature input + hist_item_id = np.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 0]]) + hist_item_gender = np.array([[0, 1, 0, 1], [0, 1, 1, 1], [0, 0, 1, 0]]) + # valid length of behavior sequence of every sample + hist_length = np.array([4, 4, 3]) + + feature_dict = {'user_age': user_age, 'user_gender': user_gender, 'item_id': item_id, 'item_gender': item_gender, + 'hist_item_id': hist_item_id, 'hist_item_gender': hist_item_gender, } + + x = [feature_dict[feat] for feat in feature_dim_dict["sparse"]] + \ + [feature_dict['hist_'+feat] + for feat in behavior_feature_list] + [hist_length] + # Notice the concatenation order: single feature + multi-value feature + length + # Since the length of the historical sequences of different features in DIN are the same(they are all extended from item_id),only one length vector is enough. + y = [1, 0, 1] + + return x, y, feature_dim_dict, behavior_feature_list + + +if __name__ == "__main__": + x, y, feature_dim_dict, behavior_feature_list = get_xy_fd() + model = DIN(feature_dim_dict, behavior_feature_list, hist_len_max=4,) + model.compile('adam', 'binary_crossentropy', + metrics=['binary_crossentropy']) + history = model.fit(x, y, verbose=1, validation_split=0.5) diff --git a/setup.py b/setup.py index d55d401b..012d6adc 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setuptools.setup( name="deepctr", - version="0.2.0.post1", + version="0.2.1", author="Weichen Shen", author_email="wcshen1994@163.com", description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with tensorflow.", diff --git a/tests/layers_test.py b/tests/layers_test.py index 80a8ebe7..f1f0e41c 100644 --- a/tests/layers_test.py +++ b/tests/layers_test.py @@ -152,3 +152,17 @@ def test_test_CIN_invalid(layer_size): with CustomObjectScope({'CIN': layers.CIN}): layer_test(layers.CIN, kwargs={"layer_size": layer_size}, input_shape=( BATCH_SIZE, FIELD_SIZE, EMBEDDING_SIZE)) + + +@pytest.mark.parametrize( + 'head_num,use_res', + [(head_num, use_res,) + for head_num in [1, 2] + for use_res in [True, False] + ] +) +def test_InteractingLayer(head_num, use_res,): + with CustomObjectScope({'InteractingLayer': layers.InteractingLayer}): + layer_test(layers.InteractingLayer, kwargs={"head_num": head_num, "use_res": + use_res, }, input_shape=( + BATCH_SIZE, FIELD_SIZE, EMBEDDING_SIZE)) diff --git a/tests/models/AutoInt_test.py b/tests/models/AutoInt_test.py new file mode 100644 index 00000000..260a9ddc --- /dev/null +++ b/tests/models/AutoInt_test.py @@ -0,0 +1,37 @@ +import numpy as np +import pytest +from deepctr.models import AutoInt +from ..utils import check_model + + +@pytest.mark.parametrize( + 'att_layer_num,hidden_size,sparse_feature_num', + [(0, (4,), 2), (1, (), 1), (1, (4,), 1), (2, (4, 4,), 2)] +) +def test_AutoInt(att_layer_num, hidden_size, sparse_feature_num): + model_name = "AutoInt" + sample_size = 64 + feature_dim_dict = {"sparse": {}, 'dense': []} + for name, num in zip(["sparse", "dense"], [sparse_feature_num, sparse_feature_num]): + if name == "sparse": + for i in range(num): + feature_dim_dict[name][name + '_' + + str(i)] = np.random.randint(1, 10) + else: + for i in range(num): + feature_dim_dict[name].append(name + '_' + str(i)) + + sparse_input = [np.random.randint(0, dim, sample_size) + for dim in feature_dim_dict['sparse'].values()] + dense_input = [np.random.random(sample_size) + for name in feature_dim_dict['dense']] + y = np.random.randint(0, 2, sample_size) + x = sparse_input + dense_input + + model = AutoInt(feature_dim_dict, att_layer_num=att_layer_num, + hidden_size=hidden_size, keep_prob=0.5, ) + check_model(model, model_name, x, y) + + +if __name__ == "__main__": + test_AutoInt(True, (32, 32), 2)