0%

模型服务与 RPC

业界主流模型服务方法

预存推荐结果或 Embedding 结果

离线结果预存到 Redis 数据库中,线上环境直接取出预存数据推荐给用户
优点:
1、线上服务平台与线下离线模型训练完全解耦
2、线上没有复杂计算,推荐系统线上延迟低
缺点:
1、用户、物品数量规模过大时,线上数据库无法支撑大规模结果的存储
2、无法引入线上场景类特征,推荐结果灵活性、效果受限
适用场景:
用户规模小、冷启动、热门榜单

用户规模大,存储embedding,线上通过相似度运算得到推算结果(Item2vec、Graph Embedding -> Redis)

预训练 Embedding + 轻量级线上模型

复杂深度学习网络离线训练生成 embedding,存入内存数据库 + 线上实现逻辑回归或浅层神经网络等轻量级模型来拟合优化目标
eg. MIMN 模型:“多通道用户兴趣向量”,这些 Embedding 向量就是连接离线模型和线上模型部分的接口

利用 PMML 转换和部署模型

  • END2END 训练 + END2END部署
    PMML (Predictive Model Markup Language, PMML) 预测模型标记语言,连接离线训练平台和模型服务模块 - XML形式
    eg. spark MLlib
    Spark: 完成 Spark MLlib 模型的序列化,生成PMML文件,并且把它保存到线上服务器能够触达的数据库或文件系统中
    Java Server: 解析模型,生成预估模型,完成了与业务逻辑的整合。
    JPMML: 在 Java Server 部分只进行推断,不考虑模型训练、分布式部署等一系列问题

    代码参考

    JPMML[https://github.com/jpmml]
    MLeap[https://github.com/combust/mleap]

Tensorflow Serving

模型存储(序列化) -> 模型载入还原 -> 提供模型服务(HTTP/gRPC)

实操(Docker中)

请求模型服务API
1
2
# 从docker仓库中下载tensorflow/serving镜像
docker pull tensorflow/serving
把tensorflow/serving的测试代码clone到本地
1
2
3
git clone https://github.com/tensorflow/serving
# 指定测试数据的地址
TESTDATA="$(pwd)/serving/tensorflow_serving/servables/tensorflow/testdata"
启动TensorFlow Serving容器
1
2
3
4
5
# 在8100端口运行模型服务API
docker run -t --rm -p 8100:8100 \
-v "$TESTDATA/saved_model_half_plus_two_cpu:/models/half_plus_two" \
-e MODEL_NAME=half_plus_two \
tensorflow/serving &
调用API
1
2
curl -d '{"instances": [1.0, 2.0, 5.0]}' \
-X POST http://localhost:8501/v1/models/half_plus_two:predict

构建 API - RPC(远程调用)、REST(表征状态传输)

接口
1
2
3
4
5
6
7
8
9
10
11
12
13
14
@app.route('/predict', methods=['POST'])  # Your API endpoint URL would consist /predict
def predict():
if lr:
try:
json_ = request.json
query = pd.get_dummies(pd.DataFrame(json_))
query = query.reindex(columns=model_columns, fill_value=0)
prediction = list(lr.predict(query))
return jsonify({'prediction': str(prediction)})
except:
return jsonify({'trace': traceback.format_exc()})
else:
print('Train the model first')
return 'No model here to use'
1
2
3
4
5
6
7
8
9
10
11
12
13
from sklearn.externals import joblib

if __name__ == '__main__':

try:
port = int(sys.argv[1])
except:
port = 8000
lr = joblib.load('model.pkl') # Load "model.pkl"
print('Model loaded')
model_columns = joblib.load('model_columns.pkl') # Load "model_columns.pkl"
print('Model columns loaded')
app.run(host='192.168.100.162', port=port, debug=True)

Welcome to my other publishing channels