This repo is a tensorflow implementation of MMOE + MGDA multi_task learning algorithm that targets to resolve task dependency and achieve global optimimum via approaching Pareto Optimality.
This repo includes:
-
MMOE(Multi-gate Mixture-of-Experts)
- It could be referred to in paper: Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts.
- Our implementation is built upon the Keras MMOE implementation
-
MGDA(Multiple Gradient Descent Algorithm)
- It could be referred to in paper: Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
- Our implementation is built upon the MGDA implementation from intel, which has provided pytorch and numpy versions.
-
The task shared variables used in MGDA are composed of
- MMOE expert kernels, bias;
- MMOE gate kernels and bias
-
For illustrated purposes, we have provided a 2-task classification demo and a 3-task classification demo. The training and evaluation data is from keras-mmoe.
- Python 3.5
- Tensorflow 1.9.0 and other libraries listed in
requirements.txt
- Clone the repository
- Install dependencies
pip install -r requirements.txt
- Run demo code
python census_income_demo.py
python synthetic_demo.py
Any feedback and suggestions are greatly appreiciated: 1485840691@qq.com