This repository hosts the code to replicate experiments of the paper "Distributionally Robust Self-supervised Learning for Tabular Data" with FT Transformer backbone.
The experiments are performed on the "bank" and "census" datasets from UCI. Use the following links to download the datasets:
https://archive.ics.uci.edu/dataset/222/bank+marketing
https://archive.ics.uci.edu/dataset/20/census+income
After downloading the data, use pip
to install the required packages for this project:
pip install -r requirements.txt
CUDA_LAUNCH_BLOCKING=1 python experiments_tab_transformer_JTT.py \
--full_csv="bank-additional-full.csv" \
--model_type="FTTransformer" \
--categories job marital education default housing loan contact month day_of_week poutcome \
--num_cols age duration \
--seed=43 \
--max_epoch_phase1A=35 \
--max_epoch_phase1B=100 \
--max_epoch_phase1B=100 \
--batch_size=1024 \
--output_col="y" \
--dim_out=192 \
--mask_val=0.05 \
--upweight_factor=50 \
--dataset="bank" > ft_transformer-mr-0.05-upwt-50.out
CUDA_LAUNCH_BLOCKING=1 python experiments_tab_transformer_JTT.py \
--full_csv="adult" \
--model_type="FTTransformer" \
--categories workclass education marital-status occupation relationship race sex native-country \
--num_cols age education-num \
--seed=43 \
--max_epoch_phase1A=100 \
--max_epoch_phase1B=200 \
--max_epoch_phase1B=200 \
--batch_size=1024 \
--output_col="income(>=50k)" \
--dim_out=192 \
--mask_val=0.05 \
--upweight_factor=50 \
--dataset="census" > ft_transformer-mr-0.05-upwt-50.out
CUDA_LAUNCH_BLOCKING=1 python experiments_tab_transformer_DFR.py \
--full_csv="bank-additional-full.csv" \
--model_type="FTTransformer" \
--categories job marital education default housing loan contact month day_of_week poutcome \
--num_cols age duration \
--seed=43 \
--max_epoch_phase1A=35 \
--max_epoch_phase1B=100 \
--max_epoch_phase2B=100 \
--batch_size=1024 \
--output_col="y" \
--dim_out=192 \
--mask_val=0.05 \
--upweight_factor=50 \
--dataset="bank" > ft_transformer-mr-0.05-upwt-50.out
CUDA_LAUNCH_BLOCKING=1 python experiments_tab_transformer_DFR.py \
--full_csv="adult" \
--model_type="FTTransformer" \
--categories workclass education marital-status occupation relationship race sex native-country \
--num_cols age education-num \
--seed=43 \
--max_epoch_phase1A=100 \
--max_epoch_phase1B=200 \
--max_epoch_phase2B=200 \
--batch_size=1024 \
--output_col="income(>=50k)" \
--dim_out=192 \
--mask_val=0.05 \
--upweight_factor=50 \
--dataset="census" > ft_transformer-mr-0.05-upwt-50.out
Tab-transformer, and ft-transformer backbones were adapted from https://github.com/lucidrains/tab-transformer-pytorch