Skip to content

king821221/tf-mmoe-mgda

Repository files navigation

TF-MMOE-MGDA

This repo is a tensorflow implementation of MMOE + MGDA multi_task learning algorithm that targets to resolve task dependency and achieve global optimimum via approaching Pareto Optimality.

This repo includes:

  1. MMOE(Multi-gate Mixture-of-Experts)

  2. MGDA(Multiple Gradient Descent Algorithm)

  3. The task shared variables used in MGDA are composed of

    • MMOE expert kernels, bias;
    • MMOE gate kernels and bias
  4. For illustrated purposes, we have provided a 2-task classification demo and a 3-task classification demo. The training and evaluation data is from keras-mmoe.

Getting Started

Requirements

  • Python 3.5
  • Tensorflow 1.9.0 and other libraries listed in requirements.txt

Installation and Run

  1. Clone the repository
  2. Install dependencies
pip install -r requirements.txt
  1. Run demo code
python census_income_demo.py
python synthetic_demo.py

Any feedback and suggestions are greatly appreiciated: 1485840691@qq.com

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages