0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

MLFLOW

Posted at
  1. mlflow.set_tracking_uri 用于指定 MLflow 的跟踪服务器位置。MLflow 支持多种存储后端,包括:

    • 本地文件系统:将数据存储在本地目录中
    • 远程服务器:将数据存储在远程的 MLflow 跟踪服务器中
      • 启动 MLflow 跟踪服务器
        mlflow server --host 0.0.0.0 --port 5000
        # --host 0.0.0.0:允许所有 IP 地址访问服务器。
        # --port 5000:指定服务器监听的端口号(默认是 5000)。
        
      • 本地HTTP协议连接。远程服务器
        import mlflow
        # 设置远程服务器URI(HTTP)
        mlflow.set_tracking_uri("http://<server-ip>:5000")
        
    • 数据库:将数据存储在支持的数据库中(如 SQLite、MySQL、PostgreSQL 等)
    • 例子:
      mlruns_dir = os.path.join(project_dir, "mlruns")
      mlflow.set_tracking_uri(f"file://{mlruns_dir}")
      mlflow.set_experiment("Model Training Experiment gao")
      
  2. 在客户端代码中设置远程服务器的 URI 并记录实验数据

    import mlflow
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import accuracy_score
    
    # 设置远程服务器URI
    mlflow.set_tracking_uri("http://<server-ip>:5000")
    
    # 设置实验名称
    mlflow.set_experiment("Remote Experiment")
    
    # 加载数据
    iris = load_iris()
    X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
    
    # 训练模型
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    
    # 开始一个MLflow运行
    with mlflow.start_run():
       # 记录参数
       mlflow.log_param("n_estimators", 100)
       mlflow.log_param("random_state", 42)
       
       # 记录指标
       y_pred = model.predict(X_test)
       accuracy = accuracy_score(y_test, y_pred)
       mlflow.log_metric("accuracy", accuracy)
       
       # 保存模型
       mlflow.sklearn.log_model(model, "model")
    
    • 记录方法简介
      1. 记录参数
        单个参数:mlflow.log_param("param_name", value)
        多个参数:mlflow.log_params({"param1": value1, "param2": value2})
      2. 记录指标
        单个指标:mlflow.log_metric("metric_name", value)
        多个指标:mlflow.log_metrics({"metric1": value1, "metric2": value2})
      3. 记录标签
        单个标签:mlflow.set_tag("tag_name", value)
        多个标签:mlflow.set_tags({"tag1": value1, "tag2": value2})
      4. 记录模型
        Scikit-learn 模型:mlflow.sklearn.log_model(model, "model_path")
        TensorFlow 模型:mlflow.tensorflow.log_model(model, "model_path")
      5. 记录文件
        单个文件:mlflow.log_artifact("file_path")
        多个文件:mlflow.log_artifacts("directory_path")
      6. 记录文本
        文本内容:mlflow.log_text("text_content", "file_name.txt")
      7. 记录图像
        图像文件:mlflow.log_image("image_path", "image_name.png")
      8. 记录字典
        字典内容:mlflow.log_dict({"key": "value"}, "file_name.json")
      9. 记录表格
        表格数据:mlflow.log_table(df, "file_name.json")
      10. 记录模型签名
        模型签名:mlflow.log_model_signature(signature)

3. MLflow Model Registry

MLflow Model Registry 是 MLflow 的一个重要组件,用于集中管理机器学习模型的版本和生命周期。它提供了一个系统化的方式来注册、版本控制、部署和监控模型。

以下是 MLflow 模型注册的详细介绍:

  1. 主要功能
    模型注册:将训练好的模型注册到中央仓库。
    版本控制:管理模型的不同版本。
    生命周期管理:支持模型的状态转换(如 Staging → Production)。
    协作:支持团队共享和管理模型。
    部署:简化模型部署流程。

  1. 核心概念
    Registered Model(注册模型):一个命名的模型实体,包含多个版本。
    Model Version(模型版本):注册模型的特定版本。
    Stage(阶段):模型的生命周期状态,如 Staging、Production、Archived。
    Annotations(注释):为模型或版本添加描述信息。
    Tags(标签):为模型或版本添加元数据。

  1. 基本用法
    3.1 注册模型
    将训练好的模型注册到 Model Registry:

3.2 注册模型

mlflow.register_model("runs:/<run_id>/model", "model_name")

3.3 获取模型信息

# 获取注册模型
model = mlflow.get_registered_model("model_name")

# 获取模型版本
version = mlflow.get_model_version("model_name", version=1)

3.4 更新模型状态

# 将模型版本设置为 Production
client = mlflow.tracking.MlflowClient()
client.transition_model_version_stage(
    name="model_name",
    version=1,
    stage="Production"
)

3.5 添加注释和标签

# 添加模型注释
client.update_registered_model(
    name="model_name",
    description="This is a production model for customer churn prediction."
)

# 添加版本标签
client.set_model_version_tag(
    name="model_name",
    version=1,
    key="model_type",
    value="RandomForest"
)

4. 模型生命周期管理

MLflow Model Registry 支持以下模型状态:
None:默认状态。
Staging:模型处于测试阶段。
Production:模型已部署到生产环境。
Archived:模型已归档,不再使用。

4.1 状态转换

# 将模型版本从 Staging 转换为 Production
client.transition_model_version_stage(
    name="model_name",
    version=1,
    stage="Production"
)

4.2 归档模型

# 归档模型版本
client.transition_model_version_stage(
    name="model_name",
    version=1,
    stage="Archived"
)

5. 查询和比较模型

5.1 查询注册模型

# 归档模型版本
# 获取所有注册模型
models = client.search_registered_models()

# 获取某个注册模型的详细信息
model = client.get_registered_model("model_name")

5.2 查询模型版本

# 获取某个模型的所有版本
versions = client.search_model_versions("name='model_name'")

# 获取某个版本的详细信息
version = client.get_model_version("model_name", version=1)

6. 部署模型

6.1 加载模型

# 加载生产环境的模型
model = mlflow.pyfunc.load_model(f"models:/model_name/Production")

7. CI&CD

示例:使用 GitHub Actions 部署模型

mlops-deployment/
├── .github/
│   └── workflows/
│       └── deploy.yml           # GitHub Actions 工作流
├── src/
│   ├── app.py                   # FastAPI 应用
│   ├── model_loader.py          # 模型加载器
│   └── requirements.txt         # 依赖文件
├── Dockerfile                   # Docker 配置
├── docker-compose.yml           # Docker Compose 配置
└── deploy.sh                    # 部署脚本

模型加载器

# src/model_loader.py
import mlflow
import mlflow.pyfunc
import os
from typing import Optional

class ProductionModelLoader:
    def __init__(self, model_name: str, tracking_uri: str):
        self.model_name = model_name
        self.tracking_uri = tracking_uri
        mlflow.set_tracking_uri(tracking_uri)
        self.model = None
        
    def load_production_model(self) -> Optional[object]:
        """加载生产环境的模型"""
        try:
            # 加载 Production 阶段的模型
            model_uri = f"models:/{self.model_name}/Production"
            self.model = mlflow.pyfunc.load_model(model_uri)
            
            # 获取模型版本信息
            client = mlflow.tracking.MlflowClient()
            latest_version = client.get_latest_versions(
                self.model_name, 
                stages=["Production"]
            )[0]
            
            print(f"已加载模型: {self.model_name}")
            print(f"版本: {latest_version.version}")
            print(f"创建时间: {latest_version.creation_timestamp}")
            
            return self.model
            
        except Exception as e:
            print(f"模型加载失败: {e}")
            return None
    
    def predict(self, data):
        """使用模型进行预测"""
        if self.model is None:
            raise ValueError("模型未加载,请先调用 load_production_model()")
        return self.model.predict(data)
    
    def get_model_info(self) -> dict:
        """获取模型信息"""
        try:
            client = mlflow.tracking.MlflowClient()
            model_info = client.get_registered_model(self.model_name)
            latest_version = client.get_latest_versions(
                self.model_name, 
                stages=["Production"]
            )[0]
            
            return {
                "model_name": self.model_name,
                "version": latest_version.version,
                "stage": latest_version.current_stage,
                "description": model_info.description,
                "creation_timestamp": latest_version.creation_timestamp
            }
        except Exception as e:
            return {"error": str(e)}

FastAPI 应用

# src/app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
import numpy as np
import os
from model_loader import ProductionModelLoader

# 环境变量配置
MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI", "http://mlflow-server:5000")
MODEL_NAME = os.getenv("MODEL_NAME", "customer_churn_classifier")

# 初始化 FastAPI 应用
app = FastAPI(title="Production Model API", version="1.0.0")

# 初始化模型加载器
model_loader = ProductionModelLoader(MODEL_NAME, MLFLOW_TRACKING_URI)

# 启动时加载模型
@app.on_event("startup")
async def startup_event():
    """应用启动时加载生产模型"""
    global model_loader
    model = model_loader.load_production_model()
    if model is None:
        raise RuntimeError("无法加载生产模型")

# 输入数据模型
class PredictionInput(BaseModel):
    Gender: str
    Age: int
    HasDrivingLicense: int
    RegionID: float
    Switch: int
    PastAccident: str
    AnnualPremium: float

class BatchPredictionInput(BaseModel):
    data: list[PredictionInput]

# 健康检查端点
@app.get("/health")
async def health_check():
    """健康检查"""
    try:
        model_info = model_loader.get_model_info()
        return {
            "status": "healthy",
            "model_info": model_info
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"健康检查失败: {e}")

# 模型信息端点
@app.get("/model/info")
async def get_model_info():
    """获取当前生产模型信息"""
    return model_loader.get_model_info()

# 单个预测端点
@app.post("/predict")
async def predict(input_data: PredictionInput):
    """单个样本预测"""
    try:
        # 将输入数据转换为 DataFrame
        df = pd.DataFrame([input_data.dict()])
        
        # 进行预测
        prediction = model_loader.predict(df)
        
        return {
            "prediction": int(prediction[0]),
            "probability": float(prediction[0]),
            "model_info": model_loader.get_model_info()
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"预测失败: {e}")

# 批量预测端点
@app.post("/predict/batch")
async def batch_predict(input_data: BatchPredictionInput):
    """批量预测"""
    try:
        # 将输入数据转换为 DataFrame
        data_list = [item.dict() for item in input_data.data]
        df = pd.DataFrame(data_list)
        
        # 进行预测
        predictions = model_loader.predict(df)
        
        return {
            "predictions": [int(pred) for pred in predictions],
            "count": len(predictions),
            "model_info": model_loader.get_model_info()
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"批量预测失败: {e}")

# 模型重新加载端点
@app.post("/model/reload")
async def reload_model():
    """重新加载生产模型"""
    try:
        model = model_loader.load_production_model()
        if model is None:
            raise HTTPException(status_code=500, detail="模型重新加载失败")
        
        return {
            "status": "success",
            "message": "模型已重新加载",
            "model_info": model_loader.get_model_info()
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"模型重新加载失败: {e}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

依赖文件

# src/requirements.txt
fastapi==0.104.1
uvicorn==0.24.0
mlflow==2.14.3
pandas==2.2.2
numpy==1.26.4
scikit-learn==1.5.1
pydantic==2.5.0

Dockerfile

# Dockerfile
FROM python:3.9-slim

# 设置工作目录
WORKDIR /app

# 复制依赖文件
COPY src/requirements.txt .

# 安装依赖
RUN pip install --no-cache-dir -r requirements.txt

# 复制应用代码
COPY src/ .

# 暴露端口
EXPOSE 8000

# 健康检查
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
    CMD curl -f http://localhost:8000/health || exit 1

# 启动应用
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]

Docker Compose

# docker-compose.yml
version: '3.8'

services:
  # MLflow 跟踪服务器
  mlflow-server:
    image: python:3.9-slim
    container_name: mlflow-server
    ports:
      - "5000:5000"
    volumes:
      - mlflow_data:/mlflow
    command: |
      sh -c "
        pip install mlflow==2.14.3 &&
        mlflow server --host 0.0.0.0 --port 5000 --backend-store-uri file:///mlflow/mlruns
      "
    environment:
      - MLFLOW_BACKEND_STORE_URI=file:///mlflow/mlruns
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:5000"]
      interval: 30s
      timeout: 10s
      retries: 3

  # 生产模型 API
  model-api:
    build: .
    container_name: model-api
    ports:
      - "8000:8000"
    environment:
      - MLFLOW_TRACKING_URI=http://mlflow-server:5000
      - MODEL_NAME=customer_churn_classifier
    depends_on:
      mlflow-server:
        condition: service_healthy
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
      interval: 30s
      timeout: 10s
      retries: 3

  # Nginx 负载均衡器
  nginx:
    image: nginx:alpine
    container_name: nginx-lb
    ports:
      - "80:80"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
    depends_on:
      model-api:
        condition: service_healthy

volumes:
  mlflow_data:

GitHub Actions CI/CD

# .github/workflows/deploy.yml
name: Deploy Production Model

on:
  workflow_dispatch:
    inputs:
      model_name:
        description: 'MLflow模型名称'
        required: true
        default: 'customer_churn_classifier'
      environment:
        description: '部署环境'
        required: true
        default: 'production'
        type: choice
        options:
        - staging
        - production

jobs:
  check-model:
    runs-on: ubuntu-latest
    outputs:
      model_version: ${{ steps.check.outputs.version }}
      model_ready: ${{ steps.check.outputs.ready }}
    
    steps:
    - name: Checkout code
      uses: actions/checkout@v4
    
    - name: Set up Python
      uses: actions/setup-python@v4
      with:
        python-version: '3.9'
    
    - name: Install MLflow
      run: |
        pip install mlflow==2.14.3
    
    - name: Check Production Model
      id: check
      run: |
        python << EOF
        import mlflow
        import os
        
        # 设置 MLflow 跟踪 URI
        mlflow.set_tracking_uri("${{ secrets.MLFLOW_TRACKING_URI }}")
        
        try:
            client = mlflow.tracking.MlflowClient()
            
            # 检查是否有 Production 阶段的模型
            production_versions = client.get_latest_versions(
                "${{ github.event.inputs.model_name }}", 
                stages=["Production"]
            )
            
            if production_versions:
                version = production_versions[0].version
                print(f"发现生产模型版本: {version}")
                print(f"::set-output name=version::{version}")
                print(f"::set-output name=ready::true")
            else:
                print("未找到生产模型")
                print(f"::set-output name=ready::false")
                exit(1)
                
        except Exception as e:
            print(f"检查模型失败: {e}")
            print(f"::set-output name=ready::false")
            exit(1)
        EOF

  build-and-test:
    needs: check-model
    runs-on: ubuntu-latest
    if: needs.check-model.outputs.model_ready == 'true'
    
    steps:
    - name: Checkout code
      uses: actions/checkout@v4
    
    - name: Set up Docker Buildx
      uses: docker/setup-buildx-action@v3
    
    - name: Login to Docker Hub
      uses: docker/login-action@v3
      with:
        username: ${{ secrets.DOCKER_USERNAME }}
        password: ${{ secrets.DOCKER_PASSWORD }}
    
    - name: Build Docker image
      run: |
        docker build -t ${{ secrets.DOCKER_USERNAME }}/model-api:${{ needs.check-model.outputs.model_version }} .
        docker build -t ${{ secrets.DOCKER_USERNAME }}/model-api:latest .
    
    - name: Test Docker image
      run: |
        # 启动测试容器
        docker run -d --name test-api \
          -p 8000:8000 \
          -e MLFLOW_TRACKING_URI=${{ secrets.MLFLOW_TRACKING_URI }} \
          -e MODEL_NAME=${{ github.event.inputs.model_name }} \
          ${{ secrets.DOCKER_USERNAME }}/model-api:latest
        
        # 等待服务启动
        sleep 30
        
        # 健康检查
        curl -f http://localhost:8000/health || exit 1
        
        # 测试预测 API
        curl -X POST "http://localhost:8000/predict" \
          -H "Content-Type: application/json" \
          -d '{
            "Gender": "Male",
            "Age": 30,
            "HasDrivingLicense": 1,
            "RegionID": 28.0,
            "Switch": 0,
            "PastAccident": "No",
            "AnnualPremium": 40000.0
          }' || exit 1
        
        # 清理测试容器
        docker stop test-api
        docker rm test-api
    
    - name: Push Docker image
      run: |
        docker push ${{ secrets.DOCKER_USERNAME }}/model-api:${{ needs.check-model.outputs.model_version }}
        docker push ${{ secrets.DOCKER_USERNAME }}/model-api:latest

  deploy-staging:
    needs: [check-model, build-and-test]
    runs-on: ubuntu-latest
    if: github.event.inputs.environment == 'staging'
    
    steps:
    - name: Deploy to Staging
      run: |
        echo "部署到 Staging 环境"
        echo "模型版本: ${{ needs.check-model.outputs.model_version }}"
        
        # 这里可以添加实际的部署命令
        # 例如:kubectl apply -f staging-deployment.yaml

  deploy-production:
    needs: [check-model, build-and-test]
    runs-on: ubuntu-latest
    if: github.event.inputs.environment == 'production'
    
    steps:
    - name: Deploy to Production
      run: |
        echo "部署到生产环境"
        echo "模型版本: ${{ needs.check-model.outputs.model_version }}"
        
        # 这里可以添加实际的部署命令
        # 例如:
        # kubectl set image deployment/model-api \
        #   model-api=${{ secrets.DOCKER_USERNAME }}/model-api:${{ needs.check-model.outputs.model_version }}
    
    - name: Notify Deployment
      run: |
        echo "✅ 生产环境部署完成"
        echo "📦 镜像: ${{ secrets.DOCKER_USERNAME }}/model-api:${{ needs.check-model.outputs.model_version }}"
        echo "🔄 模型版本: ${{ needs.check-model.outputs.model_version }}"

部署脚本

#!/bin/bash
# deploy.sh

set -e

MODEL_NAME=${1:-"customer_churn_classifier"}
ENVIRONMENT=${2:-"production"}

echo "🚀 开始部署流程..."
echo "📦 模型名称: $MODEL_NAME"
echo "🌍 部署环境: $ENVIRONMENT"

# 检查 Production 模型
echo "🔍 检查生产模型..."
python << EOF
import mlflow
import sys

mlflow.set_tracking_uri("$MLFLOW_TRACKING_URI")
client = mlflow.tracking.MlflowClient()

try:
    production_versions = client.get_latest_versions("$MODEL_NAME", stages=["Production"])
    if production_versions:
        print(f"✅ 发现生产模型版本: {production_versions[0].version}")
    else:
        print("❌ 未找到生产模型")
        sys.exit(1)
except Exception as e:
    print(f"❌ 检查模型失败: {e}")
    sys.exit(1)
EOF

# 构建 Docker 镜像
echo "🏗️ 构建 Docker 镜像..."
docker-compose build

# 启动服务
echo "🚀 启动服务..."
docker-compose up -d

# 等待服务启动
echo "⏳ 等待服务启动..."
sleep 30

# 健康检查
echo "🏥 健康检查..."
curl -f http://localhost:8000/health || {
    echo "❌ 健康检查失败"
    docker-compose logs
    exit 1
}

echo "✅ 部署完成!"
echo "🌐 API 端点: http://localhost:8000"
echo "📊 MLflow UI: http://localhost:5000"

使用方法

本地部署

# 克隆项目
git clone <repository-url>
cd mlops-deployment

# 设置环境变量
export MLFLOW_TRACKING_URI="http://localhost:5000"
export MODEL_NAME="customer_churn_classifier"

# 运行部署脚本
chmod +x deploy.sh
./deploy.sh

CI/CD 部署

  • 将代码推送到 GitHub
  • 在 GitHub Secrets 中配置:
    MLFLOW_TRACKING_URI
    DOCKER_USERNAME
    DOCKER_PASSWORD
    在 GitHub Actions 中手动触发工作流

测试 API

# 健康检查
curl http://localhost:8000/health

# 单个预测
curl -X POST "http://localhost:8000/predict" \
  -H "Content-Type: application/json" \
  -d '{
    "Gender": "Male",
    "Age": 30,
    "HasDrivingLicense": 1,
    "RegionID": 28.0,
    "Switch": 0,
    "PastAccident": "No",
    "AnnualPremium": 40000.0
  }'

# 模型信息
curl http://localhost:8000/model/info
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?