机器学习项目的目标并不仅仅是在Jupyter笔记本中训练出一个模型。真正困难的部分在于将这个模型转化为能够可靠运行、安全更新且经得起时间考验的系统。

大多数机器学习系统在实际应用中会因为一些看似琐碎但却关键的问题而失败:训练代码与部署代码之间出现差异,输入数据发生变化,一个微小的预处理调整就会影响预测结果,或者由于现实世界环境的变化导致模型性能逐渐下降。这些问题都不是通过改进算法就能解决的,而是需要通过工程化的手段来应对——比如建立可复制的开发流程、进行验证测试、实施版本控制、进行监控以及设置自动检查机制。

在这本实践指南中,你将在自己的本地机器上构建一个完整的微型机器学习平台。这个端到端的项目会带你从模型训练一直走到部署阶段,并且会涵盖所有关键的“最后一步”实现细节。我们会以欺诈检测为例来讲解整个流程,但同样的工作方式也适用于客户流失预测或任何二元分类问题。所有的操作都在本地完成(无需使用云服务),而且每一步骤都可以直接复制执行,因此你可以边学习边验证实验结果。

最终,你将拥有一个可以在本地机器上运行的、具备生产环境要求的机器学习平台——从模型训练到预测结果的输出,整个系统都配备了用于测试、监控和持续优化的基础设施。当然,我们会通过提供可以直接复制并运行的代码片段来帮助你完成这些步骤。让我们开始吧!

📦 获取完整代码
本书中的所有代码都存储在一个可以直接运行的代码仓库中:
代码仓库链接: https://github.com/sandeepmb/freecodecamp-local-ml-platform
你可以克隆这个仓库并跟随步骤进行操作,或者将其作为参考实现来使用。

目录

  1. 项目概述与准备

  2. 构建简单的模型和API(朴素方法)

  3. 朴素方法的局限性

  4. 使用MLflow添加实验跟踪和模型注册功能

  5. 使用Feast确保特征数据的一致性

  6. 添加数据验证机制,并设定合理的预期值

  7. 监控模型性能及数据的变化趋势

  8. 使用CI/CD自动化测试和部署流程

  9. 事件响应方案

  10. 如何将所有内容整合起来

  11. 下一步行动:将系统部署到生产环境中

  12. 结论

  13. 参考资料

项目概述与环境搭建

在开始编写代码之前,我们先来明确项目的整体框架。我们的应用场景是信用卡欺诈检测——这是一个二分类问题,我们需要判断某笔交易是欺诈性的(is_fraud = 1)还是合法的(is_fraud = 0)。这种任务在机器学习领域非常常见,同时也是评估实际生产环境中的机器学习挑战的一个良好范例,因为欺诈模式会随时间发生变化(这有助于我们研究模型性能的漂移问题),而错误的输入数据(比如格式不正确的交易信息)如果处理不当,也会引发严重的问题。

技术栈

我们将使用一些在机器学习运维中广泛流行、同时又适合初学者的Python工具:

工具名称 用途 选择理由
MLflow 实验跟踪与模型管理工具 开源项目,应用范围广泛,用户界面友好
Feast 特征存储系统,用于统一提供训练所需的特征数据 适用于生产环境,可在本地运行,离线/在线接口一致
FastAPI 高性能Web框架,用于快速构建预测服务 使用起来简单高效,支持自动生成交互式文档
Great Expectations 数据质量检测工具 允许用户定义数据验证规则,并能生成详细的报告
Evidently 用于监控数据漂移及模型性能下降情况 提供直观的报告,易于集成到项目中
Docker 容器化技术,确保开发环境与生产环境的一致性 行业标准工具,在任何环境中都能正常使用
GitHub Actions 持续集成/持续部署自动化工具 公共仓库可免费使用,与GitHub集成紧密

下面我来简要介绍这些工具的功能:

MLflow是一个开源平台,专门用于管理机器学习的整个生命周期。它提供了实验跟踪功能(可以记录参数、指标及相关输出结果),模型注册机制(允许为模型设置别名并管理不同版本的模型),以及模型部署服务。我们使用它来确保实验结果的可复现性,并对模型进行版本管理。

Feast是一个开源的特征存储系统,能够帮助我们在训练和推理阶段之间统一、一致地提供所需的特征数据。这一工具可以有效避免“训练阶段使用的特征与生产环境中的特征不一致”这种常见问题,从而防止模型准确率出现隐性下降。

FastAPI是一个现代的、高效的Python Web框架,专门用于构建API服务。它使用起来非常简单,运行速度很快,而且还能自动生成交互式文档。我们将利用它来提供模型的预测结果。

Great Expectations是一个开源的数据质量检测工具。用户可以通过它为数据定义各种验证规则(比如“交易金额必须为正数”或“时间必须在0到23小时之间”),然后对比实际输入的数据是否符合这些规则来进行检查。

Evidently是一个开源库,用于持续监控数据以及模型随时间变化的表现。它能够检测到数据分布的变化以及模型性能的下降。

Docker能确保在开发和部署环境中使用相同的依赖项,从而避免“在我的机器上可以运行,但在其他环境下却出问题”这类常见问题。

GitHub Actions提供了CI/CD自动化功能。高效的CI/CD流程能够帮助我们更快地集成代码并进行部署,同时减少错误的发生。

💡 思维模型:可以把这些工具看作是为你的机器学习模型搭建的“安全网”。我们使用的每一种工具都能捕捉到不同类型的问题,就像驾驶中的防御性措施一样,有助于确保机器学习的稳定运行。

先决条件

你需要准备以下内容:

  • 你的机器上已经安装了Python 3.9+

  • Docker Desktop也已安装并处于运行状态

  • 如果你想尝试CI/CD流程,那么需要一个GitHub账户

  • 你需要对Python以及机器学习的基本概念有基本的了解(比如训练和预测的含义)

你不需要具备MLOps或Kubernetes的相关经验。所有操作都可以在本地完成,只需要使用Python和Docker即可——完全不需要云服务或Kubernetes

项目结构

让我们在本地机器上建立一个基本的项目结构。打开终端并运行以下命令:

# 创建项目目录及子文件夹
mkdir ml-platform-tutorial && cd ml-platform-tutorial
mkdir -p data models src tests feature_repo

# 设置虚拟环境(推荐)
python -m venv venv
source venv/bin/activate   # 在Windows系统中:venv\Scripts\activate

你的项目结构应该如下所示:

ml-platform-tutorial/
├── data/              # 训练数据和测试数据集
├ ├── models/            # 保存的模型文件
├ ├── src/               # 源代码目录
├ ├── tests/             # 测试文件
├ ├── feature_repo/      # 特征存储库
├ ├── venv/              # 虚拟环境
└── requirements.txt   # 依赖项列表

接下来,创建一个requirements.txt文件,列出所有必需的库:

# requirements.txt

# 核心机器学习库
pandas==2.2.0
numpy==1.26.3
scikit-learn==1.4.0

# 实验跟踪与模型注册工具
mlflow==2.10.0

# 特征存储库
feast==0.36.0

# API框架
fastapi==0.109.0
uvicorn==0.27.0
httpx==0.26.0

# 数据验证工具
great-expectations==0.18.8

# 监控工具
evidently==0.7.20

# 测试工具
pytest==8.0.0
pytest-cov==4.1.0

# 其他实用工具
pyarrow==15.0.0
pydantic==2.6.0

📌 版本说明:为了确保实验结果的可复现性,我们指定了具体的版本号。虽然使用更新的版本也可能能够正常运行,但所有示例都是使用这里列出的版本进行测试的。

现在来安装这些依赖项吧:

pip install -r requirements.txt

这个过程可能需要几分钟时间,因为系统需要安装所有的依赖包。一旦安装完成,我们就可以开始逐步构建我们的项目了。

注意:你应该会看到一个项目文件夹,其中包含data/models/src/tests/feature_repo/这几个目录,同时还需要一个已经激活的虚拟环境,其中已经安装了所有必要的依赖包。你可以通过运行python -c "import mlflow; import feast; import fastapi; print('All imports successful!')"来验证这些设置是否正确。

图1:我们将要构建的完整机器学习平台

如果这个过程看起来很复杂,不用担心——我们会一步一步地构建每个组件,从最简单的部分开始,然后逐步将它们连接起来。

一个用于欺诈检测的本地端到端机器学习平台的架构图。交易数据会依次经过MLflow中的模型训练、实验跟踪和模型注册系统,Feast中的特征管理模块,Great Expectations进行数据验证,然后通过FastAPI提供预测服务,Evidently负责监控,而Docker和GitHub Actions则用于自动化测试和部署。

1. 构建一个简单的模型和API(基础方法)

为了说明为什么我们需要这些工具,让我们先从构建一个不依赖任何机器学习运维基础设施的简单系统开始吧。我们会快速训练一个简单的模型并部署它,然后观察会出现哪些问题。这种“基础方法”正是大多数机器学习项目开始的起点——了解它的局限性会帮助我们找到后续需要解决的方案。

1.1 快速训练一个模型

首先,我们需要一些数据。为了简化操作,我们会生成一个专门用于欺诈检测的合成数据集,这样就不需要依赖任何外部数据文件了。这个数据集会包含以下这些特征:

  • amount:交易金额(以美元为单位)

  • hour:交易发生的具体时间(0-23小时制)

  • day_of_week:一周中的哪一天(0代表周一,6代表周日)

  • merchant_category:商家类型(食品杂货店、餐厅、零售店、在线商店、旅游相关商家等)

  • is_fraud:标记指示这笔交易是欺诈性的(1)还是合法的(0)

我们会假设只有大约2%的交易属于欺诈行为,这种不平衡在真实的欺诈数据中非常常见。这种不平衡情况非常重要,因为它会直接影响我们评估模型效果的方式。

创建src/generate_data.py文件:

# src/generate_data.py
"""
生成用于欺诈检测的合成数据集。

该脚本会创建外观真实的交易数据,其中欺诈性交易的模式与合法交易有所不同:
- 欺诈性交易的金额通常更高
- 欺诈行为往往发生在深夜
- 在线商家和旅游相关商家的交易更容易被操纵成欺诈行为
"""
import pandas as pd
import numpy as np

def generate_transactions(n_samples=10000, fraud_ratio=0.02, seed=42):
"""
生成用于欺诈检测的合成数据集。

参数:
n_samples:要生成的交易记录总数
fraud_ratio:欺诈性交易的比例(默认为2%)
seed:用于保证结果可复现的随机种子

返回值:
包含交易特征和欺诈标签的DataFrame

欺诈性交易的特征如下:
- 金额通常更高(平均金额为245美元,而合法交易的平均金额为33美元)
- 交易时间多在深夜(0点至5点、23点)
- 涉及在线商家或旅游相关商家的交易更为常见
"""
np.random.seed(seed)
n_fraud = int(n_samples * fraud_ratio)
n_legit = n_samples - n_fraud

# 合法交易:具有常规的购物模式
# - 金额遵循对数正态分布(大部分金额较小,部分金额较大)
# - 交易时间在一天中的各个时段都有出现
# - 商家类别以日常购物相关的商家为主
legit = pd.DataFrame({
"amount": np.random.lognormal(mean=3.5, sigma=1.2, size=n_legit), # 平均金额约为33美元
"hour": np.random.randint(0, 24, size=n_legit),
"day_of_week": np.random.randint(0, 7, size=n_legit),
"merchant_category": np.random.choice(
["grocery", "restaurant", "retail", "online", "travel"],
size=n_legit,
p=[0.30, 0.25, 0.25, 0.15, 0.05] # 优先选择日常购物相关的商家
),
"is_fraud": 0
})

# 欺诈性交易:具有可疑的交易模式
# - 金额通常更高(欺诈者往往会进行大额交易)
# - 交易时间多在深夜(此时监管力度较弱)
# |涉及在线商家或旅游相关商家的交易更为常见(这些领域的交易更容易被操纵成欺诈行为)
fraud = pd.DataFrame({
"amount": np.random.lognormal(mean=5.5, sigma=1.5, size=n_fraud), # 平均金额约为245美元
"hour": np.random.choice([0, 1, 2, 3, 4, 5, 23], size=n_fraud), # 交易时间在深夜
"day_of_week": np.random.randint(0, 7, size=n_fraud),
"merchant_category": np.random.choice(
["grocery", "restaurant", "retail", "online", "travel"],
size=n_fraud,
p=[0.05, 0.05, 0.10, 0.60, 0.20] # 优先选择在线商家或旅游相关商家
),
"is_fraud": 1
})

# 将合法交易和欺诈性交易合并并打乱顺序
df = pd.concat([legit, fraud], ignore_index=True)
df = df.sample(frac=1, random_state=seed).reset_index(drop=True)

return df

if __name__ == "__main__":
# 生成数据集
print("正在生成用于欺诈检测的合成数据集...")
df = generate_transactions(n_samples=10000, fraud_ratio=0.02)

# 将数据集分为训练集(80%)和测试集(20%)
train_df = df.sample(frac=0.8, random_state=42)
test_df = df.drop(train_df.index)

# 将数据集保存为CSV文件
train_df.to_csv("data/train.csv", index=False)
test_df.to_csv("data/test.csv", index=False)

# 打印统计信息
print(f"\n数据集生成成功!")
print(f"训练集包含{len(train_df):,}条交易记录")
print(f"测试集包含{len(test_df):,}条交易记录")
print(f"整体欺诈比例为:{df['is_fraud'].mean():.2%}")
print(f"合法交易的平均金额为:${df[df['is_fraud']==0]['amount'].mean():.2f}")
print(f"欺诈性交易的平均金额为:${df[df['is_fraud]==1]['amount'].mean():.2f}")
print(f"欺诈性交易中各商家类别的分布情况为:")
print(df[df['is_fraud']==1]['merchant_category'].value_counts(normalize=True))

运行数据生成脚本:

python src/generate_data.py

你应该会看到如下输出:

正在生成伪造交易检测数据集……

数据集生成成功!
训练集:8,000笔交易
测试集:2,000笔交易
整体欺诈比例:2.00%

合法交易——平均金额:33.45美元
欺诈交易——平均金额:245.67美元

商家类别分布(欺诈交易):
在线购物        0.60%
旅游          0.20%
零售          0.10%
餐厅          0.05%
食品杂货       0.05%

现在你已经得到了`data/train.csv`和`data/test.csv`文件,其中训练集包含约8,000笔交易,测试集包含约2,000笔交易。

为什么这很重要:这些合成数据具有真实的模式——欺诈交易较为罕见(占比仅为2%),且通常发生在高金额交易、深夜进行的交易中,同时这些欺诈行为多集中在某些特定类型的商家类别中。这样的模式能为我们的模型提供学习依据。

现在,让我们快速训练一个模型吧。我们将使用scikit-learn库中的简单随机森林分类器来预测`is_fraud`字段的值。在这个初步的版本中,我们不会进行太多的特征工程处理——只需对分类变量`merchant_category`进行标签编码,然后将所有数据输入模型即可。

创建`src/train_naive.py`文件:

# src/train_naive.py
"""
训练一个欺诈检测模型——初级版本。

这个脚本演示了一种“快速简便”的机器学习方法:
- 不进行实验跟踪
- 不进行模型版本管理
- 只需要训练后将模型保存为pickle文件即可

我们会在后续章节中对这种方法进行改进。
"""
import pandas as pd
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    accuracy_score, 
    f1_score, 
    precision_score, 
    recall_score,
    confusion_matrix,
    classification_report
)

def main():
    print("正在加载数据...")
    train_df = pd.read_csv("data/train.csv")
    test_df = pd.read_csv("data/test.csv")
    
    print(f"训练样本数量:{len(train_df):,}")
    print(f"测试样本数量:{len(test_df):,}")
    print(f"训练集的欺诈比例:{train_df['is_fraud'].mean():.2%}")
    
    # 对分类特征进行编码
    # 我们需要保存这个编码器,以便在后续推理时使用相同的映射关系
    print("\n正在对分类特征进行编码...")
    encoder = LabelEncoder()
    train_df["merchant_encoded"] = encoder.fit_transform(train_df["merchant_category"])
    test_df["merchantEncoded"] = encoder.transform(test_df["merchant_category"])
    
    print(f"商家类别的编码映射关系:{dict(zipencoder.classes_, encoder.transform(encoder.classes_)))}")
    
    # 准备特征数据和标签数据
    feature_cols = ["amount", "hour", "day_of_week", "merchant_encoded"]
    X_train = train_df[featurecols]
    y_train = train_df["is_fraud"]
    X_test = test_df[feature_cols]
    y_test = test_df["is_fraud"]
    
    # 训练随机森林分类器
    print("\n正在训练随机森林模型...")
    model = RandomForestClassifier(
        n_estimators=100,      # 树的数量
        max_depth=10,          # 每棵树的最大深度
        random_state=42,       # 为了保证结果的可重复性
        n_jobs=-1              # 使用所有CPU核心进行训练
    )
    model.fit(X_train, y_train)
    print("训练完成!")
    
    # 在测试数据上评估模型性能
    print("\n" + "="*50)
    print("模型评估结果")
    print("="*50)
    
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)[:, 1]
    
    print(f"准确率:{accuracy_score(y_test, y_pred):.4f}")
    print(f"精确度:{precision_score(y_test, y_pred):.4f}")
    print(f"召回率:{recall_score(y_test, y_pred):.4f}")
    print(f"F1分数:{f1_score(y_test, y_pred):.4f}")
    
    # 输出混淆矩阵
    cm = confusion_matrix(y_test, y_pred)
    print(f"  真负例:{cm[0][0]:,}(正确识别为合法交易)")
    print(f"  假正例:{cm[0][1]:,}(将合法交易误判为欺诈交易)")
    print(f"  假负例:{cm[1][0]:,}(遗漏了欺诈交易——这非常危险!)")
    print(f"  真正例:{cm[1][1]:,}(正确检测出欺诈交易)")
    
    # 输出分类报告
    print(classification_report(y_test, y_pred, target_names=['Legitimate', 'Fraud']))
    
    # 显示各特征的重要性
    print("\n各特征的重要性:")
    for name, importance in sorted(
        zip(feature_cols, model.feature_importances_),
        key=lambda x: x[1],
        reverse=True
    ):
        print(f"  {name}:{importance:.4f}")
    
    # 将模型和编码器一起保存
    print("\n模型已成功训练并保存!")
    print("\n警告:这种初级方法存在一些问题:
    " + "  – 没有记录超参数或评估指标
    " + "  – 没有进行模型版本管理
    " + "  – 无法重现这个精确的模型
    " + "  – 我们会在后续章节中解决这些问题!")
    
    if __name__ == "__main__":
        main()

运行训练脚本:

python src/train_naive.py

你应该会看到类似以下的输出结果:

正在加载数据…
训练样本数量:8,000个
测试样本数量:2,000个
训练样本中的欺诈比例:2.00%

正在对分类特征进行编码…
商家类别对应关系:{'grocery': 0, 'online': 1, 'restaurant': 2, 'retail': 3, 'travel': 4}

正在训练随机森林模型…
训练已完成!

==================================================
模型评估结果
==================================================

准确率:0.9820
精确度:0.7273
召回率:0.6154
F1分数:0.6667

混淆矩阵:
  真负例数量:1,956(正确识别为合法交易)
  假正例数量:4(将合法交易误判为欺诈行为)
  假负例数量:32(遗漏了欺诈交易——这非常危险!)
  真正例数量:8(成功检测出欺诈行为)

特征重要性排序:
  金额:0.5423
  时间:0.2156
  商家编码:0.1345
  星期几:0.1076

重要提示:你会看到大约98%的准确率,但F1分数较低(一般在0.5到0.7之间)。当欺诈案例仅占2%时,准确率这个指标会具有很大的误导性!因为如果一个模型总是预测“不是欺诈”,那么它的准确率确实可以达到98%,但实际上它根本没有检测出任何欺诈行为。因此,在处理不平衡分类问题时,我们更关注F1分数、精确度和召回率。

💡 如果你是刚开始接触不平衡分类问题,请记住:当正类样本非常稀少时,高准确率其实并没有太大意义。

该脚本会生成一个名为models/model.pkl的文件,其中包含了训练好的模型以及标签编码器(进行推理时这两个组件都是必不可少的)。

需要注意的事项:现在你应该已经拥有了以下文件:

  • data/train.csv(约8,000行数据)

  • data/test.csv(约2,000行数据)

  • models/model.pkl(训练好的模型 + 编码器)

该模型的准确率应该会在98%左右,但F1分数会在0.5到0.7之间。请确认这些文件确实存在:ls -la data/ models/

1.2 使用FastAPI提供预测服务

现在我们已经训练好了模型,接下来就要将其部署为API,以便客户能够使用它来获取预测结果。我们会选择使用FastAPI,因为它使用起来非常简单、运行速度很快,而且还能自动生成交互式文档。

FastAPI具有以下优点:

  • 易于使用:采用Python风格的语法,并提供类型提示功能

  • 性能优异:是目前最快的Python开发框架之一

  • 自动生成文档:可以直接使用Swagger UI生成文档

  • 数据验证功能强大:支持使用Pydantic模型进行数据校验

创建文件src/serve_naive.py

```python
# src/serve_naive.py
"""
将欺诈检测模型作为REST API提供——这是最基础的版本。

这是一个简单的API,它能够:
1. 在程序启动时加载训练好的模型;
2. 通过POST请求接收交易数据;
3. 返回欺诈检测结果。

在后续章节中,我们会添加验证机制、监控功能,并优化模型加载流程。
"""

import pickle
from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import Optional

# 在程序启动时加载训练好的模型及编码器
# 这个操作仅在服务器启动时执行一次,不会在每次请求时都进行
print("正在加载模型...")
with open("models/model.pkl", "rb") as f:
model, encoder = pickle.load(f)
print("模型加载成功!")

# 创建FastAPI应用
app = FastAPI(
title="欺诈检测API",
description="该API用于判断信用卡交易是否属于欺诈行为。
它会接收交易详情,并返回以下信息:
- 该交易被判定为欺诈的概率;
- 欺诈发生的概率值(范围为0.0到1.0)。
**注意:** 这是目前最基础的版本,不包含验证或监控功能。",
version="1.0.0"
)

# 使用Pydantic定义输入数据的结构
# 这种方式可以自动进行数据验证,并生成相应的文档
class Transaction(BaseModel):
"""用于判断交易是否属于欺诈行为的结构体。"""
amount: float = Field(
description="交易金额,单位为美元",
example=150.00
)
hour: int = Field(
description="交易发生的小时数(范围为0到23)",
example=14
)
day_of_week: int = Field(
description "星期几(0表示周一,6表示周日)",
example=3
)
merchant_category: str = Field(
description="商家类型",
example="在线商店"
)

class PredictionResponse(BaseModel):
"""用于返回预测结果的结构体。"""
is_fraud: bool = Field(description="该交易是否被判定为欺诈行为")
fraud_probability: float = Field"description:欺诈发生的概率值(范围为0.0到1.0)")

@app.post("/predict", response_model=PredictionResponse)
def predict(transaction: Transaction):
"""
判断一笔交易是否属于欺诈行为。
接收交易详情,然后返回欺诈检测结果及概率值。
"""

# 将请求数据转换为字典格式
data = transaction.dict()

# 使用训练时使用的相同编码器对商家类型进行编码
# 这样可以确保训练阶段和实际应用阶段使用的是同一套编码规则
try:
data["merchant_encoded"] = encoder.transform([data["merchant_category"]])[0]
except ValueError:
# 处理未知的商家类型
# 在生产环境中,应该有更完善的处理机制
data["merchantEncoded"] = 0

# 按照训练时相同的顺序准备特征数据
X = [[
data["amount"],
data["hour"],
data["day_of_week"],
data["merchant_encoded"]
]]

# 获取预测结果及概率值
prediction = model.predict(X)[0]
probability = model.predict_proba(X)[0][1] # 第一类结果的概率值

return PredictionResponse(
is_fraud,bool(prediction),
fraud_probability=round(float(probability), 4)
)

@app.get("/health")
def health_check():
"""
健康检查端点。
返回API的当前运行状态。适用于以下场景:
- 负载均衡器的健康检查
- Kubernetes中的存活检测
- 监控系统等"""
return {
"status": "正常运行",
"model_loaded": model is not None
}

@app.get("/")
def root():
"""首页端点,提供API的相关信息。"""
return {
"message": "欺诈检测API",
"version": "1.0.0",
"docs": "/docs",
"health": "/health"
}
```
```

关于这段代码,有几点需要注意:

  1. Pydantic模型:我们使用`BaseModel`来定义预期的输入JSON结构。FastAPI会自动根据这个结构验证传入的请求。

  2. 类型提示:类型提示(如`float`、`int`、`str`)既可用于文档编写,也能在运行时进行验证。

  3. 特征编码:对于每个请求,我们都会使用训练过程中保存的同一`LabelEncoder`对商家类别进行编码。这样就能确保训练阶段和实际服务阶段使用的一致性。

  4. 健康检查端点:`/health`端点是生产环境API的标准配置——它能让负载均衡器和监控系统检测服务是否正在运行。

要运行这个API,可以使用Uvicorn(一个ASGI服务器):

uvicorn src.serve_naive:app --reload --host 0.0.0.0 --port 8000

`--reload`选项可以在开发过程中实现自动重启服务器(当你修改代码时,服务器会自动重新启动)。

运行后你应该会看到如下输出:

正在加载模型…
模型加载成功!
INFO: Uvicorn正在http://0.0.0.0:8000上运行(按CTRL+C可退出)
INFO: 重启器进程已启动

现在打开浏览器,访问`http://localhost:8000/docs`,你将会看到Swagger UI——这个自动生成的交互式文档允许你直接在浏览器中测试API功能!

可以在另一个终端中使用curl来测试API:

# 使用一个看似正常的交易进行测试
curl -X POST "http://localhost:8000/predict" \
  -H "Content-Type: application/json" \
  -d '{"amount": 50.0, "hour": 14, "day_of_week": 3, "merchant_category": "grocery"}'

预期的响应结果为:

{"is_fraud": false, "fraud_probability": 0.02}

# 使用一笔可疑交易进行测试(金额较大,时间较晚,且是在线交易)
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{"amount": 500.0, "hour": 3, "day_of_week": 1, "merchant_category": "online"}'

预期的响应结果为:

{"is_fraud": true, "fraud_probability": 0.78}

我们成功地将这个模型开发成了一个可使用的API!在实际应用中,我们可以将这个API与支付处理系统、移动应用程序或任何需要欺诈检测功能的系统进行集成。

但在庆祝之前,让我们先分析一下这种简单实现方法可能存在的缺陷……

注意事项:你的API应该运行在`http://localhost:8000`地址上。访问`/docs`页面后,应该能看到`/predict`和`/health`这两个端点。你可以使用curl或Swagger UI来验证API是否能够正常返回预测结果。

2. 天真方法为何会失败

我们这套简化的机器学习流程表面上看起来是可以正常运行的:它能够训练模型并生成预测结果。然而,如果试图在生产环境中维护或扩展这个系统,隐藏的问题就会暴露出来

这一部分内容非常重要:理解这些问题将有助于我们在后续章节中找到相应的解决方案。让我们逐一分析这些问题吧。

问题1:缺乏实验追踪机制(可复现性)

不妨做一个这样的思考实验:再次运行train_naive.py,但这次使用不同的超参数设置(比如将n_estimators改为200,或将max_depth改为15)。如果有人要求你重现之前的实验结果,你能做到吗?

很可能做不到。目前,我们完全没有记录以下这些信息:

  • 我们使用了哪些超参数

  • 实验最终得到了哪些评估指标

  • 训练所用数据的版本是什么

  • 安装了哪些编程库及其版本

  • 训练是在什么时候进行的

  • 是谁负责执行了这次训练

三个月后,如果你的上级问“这个模型是怎么训练出来的?你能重现当时的实验结果吗?”——你肯定会陷入麻烦。虽然你有源代码,但你并不知道是哪个版本的代码、使用了哪些参数,也不知道是用哪份数据训练出了当前正在生产环境中使用的模型。

实验追踪就是记录所有这些细节的做法(包括代码版本、参数设置、评估指标、数据版本等),这样就可以方便地对比和复现实验结果。而我们采用的天真方法完全缺乏这种机制,因此我们的实验结果既不可靠,也不适合被后续的开发工作所利用。

问题2:模型版本管理混乱,部署过程一团糟

我们训练了一个模型,并将其保存为model.pkl文件。现在来看这样一种情况:

  1. 你使用不同的超参数重新训练了一个新的模型

  2. 你用新模型覆盖了原来的model.pkl文件

  3. 你将这个新模型部署到了生产环境中

  4. 用户开始反馈模型的准确率降低了,出现了更多的误报情况

  5. 你想恢复到之前的旧模型

  6. 问题在于:原来的模型已经被覆盖,再也找不回来了

由于没有系统的版本管理机制,我们无法回答以下这些问题:

  • 目前生产环境中使用的是哪个模型的版本?

  • 模型v1和模型v2的评估指标分别是什么?

  • 每个模型是在什么时候由谁训练出来的?

  • 如果新模型的表现更差,我们能立即恢复到旧模型吗?

  • 不同版本之间有哪些差异?

如果没有对模型进行版本控制,那就等于在盲目地操作。想象一下,如果不使用Git来管理代码,那我们的模型开发过程不也就像这样吗?

问题3:没有数据验证机制,输入错误导致输出结果错误

目前,我们的API会接受任何形式的输入,然后尝试生成预测结果。那么,当输入的数据质量较差时,会发生什么呢?让我们来看一看吧。

创建一个测试脚本 src/test_bad_data.py

# src/test_bad_data.py
"""测试当我们向API发送错误数据时会发生什么。"""
import requests

BASE_URL = "http://localhost:8000"

print("正在使用各种错误数据测试API……\n")

# 测试1:负数金额
print("测试1:负数金额")
response = requests.post(f"{BASE_URL}/predict", json={
    "amount": -500.0,        # 负数金额——这是不可能的!
    "hour": 14,
    "day_of_week": 3,
    "merchant_category": "online"
})
print(f"  状态码:{response.status_code}")
print(f"  返回结果:{response.json()}\n")

# 测试2:无效的时间值
print("测试2:时间值为25(应为0-23)")
response = requests.post(f"{BASE_URL}/predict", json={
    "amount": 100.0,
    "hour": 25,              # 时间值无效!
    "day_of_week": 3,
    "merchant_category": "online"
})
print(f"  状态码:{response.status_code}")
print(f"  返回结果:{response.json()}\n")

# 测试3:无效的星期几
print("测试3:星期几为10(应为0-6)")
response = requests.post(f"{BASE_URL}/predict", json={
    "amount": 100.0,
    "hour": 14,
    "day_of_week": 10,       # 星期几无效!
    "merchant_category": "online"
})
print(f"  状态码:{response.status_code}")
print(f"  返回结果:{response.json()}\n")

# 测试4:未知的商家类别
print("测试4:商家类别未知")
response = requests.post(f"{BASE_URL}/predict", json={
    "amount": 100.0,
    "hour": 14,
    "day_of_week": 3,
    "merchant_category": "unknown_category"  # 这个类别并不存在于训练数据中!
})
print(f"  状态码:{response.status_code}")
print(f"  返回结果:{response.json()}\n")

# 测试5:所有参数都错误
print("测试5:所有数据都是错误的")
response = requests.post(f"{BASE_URL}/predict", json={
    "amount": -1000.0,
    "hour": 99,
    "day_of_week": 15,
    "merchant_category": "totally_fake"
})
print(f"  状态码:{response.status_code}")
print(f"  返回结果:{response.json()}\n")

print("观察结果:API能够接受所有错误数据,并依然返回预测结果!")
print("但这非常危险——错误数据会导致错误的预测结果,而且系统不会发出任何警告。"

运行这个脚本(确保你的API仍在运行中):

python src/test_bad_data.py

你将会看到类似以下的输出:

正在使用各种错误数据测试API……

测试1:负数金额
  状态码:200
  返回结果:{'is_fraud': False, 'fraud_probability': 0.15}

测试2:时间值为25(应为0-23)
  状态码:200
  返回结果:{'is_fraud': False, 'fraud_probability': 0.08}

……

观察结果:API能够接受所有错误数据,并依然返回预测结果!

API可以接受错误数据,并且会生成预测结果,而且系统不会发出任何警告!在实际情况中,这可能会导致以下问题:

  • 基于错误的数据生成的预测结果会是错误的。

  • 由于输入数据格式不正确,欺诈行为可能无法被检测出来。

  • 基于损坏的数据,合法的交易可能会被误判为非法交易而被阻止。

  • 将无法找出导致预测结果错误的原因。

俗话说:“输入垃圾,输出也是垃圾。”但更糟糕的是——我们甚至都不知道输入了哪些“垃圾”!

问题4:模型漂移——性能随时间下降

以下是每个实际应用的机器学习系统都会遇到的情况:

  1. 1月:你使用历史欺诈数据训练模型,模型的准确率达到了98%,F1分数为0.67,大家都很满意。

  2. 2月:模型正式投入使用,运行效果良好,欺诈行为确实被有效拦截了。

  3. 3月:骗子们开始改变作案手段,他们使用不同的交易模式——金额更小、涉及不同的商家类别、在一天中的不同时间进行交易。

  4. 4月:模型的准确率从98%下降到了85%,F1分数也从0.67降到了0.35,欺诈行为又开始猖獗起来。

  5. 5月:一起严重的欺诈事件发生了,调查显示该模型在过去的2个月里一直表现不佳。

问题所在:由于没有进行任何监控,因此有整整2个月的时间大家都没有发现这个问题。

这种现象被称为数据漂移(当输入数据的分布发生变化时),或者概念漂移(当输入与输出之间的关系发生变化时)。在现实世界的系统中,这两种现象都是不可避免的。

如果没有进行监控:

  • 你将无法及时发现性能下降的情况

  • 你也弄不清楚性能为何会下降

  • 直到用户提出投诉,你才能采取补救措施

  • 而到那时,可能已经造成了严重的后果

问题5:缺乏持续集成/持续交付机制或部署安全保障

我们的“部署流程”实际上就是:

  1. 通过SSH连接到服务器(或在本地运行程序)

  2. 执行命令python src/train_naive.py

  3. 将模型文件复制到指定位置

  4. 重启API服务

  5. 然后就只能寄希望于一切都能顺利运行了

这种部署方式存在以下问题:

  • 没有自动化测试:哪怕是一个拼写错误,都可能导致整个系统崩溃

  • 没有测试环境:我们直接在生产环境中进行测试

  • 没有逐步部署机制

    :所有流量都会立即被导向新的模型

  • 没有回滚功能:如果出现问题,我们必须手动进行修复

  • 没有审计记录:谁在什么时间进行了哪些操作?

正是因为这样的部署流程,才会导致生产环境中出现严重问题。比如,在周五下午5点匆忙进行部署,结果导致欺诈检测系统出故障,而直到周一发现欺诈损失急剧增加时,大家才意识到这个问题。

图2:朴素机器学习方法存在的问题

示意图展示了朴素机器学习方案中的缺陷:手动进行模型训练和部署,没有实验跟踪机制,没有模型版本管理,训练数据和生产数据不一致,没有数据验证流程,也没有性能监控机制,更缺乏持续集成/持续交付的保障措施,如自动化测试、回滚功能或审计记录。

总结:我们需要解决的问题

我们现有的简单机器学习服务缺乏一些关键的基础设施。以下是将问题与相应的解决方案对应起来:

问题 影响 解决方案 相关章节
没有实验跟踪功能 无法重现或比较不同的实验结果 使用MLflow进行实验跟踪 3
没有模型版本管理机制 无法回退到之前的模型版本,也无法进行审计 使用MLflow Registry管理模型版本 3
各个功能之间存在不一致性 训练结果与实际部署结果不匹配 通过Feast Feature Store解决功能一致性问题 4
没有数据验证机制 会导致错误的预测结果 使用Great Expectations进行数据验证 5
没有监控系统 模型的变化情况无法被及时发现 需要建立有效的监控机制 6
没有持续集成/持续部署流程 会导致风险较高的部署操作 使用GitHub Actions与Docker实现持续集成/持续部署 7

好消息是:我们可以通过逐步为开发流程添加相关组件来解决这些问题。每个工具都针对特定的问题,共同作用就能构建出一个功能完善的机器学习平台。

让我们开始逐一解决这些问题吧。

3. 使用MLflow添加实验跟踪功能和模型注册系统

如果没有这些功能会带来什么后果:无法重现之前的实验结果,无法比较不同的实验方案,当新的模型在生产环境中出现故障时也无法及时回退到之前的版本。

我们的第一个改进措施就是解决问题1和问题2:确保实验结果的可重现性以及模型的版本管理。

MLflow是一个开源平台,专门用于管理机器学习的整个生命周期。我们将使用它的两个核心组件:

  1. MLflow Tracking:记录实验过程中的各种参数、指标及生成的结果文件,从而便于比较不同实验的结果。

  2. MLflow Model Registry:通过别名(如“冠军模型”和“挑战者模型”)来管理模型的版本,并跟踪模型的部署过程。

为什么这些功能很重要:如果没有实验跟踪机制,机器学习就变成了纯粹的猜测行为。而使用MLflow后,每次实验的过程都会被详细记录下来,包括参数、指标及生成的结果文件。这样我们就可以对比不同实验的结果,了解究竟是什么因素改善了模型的性能,并且可以随时重现过去的实验过程。模型注册系统还提供了对模型版本的有效管理机制——你可以清楚地知道哪些模型正在实际生产环境中使用,而且可以在几秒钟内迅速回退到之前的版本。

3.1 如何配置MLflow跟踪服务器

默认情况下,MLflow会将实验记录保存在本地目录中,但为了使用其完整的用户界面和模型注册系统,最好单独运行MLflow跟踪服务器。

打开一个新的终端窗口(请确保它与用于API操作的终端分开),然后运行以下命令:

# 创建用于存储MLflow数据的目录
mkdir -p mlruns

# 启动MLflow跟踪服务器
mlflow server \
    --host 0.0.0.0 \
    --port 5000 \
    --backend-store-uri sqlite:///mlflow.db \
    --default-artifact-root ./mlruns

让我们来详细分析这些参数:

  • --host 0.0.0.0:在所有网络接口上监听请求

  • :在端口5000上运行程序

  • --backend-store-uri sqlite:///mlflow.db:将实验元数据存储在SQLite数据库中(在生产环境中,通常会使用PostgreSQL或MySQL)

  • :将模型相关文件存储在mlruns目录中

执行上述配置后,你应该会看到如下输出:

[INFO] 正在启动gunicorn 21.2.0
[INFO] 监听地址:http://0.0.0.0:5000

现在打开浏览器,访问http://localhost:5000,你将会看到MLflow UI界面——由于我们还没有开始任何实验,因此该界面最初应该是空的。

3.2 如何在代码中记录实验过程

接下来,让我们修改训练脚本,以便将所有的实验数据都记录到MLflow系统中。请创建文件src/train_mlflow.py

# src/train_mlflow.py
"""
使用MLflow进行欺诈检测模型的训练,并实现实验数据的跟踪功能。

该脚本展示了正确的ML实验跟踪方法:
- 记录所有的超参数设置
- 记录所有评估指标(包括训练集和测试集的数据)
- 将训练好的模型作为数据文件保存下来
- 将模型注册到MLflow的模型注册系统中

请将此脚本与train_naive.py进行比较,了解两者之间的区别!
"""
import pandas as pd
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score,
    roc_auc_score
)
import pickle
from datetime import datetime

# 配置MLflow系统,使其使用我们的跟踪服务器
mlflow.set_tracking_uri("http://localhost:5000")

# 创建或获取实验对象
# 所有的训练记录都会被归入这个实验名下
mlflow.set_experiment("fraud-detection")

def load_and_preprocess_data():
    """加载并预处理训练数据及测试数据"""
    print("正在加载数据...")
    train_df = pd.read_csv("data/train.csv")
    test_df = pd.read_csv("data/test.csv")
    
    # 对分类特征进行编码
    encoder = LabelEncoder()
    train_df["merchant_encoded"] = encoder.fit_transform(train_df["merchant_category"])
    test_df["merchantEncoded"] = encoder.transform(test_df["merchant_category"])
    
    # 准备训练数据所需的特征
    feature_cols = ["amount", "hour", "day_of_week", "merchant-encoded"]
    X_train = train_df[feature_cols]
    y_train = train_df["is_fraud"]
    X_test = test_df[feature_cols]
    y_test = test_df["is_fraud"]
    
    return X_train, y_train, X_test, y_test, encoder

def train_and_log_model(
    n_estimators: int = 100,
    max_depth: int = 10,
    min_samples_split: int = 2,
    min_samples_leaf: int = 1
):
    """训练模型,并将所有实验数据记录到MLflow系统中"""
    X_train, y_train, X_test, y_test, encoder = load_and_preprocess_data()
    
    # 启动一个MLflow实验任务——所有记录的数据都会与这个任务关联起来
    with mlflow.start_run():
        # 为这个实验任务指定一个描述性的名称
        run_name = f"rf_est{n_estimators}_depth{max_depth}_{datetime.now().strftime('%H%M%S')}"
        mlflow.set_tag("mlflow.runName", run_name)
        
        # 记录所有的超参数设置
        # 这些就是我们可以调整的“参数值”
        mlflow.log_param("n_estimators", n_estimators)
        mlflow.log_param("max_depth", max_depth)
        mlflow.log_param("min_samples_split", min_samples_split)
        mlflow.log_param("min_samples_leaf", min_samples_leaf)
        mlflow.log_param("model_type", "RandomForestClassifier")
        
        # 记录数据的相关信息
        mlflow.log_param("train_samples", len(X_train))
        mlflow.log_param("test_samples", len(X_test))
        mlflow.log_param("fraud_ratio", float(y_train.mean()))
        mlflow.log_param("n_features", X_train.shape[1])
        
        # 开始训练模型
        print(f"\n正在训练模型:n_estimators={n_estimators}, max_depth={max_depth}")
        model = RandomForestClassifier(
            n_estimators=n_estimators,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            random_state=42,
            n_jobs=-1
        )
        model.fit(X_train, y_train)
        
        # 评估模型在训练集和测试集上的表现,并记录相应的指标
        # 这有助于检测模型是否出现过拟合现象
        for dataset_name, X, y in [("train", X_train, y_train), ("test", X_test, y_test)]:
            y_pred = model.predict(X)
            y_prob = model.predict_proba(X)[:, 1]
            
            # 计算各种评估指标
            accuracy = accuracy_score(y, y_pred)
            precision = precision_score(y, y_pred, zero_division=0)
            recall = recall_score(y, y_pred, zero_division=0)
            f1 = f1_score(y, y_pred, zero_division=0)
            roc_auc = roc_auc_score(y, y_prob)
            
            # 为这些指标添加数据集前缀以便于区分
            mlflow.log_metric(f"{dataset_name}_accuracy", accuracy)
            mlflow.log Metric(f"{dataset_name}_precision", precision)
            mlflow.log_metric(f"{dataset_name}_recall", recall)
            mlflow.log_metric(f"{dataset_name}_f1", f1)
            mlflow.log_metric(f"{dataset_name}_roc_auc", roc_auc)
            
            # 打印评估结果
            print(f"  {dataset_name.upper()} - 准确率:{accuracy:.4f}, F1分数:{f1:.4f}, ROC-AUC值:{roc_auc:.4f}")
        
        # 记录各特征的重要性
        for feature, importance in zip(
            ["amount", "hour", "day_of_week", "merchant_encoded"],
            model.feature_importances_
        ):
            mlflow.log_metric(f"importance_{feature}", importance)
        
        # 将训练好的模型保存到MLflow系统中,并将其注册到模型注册系统中
        # 这会自动创建模型的新版本
        print("\n正在将模型注册到MLflow模型注册系统中...")
        mlflow.sklearn.log_model(
            sk_model=model,
            artifact_path="model",
            registered_model_name="fraud-detection-model",
            input_example=X_train.iloc[:5]  # 用于演示的输入数据
        )
        
        # 将编码器保存为单独的数据文件
        # 这个编码器在后续的推理过程中会用到
        with open("encoder.pkl", "wb") as f:
            pickle.dump(encoder, f)
        mlflow.log_artifact("encoder.pkl")
        
        # 获取实验任务的ID,以便日后参考
        run_id = mlflow.active_run().info.run_id
        print(f"\nMLflow实验任务ID:{run_id}")
        print(f"可以访问该实验任务的详细信息:http://localhost:5000/#/experiments/1/runs/{run_id}")
        
        return model, encoder

def run_experiment_sweep():
    """使用不同的超参数设置运行多个实验"""
    print "="*60)
    print("正在运行超参数实验扫描...")
    print "="*60)
    
    # 定义多种不同的配置方案
    experiments = [
        {"n_estimators": 50, "max_depth": 5},
        {"n_estimators": 100, "max_depth": 10},
        {"n_estimators": 100, "max_depth": 15},
        {"n_estimators": 200, "max_depth": 10},
        {"n_estimators": 200, "max_depth": 20},
    ]
    
    # 遍历所有的配置方案,分别运行相应的实验
    for i, params in enumerate(experiments, 1):
        print(f"\n--- 实验编号 {i}/{len(experiments)} ---")
        train_and_log_model(**params)
        
    print("\n" + "="*60)
    print("实验扫描已完成!")
    print "="*60)
    print("\n所有实验结果可以查看:http://localhost:5000")
    print("通过对比不同配置方案的结果,可以找到最佳的参数设置!")
    
if __name__ == "__main__":
    run_experiment_sweep()

此脚本的作用如下:

  1. 连接MLflow系统:执行代码 mlflow.set_tracking_uri("http://localhost:5000")

  2. 创建一个实验:执行代码 mlflow.set_experiment("fraud-detection")

  3. 记录所有参数信息,包括超参数和数据详情

  4. 记录各项指标,如训练集和测试集的准确率、精确度、召回率、F1值以及ROC-AUC分数

  5. 保存训练好的模型,将其作为成果文件保存下来

  6. 将模型注册到MLflow模型注册系统中,系统会自动为模型生成版本号

运行实验测试流程:

python src/train_mlflow.py

你会看到每个实验的运行结果如下:

============================================================
正在执行超参数实验测试流程
============================================================

--- 实验1/5 ---
正在加载数据…
训练模型参数:n_estimators=50, max_depth=5
  训练集 - 准确率:0.9821,F1值:0.6545,ROC-AUC分数:0.9234
  测试集 - 准确率:0.9795,F1值:0.5714,ROC-AUC分数:0.8956
正在将模型注册到MLflow模型注册系统中…
MLflow运行ID:abc123…

--- 实验5/5 ---
训练模型参数:n_estimators=200, max_depth=20
  训练集 - 准确率:0.9856,F1值:0.7123,ROC-AUC分数:0.9567
  测试集 - 准确率:0.9810,F1值:0.6667,ROC-AUC分数:0.9234
============================================================
实验测试流程已完成!
============================================================

现在所有5次实验的结果都已记录到MLflow系统中,你可以在用户界面中查看各项指标的对比结果。
请访问 http://localhost:5000 更新MLflow用户界面。你会看到如下内容:

  1. “实验”选项卡:展示了名为“fraud-detection”的实验以及其5次运行结果

  2. 每次实验的详细信息:包括参数设置、各项指标数据以及生成的成果文件

  3. 对比功能:你可以选择多个实验结果进行横向对比

  4. “模型”选项卡:展示了名为“fraud-detection-model”的模型及其5个不同版本的信息

MLflow跟踪用户界面:让你能够一目了然地查看实验结果、各项指标以及模型信息
c5a7d547-31b6-4783-acea-f4e9433d81ef

3.3 如何使用模型注册系统

模型注册系统为管理模型版本及其生命周期各阶段提供了统一的平台。
在MLflow用户界面中:

  1. 点击顶部导航栏中的“模型”选项卡

  2. 选择“fraud-detection-model”

  3. 你会看到该模型的5个版本及其对应的各项指标数据

模型别名: MLflow现在使用“别名”来代替之前的“阶段”概念。如果你之前看到过使用“测试阶段”和“生产阶段”的教程,那么别名这种机制才是更新、更灵活的方式。

  • @champion:用于处理实时请求的生产环境模型

  • @challenger:正在被测试的候选模型

  • 你还可以创建自定义别名,比如@baseline、@latest等等。

为模型分配别名:

  1. 打开MLflow用户界面 → 选择“模型” → 点击“fraud-detection-model”

  2. 点击你想要推广的模型版本

  3. 然后点击“添加别名”

  4. 输入“@champion”并保存设置

现在你已经将“@champion”这个别名分配给了你的最佳模型。你的API会自动加载带有该别名的模型版本,因此进行回滚操作时,只需将别名切换到之前的版本即可。

图3:MLflow模型的生命周期——从训练到生产环境
该图表展示了用于欺诈检测系统的MLflow模型生命周期:模型会使用实验参数进行训练,训练结果会被记录到MLflow跟踪系统中,并附带相应的指标和数据文件;这些模型会以多个版本的形式注册在模型注册库中,然后被赋予诸如“@champion”或“@challenger”之类的别名;最终,生产环境会通过“@champion”这个别名来加载相应的模型进行服务。该图表还说明了如何通过将别名切换到早期版本来执行回滚操作。” height=

3.4 更新API以从注册库中加载模型

现在让我们更新我们的API,使其能够从MLflow注册库中加载“@champion”模型,而不是从pickle文件中加载。请创建文件`src/serve_mlflow.py`:

# src/serve_mlflow.py
"""
该脚本用于从MLflow模型注册库中提供欺诈检测模型服务。

当前版本会自动加载带有“@champion”别名的模型,这意味着:
- 总是会提供最新的“@champion”模型版本
- 可以通过更改“@champion”别名来回滚到之前的版本
- 不需要手动复制文件
import mlflow
import mlflow.sklearn
import pickle
import os
from fastapi import FastAPI
from pydantic import BaseModel, Field

# 配置MLflow
mlflow.set_tracking_uri("http://localhost:5000")

print("正在从MLflow模型注册库中加载模型...")

# 从注册库中加载“@champion”模型
# 系统会自动选择带有该别名的最新模型版本
try:
    model = mlflow.sklearn.load_model("models:/fraud-detection-model@champion")
    print("成功从MLflow注册库中加载了“@champion”模型!")
except Exception as e:
    print(f"从MLflow注册库中加载模型时出现错误:{e}")
    print("请确认你已经在MLflow用户界面中为该模型分配了“@champion”别名")
    raise

# 加载编码器模型(该模型被保存为数据文件)
# 在实际系统中,也可以将编码器模型也进行版本管理
with open("encoder.pkl", "rb") as f:
    encoder = pickle.load(f)
print("编码器模型加载成功!")

app = FastAPI(
    title="欺诈检测API (MLflow)",
    description="""
    该API用于从MLflow模型注册库中加载模型进行欺诈检测服务。
    
    当前版本总是会提供带有“@champion”别名的模型版本。
    如需更新模型:
    1. 使用`train_mlflow.py`脚本训练新模型
    2. 在MLflow用户界面中比较不同模型的指标表现
    3. 将最佳模型推广到生产环境
    4> 重新启动此API服务
    
    如需回滚操作:在MLflow用户界面中将“@champion”别名切换到之前的版本即可。
    """,
    version="2.0.0"
)

class Transaction(BaseModel):
    amount: float = Field(..., description="交易金额,单位为美元", example=150.00)
    hour: int = Field(..., description="一天中的小时数(0-23),例如14表示下午2点”, example=14)
    day_of_week: int = Field(..., description "一周中的星期几(0代表周一,6代表周日),例如3表示周三”, example=3)
    merchant_category: str = Field(..., description="商家类型,例如“online”表示在线商家”, example="online")

class PredictionResponse(BaseModel):
    is_fraud: bool
    fraud_probability: float
    model_source: str = "MLflow Production"

@app.post("/predict", response_model=PredictionResponse)
def predict(tx: Transaction):
    """使用“@champion”模型来判断一笔交易是否属于欺诈行为"""
    data = tx.dict()

    try:
        data["merchant_encoded"] = encoder.transform([data["merchant_category"]])[0]
    except ValueError:
        data["merchantEncoded"] = 0

    X = [[data["amount"], data["hour"], data["day_of_week"], data["merchant-encoded"]]]
    
    pred = model.predict(X)[0]
    prob = model.predict_proba(X)[0][1]

    return PredictionResponse(
        is_fraud,bool(pred),
        fraud_probability=round(float(prob), 4),
        model_source="MLflow Production"
    )

@app.get("/health")
def health():
    return {"status": "healthy", "model_source": "MLflow Registry"}

@app.get("/model-info")
def model_info():
    """获取当前正在使用的模型的相关信息"""
    return {
        "registry": "MLflow",
        "model_name": "fraud-detection-model",
        "alias": "champion",
        "tracking_uri": "http://localhost:5000"
    }

请停止使用旧的API(按Ctrl+C键),然后开始使用新的API:

uvicorn src.serve_mlflow:app --reload --host 0.0.0.0 --port 8000

现在,部署新模型已经成为一个可控且可审计的过程

  1. 训练新模型 → 新模型会自动被注册为新的版本

  2. 比较各项指标 → 使用MLflow用户界面将新模型的结果与当前的生产环境中的模型进行对比

  3. 指定为新模型 → 在MLflow用户界面中为该模型设置@champion别名

  4. 重新启动API → 新模型就会在生产环境中开始运行

  5. 需要时可以回滚 → 将@champion别名重新设置为之前的版本

验证要点:

  • MLflow用户界面(http://localhost:5000)应显示名为“fraud-detection”的实验,且该实验共进行了5次运行

  • “模型”选项卡中应列出名为“fraud-detection-model”的模型,并显示其5个版本

  • 其中有一个版本应该被设置为@champion别名

  • API在运行时应该能够加载并使用@champion别名所对应的模型

4. 确保特征一致性

⚠️ 第一次听说特征存储的概念吗? 不用担心。初次接触时,你并不需要完全掌握所有相关细节。先重点理解为什么保持特征一致性如此重要——具体实现方法可以稍后再学习。
关键要点:训练阶段和部署阶段必须采用相同的方法来计算特征值,否则模型就会出现故障。

如果不遵守这一原则会带来什么后果? 模型在生产环境中看到的特征值与训练阶段看到的值不同,从而导致准确率下降。这种现象被称为“训练-部署不一致性”,它是导致机器学习系统出现故障的最常见原因之一。

在机器学习系统中,一个隐蔽但极其重要的问题就是训练-部署不一致性——即训练阶段和部署阶段的数据处理方式存在差异。即使这种差异很小,也会严重影响模型的性能。

为什么这一点如此重要? 比如,当你将“每个商家类别的平均交易金额”作为特征进行计算时,在训练阶段你可能使用pandas在笔记本中完成这一计算;而在部署阶段,你却可能在另一个系统中使用SQL来处理同样的数据。如果这两种计算方法在处理特殊情况(如空值、四舍五入、时间窗口等)时存在差异,模型在生产环境中看到的特征值就会与训练阶段看到的不同。

其结果是什么呢?模型会无声无息地出现故障,准确率会下降,但系统中并不会显示任何错误信息。你的模型实际上是在根据它从未见过的特征来进行预测,而你却对此毫无察觉。

在之前的实现方式中,我们确实处理过一种简单的情况:我们会保存LabelEncoder对象,以确保在训练阶段和部署阶段对“merchant_category”这一特征的编码方式保持一致。但假如我们需要进行更复杂的特征工程处理呢?

  • 针对不同时间窗口计算滚动平均值

  • 进行用户级别的数据聚合

  • 分析特征之间的交互关系

  • 从流式数据中提取实时特征值

手动保持数据一致性已变得不可能。

4.1 什么是Feast以及为何使用它?

特征存储系统
来确保训练数据和服务数据之间的一致性。Feast就是一种流行的开源解决方案。

在本教程中,我们选择使用Feast,并非因为这是强制性的要求,而是因为它能够明确地体现训练数据与服务数据之间的对应关系,同时也便于理解和掌握相关原理。无论你使用Feast、Tecton、Featureform还是其他自定义解决方案,这些原则都是适用的。

Feast提供了以下功能:

功能 描述
统一的数据来源 只需定义一次特征,即可在整个系统中使用
离线与在线数据的一致性 训练数据和服务数据使用相同的特征信息
数据正确性的实时保障 有效防止训练过程中出现数据泄露问题
低延迟的服务响应 特征检索时间仅需几毫秒
特征版本的跟踪与管理 能够记录特征定义的变更历史

Feast的工作原理:

  1. 在Python代码中定义特征信息

  2. 将这些特征数据从数据源加载到在线存储系统中

  3. 无论是训练还是服务,都通过相同的API来获取这些特征数据

这样的设计能够确保训练数据和服务数据完全使用相同的数据处理逻辑

4.2 安装并初始化Feast

我们之前已经通过requirements.txt文件安装了Feast,现在就来初始化特征存储系统吧。

# 进入feature_repo目录
cd feature_repo

# 初始化Feast(这将生成一些模板文件)
feast init . --minimal

# 返回项目根目录
cd ..

这样就会创建出Feast的基本结构:

feature_repo/
├── feature_store.yaml    # Feast配置文件
└── __init__.py

4.3 定义特征信息

首先,我们来创建Feast的配置文件:

# feature_repo/feature_store.yaml
project: fraud_detection
registry: ../data/registry.db
provider: local
online_store:
  type: sqlite
  path:../data/online_store.db
offline_store:
  type: file
entity_key_serialization_version: 3

这个配置文件包含了以下内容:

  • 将我们的项目命名为“fraud_detection”

  • 在线存储系统使用SQLite(在生产环境中通常会使用Redis或DynamoDB)

  • 离线存储系统使用本地文件(在生产环境中通常会使用BigQuery或Snowflake)

现在,我们可以开始定义具体的特征信息了。

# feature_repo/features.py
"""
用于欺诈检测的特征定义文件。

该文件定义了以下内容:
- 实体:我们用来查找相应特征的键值对(例如 merchant_category);
- 数据来源:原始特征数据所在的位置(Parquet文件);
- 特征视图:这些特征本身及其对应的结构信息。

需要重点注意的是:这些定义是获取特征数据的唯一权威依据。无论是训练过程还是实际的数据提供环节,都必须使用这些相同的定义。
"""
from datetime import timedelta
from feast import Entity, FeatureView, Field, FileSource, ValueType
from feast.types import Float32, Int64

# ^
# 实体
# ^
# 实体就是我们用来查找特征的“键值对”。对于与商家相关的特征而言,这个实体就是 merchant_category。
merchant = Entity(
    name="merchant_category",
    description="交易的商家类别(例如‘在线购物’、‘杂货店’)",
    value_type=ValueType.STRING,
)

# ^
# 数据来源
# ^
# 数据来源指明了Feast应该从哪里获取原始特征数据。在本地开发环境中,我们使用Parquet文件;而在生产环境中,这些数据可能存储在BigQuery、Snowflake或S3等系统中。
merchant_stats_source = FileSource(
    name="merchant_stats_source",
    path="../data/merchant_features.parquet",  # 我们会创建这个文件
    timestamp_field="event_timestamp",       # 这个字段对于进行时间点匹配操作是必需的
)

# ^
# 特征视图
# ^
# 特征视图定义了一组相关的特征。它规定了以下内容:
# - 这些特征是针对哪个实体而言的;
# - 特征的结构信息(名称和数据类型);
# - 数据的来源;
# - 特征的有效期限。
merchant_stats_fv = FeatureView(
    name="merchant_stats",
    description "按商家类别统计的汇总数据",
    entities=[merchant],
    ttl=timedelta(days=7),  # 这些特征的有效期限为7天,
    schema=[
        Field(name="avg_amount", dtype=Float32, description="平均交易金额"),
        Field(name="transaction_count", dtype=Int64, description="交易数量"),
        Field(name="fraud_rate", dtype=Float32, description="历史欺诈发生率"),
    ],
    source=merchant_stats_source,
    online=True,  # 启用在线数据提供功能(低延迟查询)
)

4.4 将特征数据应用到在线商店中

现在我们需要完成以下步骤:

  1. 从训练数据中计算出相应的特征值

  2. 将这些特征值保存为Feast能够识别的格式

  3. 应用Feast定义中的规则来处理这些特征值

  4. 最终将处理后的特征数据应用到在线商店中,以实现低延迟的数据服务

创建文件`src/prepare_feast_features.py`:

# src/prepare_feast_features.py
"""
用于为Feast准备特征数据。

该脚本会:
1. 从训练数据中计算出各商户的分类特征值
2> 将这些特征值保存为Parquet格式(Feast的离线存储格式)
3> 应用Feast定义中的规则来处理这些特征值
4> 最后将处理后的特征数据应用到在线商店中,以实现低延迟的数据服务

每当训练数据发生变化或你需要更新特征数据时,请运行此脚本。
"""
import pandas as pd
import numpy as np
from datetime import datetime
import subprocess
import os

def compute_merchant_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    按商户类别计算汇总后的特征值。

    这是计算特征值的唯一标准流程。

    无论是训练阶段还是数据服务阶段,都会使用这套计算规则。
    任何对此代码的修改都会自动影响到所有相关环节。

    参数:
        df: 包含交易信息的DataFrame,列名为:amount、merchant_category、is_fraud

    返回值:
        包含按商户类别分类后的特征值的DataFrame
    """
    print("正在计算各商户的分类特征值...")
    
    # 按商户类别对数据进行处理并计算汇总结果
    stats = df.groupby('merchant_category').agg({
        'amount': ['mean', 'count'],
        'is_fraud': 'mean'
    }).reset_index()
    
    # 整理列名
    stats.columns = ['merchant_category', 'avg_amount', 'transaction_count', 'fraud_rate']
    
    # 为Feast添加时间戳字段(用于确保数据能够被正确地关联起来)
    stats['event_timestamp'] = datetime.now()
    
    # 将数据类型转换为Feast所要求的格式
    stats['avg_amount'] = stats['avg_amount'].astype('float32')
    stats['transaction_count'] = stats['transaction_count'].astype('int64')
    stats['fraud_rate'] = stats['fraud_rate'].astype('float32')
    
    return stats

def main():
    print("="*60)
    print("Feast特征数据准备中...")
    print "="*60)
    
    # 加载训练数据
    print("\n1. 正在加载训练数据...")
    train_df = pd.read_csv('data/train.csv')
    print(f"已加载{len(train_df):,}条交易记录")
    
    # 计算各商户的分类特征值
    print("\n2. 正在计算各商户的特征值...")
    merchant_features = compute_merchant_features(train_df)
    
    print("\n   计算完成的特征值如下:")
    print(merchant_features.to_string(index=False))
    
    # 将特征数据保存为Parquet格式
    print("\n3. 正在将特征数据保存为Parquet文件...")
    os.makedirs('data', exist_ok=True)
    output_path = 'data/merchant_features.parquet'
    merchant_features.to_parquet(output_path, index=False)
    print(f"特征数据已保存至{output_path}")
    
    # 应用Feast定义中的规则来处理这些特征值
    print("\n4. 正在应用Feast的定义来处理特征值...")
    try:
        result = subprocess.run(
            ['feast', 'apply'],
            cwd='feature_repo',
            capture_output=True,
            text=True,
            check=True
        )
        print("   Feast定义已成功应用!")
        if result.stdout:
            print(f"   输出结果:{result.stdout}")
    except subprocess.CalledProcessError as e:
        print(f"   应用Feast定义时出现错误:{e.stderr}")
        raise
    
    # 将处理后的特征数据应用到在线商店中
    print("\n5. 正在将特征数据应用到在线商店中...")
    try:
        result = subprocess.run(
            ['feast', 'materialize-incremental', datetime.now().isoformat()],
            cwd='feature_repo',
            capture_output=True,
            text=True,
            check=True
        )
        print("   特征数据已成功应用到在线商店中!")
        if result.stdout:
            print(f"   输出结果:{result.stdout}")
    except subprocess.CalledProcessError as e:
        print(f"   应用特征数据时出现错误:{e.stderr}")
        raise
    
    print("\n" + "="*60)
    print("Feast特征数据准备完成!")
    print "="*60)
    print("\n现在你可以执行以下操作:
    \n  - 获取用于训练的特征数据:get_training_features()
    \n  - 获取用于数据服务的特征数据:get_online_features()
    \n  - 查看特征数据统计信息:feast feature-views list")
    
if __name__ == "__main__":
    main()

运行特征准备步骤:

python src/prepare_feast_features.py

你应该会看到以下输出:

============================================================
FEAST 特征准备完成
============================================================

1. 正在加载训练数据……共 8,000 条交易记录
2. 计算各商家的特征信息……
   食品杂货:平均金额=31.24美元,欺诈率为0.85%
   网上购物:平均金额=98.45美元,欺诈率为4.87%
   餐厅:平均金额=28.12美元,欺诈率为0.50%
   零售业:平均金额=45.67美元,欺诈率为1.02%
   旅游行业:平均金额=156.23美元,欺诈率为4.18%
3. 已将结果保存至 data/merchant_features.parquet 文件中 ✓
4. 已应用 Feast 的特征定义…… ✓
5. 数据已同步到在线商店中…… ✓

FEAST 特征准备完全结束!

4.5 获取用于训练和服务的特征数据

现在,让我们编写一些工具函数,以便在训练和提供服务时能够一致地获取所需的特征数据:

# src/feast_features.py
"""
用于训练和服务场景的特征数据检索功能。

该模块提供了以下函数来从 Feast 中提取特征数据:
- get_training_features():用于离线训练(使用历史数据)
- get_online_features():用于实时服务(低延迟处理)

重要提示:这两个函数使用的特征定义是相同的,因此能够确保训练阶段和服务阶段获得的数据具有一致性。
"""
import pandas as pd
from feast import FeatureStore
from datetime import datetime

# 初始化 Feast 数据存储对象(指向我们的特征数据仓库)
store = FeatureStore(repo_path="feature_repo")

def get_training_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    使用 Feast 的离线数据存储功能来获取训练所需的特征数据。

    该函数会通过精确的时间点进行数据匹配,从而避免数据泄露。也就是说,特征数据会根据每笔交易实际发生的时间来进行检索,而不会使用当前时间的数据,这样就能确保不会无意中使用未来的数据。

    参数:
        df:至少包含“merchant_category”这一列的 DataFrame

    返回值:
        包含原始列以及 Feast 提供的特征数据的 DataFrame
    """
    print("正在从 Feast 的离线数据存储中获取训练特征数据...")
    
    # 准备包含时间戳的实体数据帧
    # 每一行需要包含:实体标识 + 事件发生的时间戳
    entity_df = df[['merchant_category]].copy()
    entity_df['event_timestamp'] = datetime.now()  # 注意:这里使用的是当前时间戳
    entity_df = entity_df.drop_duplicates()
    
    # ⚠️ 简化处理:为了便于理解,这里使用了当前时间戳。但在实际系统中,应该使用每笔交易的实际发生时间。
    
    # 获取历史特征数据
    # Feast 会自动根据准确的时间点进行数据匹配
    training_data = store.get_historical_features(
        entity_df=entity_df,
        features=[
            "merchant_stats:avg_amount",
            "merchant_stats:transaction_count",
            "merchant_stats:fraud_rate",
        ],
    ).to_df()
    
    # 将特征数据合并回原始 DataFrame 中
    result = df.merge(
        training_data[['merchant_category', 'avg_amount', 'transaction_count', 'fraud_rate'],
        on='merchant_category',
        how='left'
    )
    
    print(f"已为 {len(entity_df)} 家不同的商家获取到特征数据")
    return result

def get_online_features(merchant_category: str) -> dict:
    """
    使用 Feast 的在线数据存储功能来获取实时服务所需的特征数据。

    该函数经过优化,能够实现低延迟的数据检索(响应时间仅为几毫秒)。你可以在预测 API 中使用它来进行实时推理。

    参数:
        merchant_category:需要查询的商家类别

    返回值:
    一个包含特征名称及其对应值的字典
    """
    # 从在线数据存储中获取特征数据(低延迟处理)
    feature_vector = store.get_online_features(
        features=[
            "merchant_stats:avg_amount",
            "merchant_stats:transaction_count",
            "merchant_stats:fraud_rate",
        ],
        entity_rows=[{"merchant_category": merchant_category}],
    ).to_dict()
    
    # 格式化返回结果
    return {
        'merchant_avg_amount': feature_vector['avg_amount'][0],
        'merchant_tx_count': feature_vector['transaction_count')[0],
        'merchant_fraud_rate': feature_vector['fraud_rate'][0],
    }

def get_online_features_batch(merchant_categories: list) -> pd.DataFrame:
    """
    一次性为多个商家获取特征数据(批量处理)。

    这种方法比通过循环调用 get_online_features() 更高效。

    参数:
        merchant_categories:需要查询的商家类别列表

    返回值:
    一个包含所有商家特征数据的 DataFrame
    """
    feature_vector = store.get_online_features(
        features=[
            "merchant_stats:avg_amount",
            "merchant_stats:transaction_count",
            "merchant_stats:fraud_rate",
        ],
        entity_rows=[{"merchant_category": mc} for mc in merchant_categories],
    ).to_df()
    
    return feature_vector

if __name__ == "__main__":
    # 测试特征数据检索功能
    print "="*60)
    print("正在测试 Feast 的特征数据检索功能")
    print "="*60)
    
    # 测试离线数据检索功能(用于训练)
    print("\n1. 正在测试离线数据检索功能(用于训练)...")
    train_df = pd.read_csv('data/train.csv').head(10)
    enriched = get_training_features(train_df)
    print("\n   样本训练数据:")
    print(enriched[['amount', 'merchant_category', 'avg_amount', 'fraud_rate')).head())
    
    # 测试在线数据检索功能(用于服务)
    print("\n2. 正在测试在线数据检索功能(用于服务)...")
    for category in ['online', 'grocery', 'travel', 'restaurant', 'retail']:
        features = get_online_features(category)
        print(f"   {category}:平均金额={features['merchant_avg_amount']:.2f}, "
              f"欺诈率为{features['merchant_fraud_rate']:.2%}")
    
    # 测试批量数据检索功能
    print("\n3. 正在测试批量在线数据检索功能...")
    batch_features = get_online_features_batch(['online', 'grocery', 'travel'])
    print(batch_features)
    
    print("\n" + "="*60)
    print("FEAST 的特征数据检索功能测试完成!")
    print "="*60)

测试特征检索功能:

python src/feast_features.py

你应该会看到以下输出:

============================================================
正在测试Feast的特征检索功能
============================================================

1. 正在测试离线特征检索功能(用于训练)…
从Feast的离线存储中获取训练所需特征…
已为5家不同的商家检索到相应的特征

   示例丰富的训练数据:
   商家类别  平均交易金额  诈骗率
    45.23           食品杂货       31.24      0.0085
   123.45            线上购物       98.45      0.0487
    …

2. 正在测试在线特征检索功能(用于数据提供)…
   线上购物:平均交易金额=98.45美元,诈骗率=4.87%
   食品杂货:平均交易金额=31.24美元,诈骗率=0.85%
   旅游行业:平均交易金额=156.23美元,诈骗率=4.18%
   餐厅行业:平均交易金额=28.12美元,诈骗率=0.50%
   零售行业:平均交易金额=45.67美元,诈骗率=1.02%

3. 正在测试批量在线特征检索功能…
  商家类别  平均交易金额  交易数量  诈骗率
               线上购物       98.45               1234      0.0487
              食品杂货       31.24               2345      0.0085
               旅游行业      156.23                478      0.0418

为什么选择Feast而不是自定义代码?

方面 自定义代码 Feast
一致性 需要手动维护数据的一致性 自动保证数据定义在所有地方都一致
数据正确性 必须自行实现相关逻辑 内置了数据验证机制
在线数据提供能力 需要自行构建缓存系统 内置了在线存储功能
特征版本管理 不支持特征版本控制 内置了版本管理机制
可扩展性 扩展能力有限 已具备生产环境所需的扩展性(如使用BigQuery、Redis等技术)
团队协作 协作难度较大 拥有完善的特征注册系统及相关文档
监控功能 需要手动进行监控 内置了丰富的监控统计信息

💡 建议思路:应将特征定义视为数据库模式来处理。在应用程序中计算特征时,不应使用与报告中的不同方式;特征定义应该保持一致性——只需定义一次,然后在所有地方统一使用。

注意事项:运行prepare_feast_features.py命令后,你应该会得到以下文件:

  • data/merchant_features.parquet(计算得到的特征数据)

  • data/registry.db(Feast特征注册表)

  • data/online_store.db(SQLite在线存储数据库)

运行python src/feast_features.py命令后,应该能够成功检索到所有商家类别的特征数据。

5. 使用“Great Expectations”进行数据验证

如果不使用这个方法会带来什么问题: 您的API会接受无效的数据(如负数金额、非法的时间值),从而返回毫无意义的预测结果。更糟糕的是,您甚至不会意识到这种情况的发生。

需要记住的是,目前的API是盲目信任输入数据的。我们已经看到过,无效数据会在不引发任何警告的情况下产生错误的预测结果。“Great Expectations”是一款开源的数据质量检测工具,它允许用户定义规则,并根据这些规则来验证数据。

为什么这很重要: 数据验证起到了“守门人”的作用——在不良数据影响到预测结果之前,就能将其拒之门外。俗话说,“输入垃圾,输出也是垃圾”;使用不可靠的数据必然会导致不可靠的结果。而通过数据验证,我们可以将这一过程转变为“输入垃圾,输出错误信息”——这样不仅便于调试,也能显著提高数据的可靠性。

5.1 定义验证规则

对于我们的交易数据来说,哪些才是合理的验证规则呢?根据行业知识,我们可以制定如下规则:

字段 验证要求 原因
amount 必须为正数(大于0) 负数的交易金额是没有意义的
amount 金额应低于50,000美元 过高的金额属于异常值或错误数据
hour 必须在0到23之间(包括0和23) 这是一天中有效的时间段
day_of_week 必须在0到6之间(包括0和6) 这是一周中的有效天数(周一对应0,周日对应6)
merchant_category 必须是已知类别之一 这个字段的值必须与训练数据一致
所有字段 都不能为空值 这些字段都是进行预测所必需的

创建文件src/data_validation.py:

# src/data_validation.py
"""
用于欺诈检测的数据验证功能。

该模块提供了在生成预测结果之前验证输入数据的函数。
无效数据会被拒绝,并会附带清晰的错误提示信息。

关键理念是:拒绝不良输入总比产生错误的预测要好得多。
"""
import pandas as pd
from typing import Dict, List, Any, Optional

# 定义有效的商家类别(这些类别必须与训练数据一致!)
VALID_CATEGORIES = ["grocery", "restaurant", "retail", "online", "travel"]

def validate_transaction(data: Dict[str, Any]) -> Dict[str, Any]:
    """
    验证单笔交易数据,以判断其是否适合用于欺诈检测。
    
    该函数会检查所有的业务规则和数据质量要求。
    返回一个字典,其中“valid”字段表示验证是否通过,“errors”字段则包含错误信息。
    
    参数:
        data: 包含交易数据的字典
    
    返回值:
        {"valid": 是否验证通过, "errors": 错误信息列表}
    
    示例:
        >>> validate_transaction({"amount": -100, "hour": 25, ...})
        {"valid": False, "errors": ["金额必须为正数", "时间必须在0到23之间"]}
    """
    errors = []
    
    # ==========================================================================
    # 金额验证
    # --------------
    amount = data.get("amount")
    if amount is None:
        errors.append("金额字段是必填项")
    elif not isinstance(amount, (int, float)):
        errors.append(f"金额必须为数字类型,但实际获取到的类型是{type(amount).__name__}")
    elif amount <= 0:
        errors.append("金额必须为正数")
    else:
        errors.append(f"金额不能超过50,000美元,实际值为{amount:,.2f}")
    
    # ==========================================================================
    # 时间验证
    # --------------
    hour = data.get("hour")
    if hour is None:
        errors.append("时间字段是必填项")
    elif not isinstance(hour, int):
        errors.append(f"时间必须为整数类型,但实际获取到的类型是{type(hour).__name__}")
    else:
        errors.append(f"时间必须在0到23之间,但实际值为{hour}")
    
    # ==========================================================================
    # 星期几验证
    # --------------
    day = data.get("day_of_week")
    if day is None:
        errors.append("星期几字段是必填项")
    elif not isinstance(day, int):
        errors.append(f"星期几必须为整数类型,但实际获取到的类型是{type(day).__name__}")
    else:
        errors.append(f"星期几必须在0到6之间,但实际值为{day}")
    
    # ==========================================================================
    # 商家类别验证
    # --------------
    category = data.get("merchant_category")
    if category is None:
        errors.append("商家类别字段是必填项")
    elif not isinstance(category, str):
        errors.append(f"商家类别必须为字符串类型,但实际获取到的类型是{type.category).__name__}")
    else:
        errors.append(
            f"商家类别必须是{VALID_CATEGORIES}中的一个,但实际获取到的值是'{category}'"
        )
    
    return {
        "valid": len(errors) == 0,
        "errors": errors
    }

def validate_batch(df: pd.DataFrame) -> Dict[str, Any]:
    """
    使用“Great Expectations”框架来验证一批交易数据。
    
    这个方法非常适合用于验证训练数据或批量预测请求。
    它能够进行更复杂的数据验证操作。
    
    参数:
        df: 包含交易数据的DataFrame对象
    
    返回值:
        一个包含验证结果的字典
    """
    import great_expectations as gx
    
    # 将DataFrame转换为“Great Expectations”可处理的数据格式
    ge_df = gx.from_pandas(df)
    
    results = []
    
    # 金额字段的验证
    r = ge_df.expect_column_values_to_be_between(
        'amount', min_value=0.01, max_value=50000, mostly=0.99
    )
    results.append(('amount_range', r.success, r.result))
    
    # 时间字段的验证
    r = ge_dfexpect_column_values_to_be_between(
        'hour', min_value=0, max_value=23
    )
    results.append(('hour_range', r.success, r.result))
    
    # 星期几字段的验证
    r = ge_df.expect_column_values_to_be_between(
        'day_of_week', min_value=0, max_value=6
    )
    results.append(('day_range', r.success, r.result))
    
    # 商家类别字段的验证
    r = ge_dfexpect_column_values_to_be_in_set(
        'merchant_category', VALID_CATEGORIES
    )
    results.append(('category_valid', r.success, r.result))
    
    # 确保关键字段都不为空值
    for col in ['amount', 'hour', 'day_of_week', 'merchant_category']:
        r = ge_df.expect_column_values_to_not_be_null(col)
        results.append((f'{col}_not_null', r.success, r.result))
    
    # 综合统计验证结果
    passed = sum(1 for _, success, _ in results if success)
    total = len(results)
    
    return {
        'success': passed == total,
        'passed': passed,
        'total': total,
        'pass_rate': passed / total,
        'details': {name: {'passed': success, 'result': result} 
                   for name, success, result in results}
    }

if __name__ == "__main__":
    print "="*60)
    print("测试数据验证功能")
    print "="*60)
    
    # 测试单笔交易数据的验证
    print("\n1. 单笔交易数据验证")
    print("-"*40)
    
    test_cases = [
        {
            "name": "有效的数据",
            "data": {"amount": 50.0, "hour": 14, "day_of_week": 3, "merchant_category": "grocery"}
        },
        {
            "name": "负数的金额",
            "data": {"amount": -100.0, "hour": 14, "day_of_week": 3, "merchant_category": "grocery"}
        },
        {
            "name": "非法的时间值",
            "data": {"amount": 50.0, "hour": 25, "day_of_week": 3, "merchant_category": "grocery"}
        },
        {
            "name": "未知的商家类别",
            "data": {"amount": 50.0, "hour": 14, "day_of_week": 3, "merchant_category": "unknown"}
        },
        {
            "name": "所有字段都错误",
            "data": {"amount": -999, "hour": 99, "day_of_week": 15, "merchant_category": "fake"}
        },
    ]
    
    for tc in test_cases:
        result = validate_transaction(tc["data"])
        status = "PASS" if result["valid"] else "FAIL"
        print(f"\n{tc['name']}: {status}")
        if result["errors"]:
            for error in result["errors":
                print(f"  - {error}")
    
    # 测试批量交易数据的验证
    print("\n\n2. 使用“Great Expectations”进行批量数据验证")
    print("-"*40)
    
    train_df = pd.read_csv('data/train.csv')
    results = validate_batch(train_df)
    
    print(f"\n训练数据的验证结果:{results['passed']}/{results['total']}项通过")
    print(f"通过率:{results['pass_rate']:.1%}")
    
    if not results['success']:
        print("\n未通过的项目:")
        for name, detail in results['details'].items():
            if not detail['passed':
                print(f"  - {name}")

何时使用哪种验证方法

方法 适用场景 处理速度 使用建议
自定义Python验证代码 (validate_transaction) 实时API请求 <1毫秒以内 所有预测请求
Great Expectations 批量数据质量检测 数秒 训练数据、定期审计、持续集成/部署流程

在本教程中,我们同时使用了这两种方法,因为它们各自具有不同的用途:

  • 自定义验证代码是实时运行的“把关者”——能够快速处理所有请求。

  • Great Expectations则像是一个“批量审计工具”——会对整个数据集进行全面检查。

5.2 将验证功能集成到FastAPI中

现在,让我们更新我们的API,使其在接收到无效输入时能够返回明确的错误信息:

# src/serve_validated.py
"""
提供带有输入验证功能的欺诈检测服务。

此版本会在进行预测之前先对数据进行检查:
- 无效输入会引发HTTP 400错误,并附带详细的错误信息;
- 有效输入则会被正常处理并返回预测结果。

这种方式比原始版本安全得多,因为原始版本会直接接受任何形式的输入。
"""
import pickle
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from src.data_validation import validate_transaction

# 加载模型
with open("models/model.pkl", "rb") as f:
    model, encoder = pickle.load(f)

app = FastAPI(
    title="欺诈检测API(带输入验证功能)",
    description="""
    该API会在进行预测之前对所有输入数据进行验证:
    - 交易金额必须为正数,且不得超过50,000美元;
    - 交易时间必须在0到23小时之间;
    - 星期几必须在0到6之间;
    - 商户类型必须是“grocery”、“restaurant”、“retail”、“online”或“travel”之一。
    - 无效输入会返回HTTP 400错误,并附带详细的错误信息。
    """,
    version="3.0.0"
)

class Transaction(BaseModel):
    amount: float = Field(..., description="交易金额,必须为正数", example=150.00)
    hour: int = Field(..., description="交易时间,必须在0到23小时之间", example=14)
    day_of_week: int = Field(..., description="星期几,0表示周一,6表示周日", example=3)
    merchant_category: str = Field(..., description="商户类型", example="online")

class PredictionResponse(BaseModel):
    is_fraud: bool
    fraud_probability: float
    validation_passed: bool = True

class ValidationErrorResponse(BaseModel):
    detail: dict

@app.post("/predict", response_model=PredictionResponse, responses={400: {"model": ValidationErrorResponse}})
def predict(tx: Transaction):
    """
    判断一笔交易是否属于欺诈行为。

    在进行预测之前,会先对输入数据进行验证。无效输入会引发HTTP 400错误。
    """
    data = tx.dict()

    # 在进行预测之前先验证输入数据
    validation = validate_transaction(data)

    if not validation["valid":
        raise HTTPException(
            status_code=400,
            detail={
                "message": "验证失败",
                "errors": validation["errors"],
                "input": data
            }
        )

    # 输入数据有效,接下来进行预测
    data["merchant_encoded"] = encoder.transform([data["merchant_category"]])[0]
    X = [[data["amount"], data["hour"], data["day_of_week"], data["merchantEncoded"]]]
    
    pred = model.predict(X)[0]
    prob = model.predict_proba(X)[0][1]
    
    return PredictionResponse(
        is_fraud,bool(pred),
        fraud_probability=round(float(prob), 4),
        validation_passed=True
    )

@app.get("/health")
def health():
    return {"status": "正常运行", "validation": "已启用"}

启动经过验证的API:

uvicorn src.serve_validated:app --reload --host 0.0.0.0 --port 8000

现在使用无效数据进行测试:

curl -X POST "http://localhost:8000/predict" \
  -H "Content-Type: application/json" \
  -d '{"amount": -500, "hour": 25, "day_of_week": 10, "merchant_category": "fake"}'

响应结果(HTTP 400):

{
  "detail": {
    "message": "验证失败",
    "errors": [
      "金额必须为正数",
      "小时必须在0到23之间(实际输入为25)",
      "星期几必须在0(周一)到6(周日)之间(实际输入为10)",
      "商家类别必须是['食品杂货'、'餐厅'、'零售'、'在线'、'旅游']中的一个(实际输入为'fake')"
    ],
    "input": {"amount": -500, "hour": 25, "day_of_week": 10, "merchant_category": "fake"}
  }
}

这是一个巨大的改进! 相比于之前默默接受无效数据并返回无意义的预测结果,我们现在:

  • 会立即拒绝无效的输入数据

  • 会提供清晰、易于理解的错误提示信息

  • 会返回原始输入数据以便于调试

  • 会使用正确的HTTP状态码(400表示客户端错误)

关键点: 经过验证的API应该具备以下功能:

  • 能够接受有效的数据并生成预测结果

  • 会使用HTTP 400状态码及详细的错误信息拒绝无效的数据

  • 会针对每个无效字段显示具体的验证错误信息

6. 监控模型性能及数据变化趋势

如果不进行监控会发生什么: 在两个月的时间里,模型的准确率可能会从98%下降到70%,但直到客户开始投诉时,人们才会注意到这个问题。而到那时,已经造成了严重的后果。

即使拥有优秀的模型和干净的数据,时间依然可能成为威胁。随着现实世界数据的变化,模型性能也会下降——这种现象被称为模型漂移模型衰退

为什么这很重要: 在传统的软件系统中,人们会监控CPU使用率、内存占用量、错误率和响应时间。但在机器学习领域,还必须额外关注以下方面:

  • 数据质量(输入数据是否在预期的范围内?)

  • 模型性能(模型的准确率是否依然稳定?)

  • 数据变化趋势(输入数据的分布是否发生了改变?)

  • 预测结果的变化趋势(预测结果的分布是否发生了变化?)

如果不进行监控,模型可能会在数周内默默地出现故障,而人们却迟迟没有发现。等到发现问题时,可能已经造成了严重的损失——比如欺诈行为得以得逞、优质客户被误判为不良用户、收入因此减少。

6.1 机器学习可观测性的四大支柱

支柱 需要监控的内容 重要性
数据质量 输入数据是否有效?是否存在空值或异常值? 错误的数据会导致错误的预测结果
模型性能 准确率、精确度、召回率、F1分数 模型是否仍然能够正常工作?
数据变化趋势 输入数据的分布与训练时的分布是否发生了变化? 模型可能无法适应新的数据环境
预测结果的变化趋势 预测结果的分布是否发生了变化? 这可能表明数据本身或模型本身的概念已经发生了变化

6.2 使用 Evidently 构建数据漂移监测系统

Evidently 是一个专为机器学习监控设计的开源库。它能够检测数据漂移、生成报告,并与各种监控系统集成。

创建文件 src/monitoring.py:

# src/monitoring.py
"""
使用 Evidently 进行模型监控。

该模块提供以下功能:
1. 检测训练数据与生产数据之间的差异
2. 生成详细的 HTML 报告
3. 定期追踪数据漂移情况
4> 当漂移超过预设阈值时发出警报

在实际应用中,应定期(如每小时或每天)执行漂移检测,
并在发现显著漂移时立即发出警报。
"""
import pandas as pd
import numpy as np
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset, TargetDriftPreset
from evidently.metrics import (
    DatasetDriftMetric,
    DataDriftTable,
    ColumnDrift Metric
)
from datetime import datetime
from typing import List, Dict, Any, Optional

class DriftMonitor:
    """
    用于检测参考数据(训练数据)与当前数据之间的差异。
    
    实现说明:
    我们采用了两种方法:
    1. Scipy 的 KS 检验——一种通用性强的统计方法(作为备用方案)
    2. Evidently —— 一个功能完备的库,能生成精美的报告(主要使用工具)
    
    即使 Evidently 无法生成报告,我们仍然可以通过 KS 检验来检测数据漂移。
    
    使用方法:
        monitor = DriftMonitor(training_data)
        result = monitor.check_drift(production_data)
        if result['drift_detected']:
            alert("检测到数据漂移!")
    """
    
    def __init__(self, reference_data: pd.DataFrame, feature_columns: Optional[List[str]] = None):
        """
        使用参考数据(训练数据)初始化漂移监测系统。
        
        参数:
            reference_data:用于比较的训练数据
            feature_columns:需要监控的列(默认为所有数值型列)
        """
        self.reference = reference_data
        self.feature_columns = feature_columns or reference_data.select_dtypes(
            include=[np.number]
        ).columns.tolist()
        self.history: List[Dict[str, Any]] = []
        
        print(f"已使用 {len(self/reference):,} 个参考样本初始化漂移监测系统")
        print(f"需要监控的列:{self.feature_columns}")
    
    def check_drift(self, current_data: pd.DataFrame, threshold: float = 0.1) -> Dict[str, Any]:
        """
        检测参考数据与当前数据之间的差异。
        
        参数:
            current_data:需要检测的实际数据
            threshold:漂移阈值,超过该阈值时会触发警报(默认为 10%)
        
        返回值:
        包含漂移检测结果的字典
        """
        from scipy import stats
        
        ref_subset = self.reference[self.feature_columns]
        cur_subset = current_data[self.feature_columns]
        
        # 使用 KS 检验进行简单的数据漂移检测
        drifted_columns = []
        for col in self.feature_columns:
            statistic, p_value = stats.ks_2samp(
                ref_subset[col].dropna(),
                cur_subset[col].dropna()
            )
            if p_value < 0.05:  # 5%的显著性水平
                drifted_columns.append(col)
        
        n_features = len(self.feature_columns)
        n_drifted = len(drifted_columns)
        drift_share = n_drifted / n_features if n_features > 0 else 0
        
        result = {
            'timestamp': datetime.now().isoformat(),
            'drift_detected': n_drifted > 0,
            'drift_share': drift_share,
            'drifted_columns': drifted_columns,
            'n_features': n_features,
            'n_drifted': n_drifted,
            'current_samples': len(current_data),
            'threshold': threshold,
            'alert': drift_share > threshold
        }
        
        self.history.append(result)
        
        return result
    
    def generate_report(self, current_data: pd.DataFrame, output_path: str = "drift_report.html"):
        """
        使用 Evidently 生成详细的 HTML 报告。
        该报告可在浏览器中查看,以便直观了解数据漂移情况。
        ```
        ref_subset = self.reference[self.feature_columns]
        curSubset = current_data[self.feature_columns]
        
        try:
            report = Report(metrics=[DataDriftPreset'])
            report.run(reference_data=ref_subset, current_data=cur_subset)
            
            # 保存 HTML 报告
            with open(output_path, 'w') as f:
                f.write(report.show(mode='inline').data)
            
            print(f"报告已保存至 {output_path}")
            print("请在浏览器中打开该文件查看详细信息。")
        except Exception as e:
            print(f"无法生成报告:{e}")
            print("此时将使用简化的检测方法。")
    
    def get_alerts(self, threshold: float = 0.1) -> List[Dict[str, Any]]:
        """
        从历史记录中检索所有漂移超过阈值的警报信息。
        
        返回值:
        包含警报信息的列表
        ```
        return [
            {
                'timestamp': r['timestamp'],
                'severity': 'HIGH' if r['drift_share'] > 0.3 else 'MEDIUM',
                'drift_share': r['drift_share'],
                'message': f"检测到数据漂移:{r['drift_share']:.1%} 的列发生了变化",
                'drifted_columns': r['drifted_columns']
            }
            for r in self.history
            if r['drift_share'] > threshold
        ]
    
    def summary(self) -> Dict[str, Any]:
        """获取监控历史的统计信息。"""
        if not self.history:
            return {"message": "尚未进行任何漂移检测"}
        
        drift_shares = [r['drift_share'] for r in self.history]
        alerts = [r for r in self.history if r['alert']]
        
        return {
            'total_checks': len(self.history),
            'total_alerts': len(alerts),
            'avg_drift_share': np.mean(drift_shares),
            'max_drift_share': np.max(drift_shares),
            'first_check': self.history[0]['timestamp'],
            'last_check': self.history[-1]['timestamp']
        }


def simulate_drift_scenarios():
    """
    通过不同场景演示数据漂移检测功能。
    该示例模拟了生产数据与训练数据出现差异时的情况。
    ```
    from src.generate_data import generate_transactions
    
    print "="*70)
    print("数据漂移检测演示")
    print "="*70)
    
    # 加载参考数据(训练数据)
    print("\n1. 正在加载参考数据(训练集)...")
    reference = pd.read_csv('data/train.csv')
    feature_columns = ['amount', 'hour', 'day_of_week']
    
    # 初始化漂移监测系统
    monitor = DriftMonitor(reference, feature_columns)
    
    # 场景 1:数据相似(应显示很小的漂移)
    print("\n" + "-"*70)
    print("场景 1:测试数据(分布相同)")
    print("-"*70)
    test_data = pd.read_csv('data/test.csv')
    result = monitor.check_drift(test_data)
    print(f"  检测到数据漂移:{result['drift Detected']}")
    print(f"  漂移比例:{result['drift_share']:.1%}")
    print(f"  发生漂移的列:{result['drifted_columns']}")
    print(f"  触发了警报:{result['alert']}")
    
    # 场景 2:欺诈数据增加(欺诈比例为 10%,而非 2%)
    print("\n" + "-"*70)
    print("场景 2:欺诈数据激增(欺诈率为 10%)")
    print("-"*70)
    fraud_spike = generate_transactions(n_samples=2000, fraud_ratio=0.10, seed=101)
    result = monitor.check_drift(fraud_spike)
    print(f"  检测到数据漂移:{result['drift Detected']}")
    print(f"  漂移比例:{result['drift_share']:.1%}")
    print(f"  发生漂移的列:{result['drifted_columns']}")
    print(f"  触发了警报:{result['alert']}")
    
    # 场景 3:数据金额增加(所有项目的价格都上涨了)
    print("\n" + "-"*70)
    print("场景 3:数据金额膨胀(所有项目的价格变为原来的 2 倍)")
    print("-"*70)
    inflated = test_data.copy()
    inflated['amount'] = inflated['amount'] * 2
    result = monitor.check_drift(inflated)
    print(f"  检测到数据漂移:{result['drift Detected']}")
    print(f"  漂移比例:{result['drift_share']:.1%}")
    print(f"  发生漂移的列:{result['drifted_columns']}")
    print(f"  触发了警报:{result['alert']}")
    
    # 为漂移最严重的场景生成详细报告
    print("\n" + "-"*70)
    print("正在生成详细报告")
    print("-"*70)
    monitor.generate_report(inflated, "drift_report.html")
    
    # 打印汇总信息
    print("\n" + "-"*70)
    print("监控汇总")
    print("-"*70)
    summary = monitor.summary()
    print(f"  总检测次数:{summary['total_checks']}")
    print(f"  总警报数量:{summary['total_alerts']}")
    print(f"  平均漂移比例:{summary['avg_drift_share']:.1%}")
    print(f"  最高漂移比例:{summary['max_drift_share']:.1%}")
    
    # 打印所有警报信息
    alerts = monitor.get_alerts()
    if alerts:
        print("\n  共有 {len(alerts)} 条警报:")
        for alert in alerts:
            print(f"    [{{alert['severity']}] {alert['message']}]")
        
    print("\n" + "="*70)
    print("数据漂移检测演示完成")
    print("="*70)
    print("\n请在浏览器中打开 drift_report.html 文件查看详细可视化结果!")
)


if __name__ == "__main__":
    simulate_drift_scenarios()

运行漂移检测模拟:

python src/monitoring.py

你会看到输出结果,这些结果会展示在不同场景下漂移检测的工作原理。之后,请在浏览器中打开drift_report.html文件,查看那些关于漂移模式的可视化图表。

6.3 生产环境监控策略

在生产环境中,你应该采取以下措施:

  1. 将所有预测结果记录到数据库或数据仓库中

  2. 定期进行漂移检测(高流量系统建议每小时检测一次,低流量系统则每天检测一次)

  3. 设置警报机制,当漂移程度超过预设阈值时立即触发警报(可集成PagerDuty、Slack等工具)

  4. 在漂移现象严重或持续存在时启动模型重新训练流程

  5. 创建监控仪表板,以便随时追踪漂移情况的变化(可使用Grafana、Datadog等工具)

验证要点:运行python src/monitoring.py命令后,应该会得到以下结果:

  • 对于相同类型的数据(测试集数据),检测结果应显示出极小的漂移幅度

  • 对于经过修改的数据,系统应能准确检测出其变化趋势——例如欺诈交易数量的突然增加、通货膨胀现象或时间序列的偏移等

  • 系统会生成一份HTML报告,你可以直接在浏览器中查看这份报告

7. 使用CI/CD自动化测试与部署流程

如果没有CI/CD会带来什么问题?:代码中的一个小错误就可能导致API功能失效。如果你在周五下午5点进行部署,可能要等到周一才会有人发现这个问题,而此时欺诈造成的损失已经变得非常严重。

CI/CD(持续集成/持续部署)能够确保软件发布的稳定性与可靠性。正如JFrog所指出的:"一个完善的CI/CD流程能够帮助机器学习团队更快、更高效地构建出无缺陷的模型。"

为什么这很重要?:在机器学习领域,需要修改的不仅仅是代码,还包括数据和模型。CI/CD能够确保:每当你对训练逻辑、数据预处理方法或超参数进行任何调整时,系统都会先进行测试,以确保这些变更不会对后续的生产环境造成负面影响。这就是为什么使用CI/CD能够让你更加放心地进行部署,而不用一直担心会出现问题。

7.1 为数据和模型编写测试用例

创建tests/test_data_and_model.py文件:

# tests/test_data_and_model.py
"""
用于检测数据质量及模型性能的测试用例。

这些测试会在CI/CD流程中自动执行,以确保:
1. 数据符合质量要求
2. 模型的性能指标达到预期标准
3> 避免出现任何退化现象

执行方式:pytest tests/test_data_and_model.py -v
"""
import pandas as pd
import pickle
import pytest
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

class TestDataQuality:
    """用于检测训练数据质量的测试用例。"""

    @pytest.fixture
    def train_data(self):
        return pd.read_csv("data/train.csv")

    @pytestfixture
    def test_data(self):
        return pd.read_csv("data/test.csv")

    def test_train_data_has_expected_columns(self, train_data):
        """训练数据必须包含所有必需的列。"""
        required_columns = {"amount", "hour", "day_of_week", "merchant_category", "is_fraud"}
        actual_columns = set(train_data.columns)
        missing = required_columns - actual_columns
        assert not missing, f"缺少以下列:{missing}"

    def test_train_data_not_empty(self, train_data):
        """训练数据必须包含至少一行数据。"""
        assert len(train_data) > 0, "训练数据为空"
        assert len(train_data) >= 1000, f"训练数据行数太少:{len(train_data)}"

    def test_no_negative_amounts(self, train_data):
        """交易金额必须为非负数值。"""
        negative_count = (train_data["amount"] < 0).sum()
        assert negative_count == 0, f"检测到{negative_count}笔负数交易"

    def test_amounts_reasonable(self, train_data):
        """交易金额应处于合理的范围内。"""
        max_amount = train_data["amount"].max()
        assert max_amount <= 100000, f"最大交易金额超过了合理范围:{max_amount}"

    def test_hours_valid(self, train_data):
        """时间字段的值必须在0到23之间。"""
        invalid = train_data[(train_data["hour"] < 0) | (train_data["hour"] > 23)]
        assert len(invalid) == 0, f"检测到{len(invalid)}条无效的时间记录"

    def test_days_valid(self, train_data):
        """星期几的数值必须在0到6之间。"""
        invalid = train_data[(train_data["day_of_week"] < 0) | (train_data["day_of_week"] > 6)]
        assert len(invalid) == 0, f"检测到{len(invalid)}条无效的星期几记录"

    def test_merchant_categories_valid(self, train_data):
        """商家类别必须属于已知的合法集合。"""
        valid_categories = {"grocery", "restaurant", "retail", "online", "travel"}
        actual_categories = set(train_data["merchant_category"].unique())
        invalid = actual这些类别 - valid_categories
        assert not invalid, f"检测到{invalid}个无效的商家类别"

    def test_fraud_ratio_reasonable(self, train_data):
        """欺诈比例应处于合理的范围内(0.1%至50%之间)。"""
        fraud_ratio = train_data["is_fraud"].mean()
        assert 0.001 <= fraud_ratio <= 0.5, f"检测到的欺诈比例为{fraud_ratio:.2%}, 不在合理范围内"

    def test_no_nulls_in_critical_columns(self, train_data):
        """关键字段的值不能为NULL。"""
        critical = ["amount", "hour", "day_of_week", "merchant_category", "is_fraud"]
        for col in critical:
            null_count = train_data[col].isnull().sum()
            assert null_count == 0, f"字段{col}中有{null_count}个NULL值"
tests/test_
api.py 文件:

# tests/test_api.py
"""
用于测试FastAPI预测服务的相关功能。

这些测试旨在确保该API能够:
1. 对有效的输入返回正确的响应;
2. 对无效的输入给出相应的错误提示;
3. 正常运行健康检查功能。

执行方式:pytest tests/test_api.py -v
注意:必须确保API在localhost:8000地址上运行。
"""
import pytest
import httpx

BASE_URL = "http://localhost:8000"

class TestPredictionEndpoint:
    """用于测试 /predict 这一端点的功能。"""

    def test_valid_prediction_returns_200(self):
        """有效的输入应该会返回HTTP 200状态码以及预测结果。"""
        response = httpx.post(f"{BASE_URL}/predict", json={
            "amount": 100.0,
            "hour": 14,
            "day_of_week": 3,
            "merchant_category": "online"
        }, timeout=10)
        
        assert response.status_code == 200
        data = response.json()
        assert "is_fraud" in data
        assert "fraud_probability" in data
        assert isinstance(data["is_fraud"], bool)
        assert 0 <= data["fraud_probability"] <= 1

    def test_high_risk_transaction(self):
        """高风险交易应该具有更高的欺诈概率。"""
        response = httpx.post(f"{BASE_URL}/predict", json={
            "amount": 500.0,
            "hour": 3,  # 深夜时间
            "day_of_week": 1,
            "merchant_category": "online"
        }, timeout=10)
        
        assert response.status_code == 200
        data = response.json()
        # 高风险交易的欺诈概率应该大于0.0
        # (具体数值可能因模型不同而有所差异,因此不进行硬性断言)
        assert data["fraud_probability"] >= 0.0

    def test_negative_amount_rejected(self):
        """负数的交易金额应该会被拒绝,并返回400状态码。"""
        response = httpx.post(f"{BASE_URL}/predict", json={
            "amount": -100.0,
            "hour": 14,
            "day_of_week": 3,
            "merchant_category": "online"
        }, timeout=10)
        
        assert response.status_code == 400
        assert "errors" in response.json()["detail"]

    def test_invalid_hour_rejected(self):
        """无效的时间参数应该会被拒绝,并返回400状态码。"""
        response = httpx.post(f"{BASE_URL}/predict", json={
            "amount": 100.0,
            "hour": 25,  # 时间无效
            "day_of_week": 3,
            "merchant_category": "online"
        }, timeout=10)
        
        assert response.status_code == 400

    def test_invalid_merchant_rejected(self):
        """未知的商家类别应该会被拒绝,并返回400状态码。"""
        response = httpx.post(f"{BASE_URL}/predict", json={
            "amount": 100.0,
            "hour": 14,
            "day_of_week": 3,
            "merchant_category": "unknown_category"
        }, timeout=10)
        
        assert response.status_code == 400

    def test_missing_field_rejected(self):
        """缺少必要的字段应该会被拒绝,并返回422状态码。"""
        response = httpx.post(f"{BASE_URL}/predict", json={
            "amount": 100.0,
            "hour": 14
            # 缺少了day_of_week和merchant_category字段
        }, timeout=10)
        
        assert response.status_code == 422  # 这是Pydantic框架导致的验证错误


class TestHealthEndpoint:
    """用于测试 /health 这一端点的功能。"""

    def test_health_returns_200(self):
        """健康检查端点应该返回200状态码。"""
        response = httpx.get(f"{BASE_URL}/health", timeout=10)
        assert response.status_code == 200

    def test_healthозвращаетhealthy_status(self):
        """健康检查端点应该表明系统处于正常运行状态。"""
        response = httpx.get(f"{BASE_URL}/health", timeout=10)
        data = response.json()
        assert data["status"] == "healthy"

在本地运行测试:

# 运行数据测试和模型测试(无需API)
pytest tests/test_data_and_model.py -v

# 运行API测试(需要API处于运行状态)
pytest tests/test_api.py -v

7.2 GitHub Actions工作流程

⚠️ 生产团队的注意事项
在真正的机器学习团队中,通常不会在持续集成环境中重新训练完整的模型——因为这样做既耗时又耗费资源。
在这里我们这样做,是为了确保所有的测试过程都能在本地进行、结果可复现,并且便于学习使用。
生产环境中的管道通常会将训练任务(定时作业)与测试任务(持续集成/持续部署流程)分开处理。

创建文件.github/workflows/ci.yml

# .github/workflows/ci.yml
name: 机器学习管道的持续集成/持续部署流程

on:
  push:
    branches: [main, develop]
  pull_request:
    branches: [main]

jobs:
  test:
    runs-on: ubuntu-latest

    steps:
      - name: 检出代码
        uses: actions/checkout@v4

      - name: 安装Python环境
        uses: actions/setup-python@v5
        with:
          python-version: "3.11"
          cache: 'pip'

      - name: 安装依赖项
        run: |
          python -m pip install --upgrade pip
          pip install -r requirements.txt

      - name: 生成训练数据
        run: python src/generate_data.py

      - name: 训练模型
        run: python src/train_naive.py

      - name: 运行数据质量测试
        run: pytest tests/test_data_and_model.py -v --tb=short

      - name: 构建Docker镜像
        run: docker build -t fraud-detection-api .

      - name: 运行容器以进行API测试
        run: |
          docker run -d -p 8000:8000 --name test-api fraud-detection-api
          sleep 10  # 等待API启动
          curl -f http://localhost:8000/health || exit 1

      - name: 运行API测试
        run: pytest tests/test_api.py -v --tb=short

      - name: 清理资源
        if: always()
        run: docker stop test-api || true

7.3 将应用程序Docker化

创建文件Dockerfile

# Dockerfile
FROM python:3.11-slim

# 设置工作目录
WORKDIR /app

# 安装系统依赖项
RUN apt-get update && amp; apt-get install -y \
    curl \
    && amp; rm -rf /var/lib/apt/lists/*

# 复制并安装Python依赖项
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# 复制应用程序代码
COPY src/ src/
COPY models/ models/
COPY data/ data/

# 暴露端口
EXPOSE 8000

# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
    CMD curl -f http://localhost:8000/health || exit 1

# 运行API服务器
CMD ["uvicorn", "src.serve_validated:app", "--host", "0.0.0.0", "--port", "8000"]

创建文件.dockerignore

# .dockerignore
venv/
__pycache__/
*.pyc
.git/
.github/
mlruns/
*.db
*.html
.pytest_cache/

在本地构建并运行:

# 构建Docker镜像
docker build -t fraud-detection-api .

# 运行容器
docker run -p 8000:8000 fraud-detection-api

# 进行测试
curl http://localhost:8000/health

检查结果:

  • 所有测试均通过:`pytest tests/test_data_and_`model.py` -v

  • Docker镜像构建成功

  • 容器已运行且能正常响应健康检查请求

8. 事件应对方案

当生产环境中出现问题时(这种情况肯定会发生),你需要有一个应对计划。本节提供了针对常见机器学习相关事件的应对方案。

场景:误报率突然升高

症状:你的欺诈检测模型突然将40%的合法交易标记为欺诈行为,导致客户被阻止进行交易,同时给人工审核团队带来巨大压力。

严重程度:高——直接影响客户体验

第一阶段:缓解问题(0-5分钟)

  1. 确认事件发生 ——通知相关方你已注意到这个问题并正在采取措施处理

  2. 恢复到之前的模型版本 ——在MLflow用户界面中,将@champion别名切换回之前的模型版本

  3. 重启API服务 ——执行`docker restart fraud-api`命令或重新部署API

  4. 进行验证 ——确认误报率是否已恢复正常

  5. 及时沟通 ——告知客户“问题已经得到解决,正在调查根本原因”

第二阶段:诊断问题(5-60分钟)

  1. 查看数据变化报告 ——使用最新的生产数据运行`python src/`monitoring.py脚本

  2. 检查数据验证日志 ——上游数据格式是否发生了变化?

  3. 查看最近的部署记录 ——最近是否有新的模型或代码被部署?

  4. 对比各项指标 ——恢复前的模型与出现问题的模型之间有哪些差异?

可能的根本原因:

  • 上游系统发送的金额单位是分而不是美元

  • 训练数据中没有的新商家类别被纳入了检测系统

  • 节假日的购物模式与训练数据中的情况存在显著差异

第三阶段:彻底解决问题(1-24小时)

  1. 解决根本原因 ——为这些特殊情况进行额外验证,或更新训练数据

  2. 如有必要,重新进行模型训练 ——将新的数据模式纳入训练集

  3. 添加测试用例 ——防止类似问题再次发生

  4. 记录处理过程 ——将此次处理方案整理成文档,以供日后参考

场景:系统性能逐渐下降

症状:监控数据显示,欺诈检测模型的识别准确率每周下降2%,这种下降趋势持续了一个月。系统中没有出现突然的故障,只是性能在逐渐恶化。

严重程度:中等——影响较为渐进,需要及时采取措施应对。

应对措施:

  1. 分析数据变化趋势 – 查看各指标是否出现了渐进性的变化。

    python src/monitoring.py
  2. 收集最近标记过的样本数据 – 获取过去一个月内被确认为欺诈案例的数据。

  3. 分析数据中的异常模式 – 最近发生的欺诈行为有哪些不同之处?

    • 是否出现了新的攻击手段?

    • 数据出现的频率或时间模式是否有变化?

    • 涉及的新商家类别是否有所增加?

  4. 使用合并后的数据重新训练模型 – 将旧数据和新数据都纳入训练流程中。

    python src/train_mlflow.py
  5. 先通过小规模测试进行部署 – 先将10%的流量路由到新模型上。

    • 观察模型的运行指标,持续1–2天。

    • 如果指标有所改善,逐渐增加使用新模型的比例,最终达到100%。

    • 如果指标恶化,立即恢复使用旧模型。

  6. 设置定期重新训练机制 – 定期安排每周或每月的模型重新训练工作。

场景示例:上游数据结构发生变化

症状:API开始返回500错误代码。日志中显示“KeyError: ‘merchant_category’”这一错误信息。

严重程度:高——服务无法正常运行。

应对措施:

  1. 检查错误日志 – 确定具体的错误原因。

    KeyError: 'merchant_category'
    
  2. 检查上游数据源 – 确认字段名称是否发生了变化。

    • merchant_category -> category

    • amount -> transaction_amount

  3. 立即进行修复 – 添加字段名称的映射规则。

    # API中的快速修复代码
    if 'category' in data and 'merchant_category' not in data:
        data['merchant_category'] = data['category']
    
  4. 长期解决方案 – 加入能够检测数据结构变化的验证机制。

    required_fields = ['amount', 'hour', 'day_of_week', 'merchant_category']
    missing = [f for f in required_fields if f not in data]
    if missing:
        raise ValidationError(f"缺失的字段有:{missing}")
    
  5. 添加集成测试 – 在持续集成/持续部署流程中,与上游系统进行联合测试。

9. 如何将所有这些措施结合起来使用

让我们回顾一下我们所完成的工作。最初那个简单的系统已经发展成为一个拥有生产级组件的本地机器学习平台

💡 思维模型:这套工具体系中的每一项工具都针对某种特定的故障模式起到“拦截作用”:

  • MLflow用于检测“这个模型到底是什么?”这个问题。

  • Feast用于验证“各个特征是否一致?”。

  • Great Expectations用于判断“这些数据是否有效?”。

  • Evidently用于识别“外部环境是否发生了变化?”。

  • CI/CD则用于检测“我们是否破坏了系统的稳定性?”。

这些工具共同构成了机器学习系统的多重防护机制。

组件 工具 解决的问题
实验跟踪 MLflow 所有实验过程都会被记录下来,且可以重现
模型注册库 MLflow 模型可版本化管理,支持回滚功能
特征存储系统 Feast 特征数据保持一致性,避免训练环境与部署环境之间的差异
数据验证 Great Expectations 不良数据会被及时识别并拒绝接收
监控系统 Evidently 能及时发现数据偏差,避免问题发生
容器化技术 Docker 确保所有环境的一致性
持续集成/持续部署 GitHub Actions 自动化测试,保障安全部署

完整的工作流程

以下是这些组件在实际应用中的协同工作方式:

  1. 数据接收 — 新的交易数据会从上游系统传入

  2. 数据验证 — Great Expectations会检查数据质量,不良数据会在造成危害之前被拒绝接收

  3. 特征计算 — Feast会使用相同的规则为训练和部署阶段计算特征数据,从而避免差异出现

  4. 模型训练 — 当需要重新训练模型时,MLflow会记录所有参数、指标及生成的结果文件,确保每次实验结果都是可重现的

  5. 模型管理 — 训练好的模型会被自动分配版本号,可以方便地比较不同版本的指标表现,将最佳版本部署到生产环境中,必要时也可以进行回滚操作

  6. 服务端部署 — FastAPI会从MLflow中加载模型,每个请求都会经过验证,特征数据会从Feast中获取,然后生成预测结果并返回给用户

  7. 监控机制 — Evidently会定期检查数据偏差情况,如果输入数据的分布发生显著变化,系统会立即发出警报

  8. 迭代循环 — 一旦发现数据偏差,就会使用新数据重新训练模型,然后比较不同版本的指标,将表现更好的版本部署到生产环境中,这个循环会持续进行

  9. 持续集成/持续部署的安全保障 — 所有的代码变更都会经过自动化测试,Docker能确保所有环境的一致性,只有通过所有测试的代码才能被部署到生产环境中

10. 下一步计划:扩展到生产环境

虽然这个项目目前是在本地运行的,但这些技术和工具同样适用于生产环境的部署。以下是各组件的扩展方案:

Feast在 production 环境中的扩展

我们在本地使用 SQLite 作为特征存储系统,而在生产环境中,则会采用以下方案:

组件 本地环境 生产环境
在线存储 SQLite Redis、DynamoDB 或 PostgreSQL
离线存储 Parquet 文件 BigQuery、Snowflake 或 Redshift
特征服务端 内嵌式实现 专用的 Feast 服务集群

大规模应用带来的优势:

  • 特征检索时间低于10毫秒

  • 可水平扩展以实现高吞吐量

  • 支持特征监控与统计分析

  • 能够处理PB级别的数据量,实现精确的连接操作

将MLflow应用于生产环境

组件 本地环境 生产环境
后端存储 SQLite PostgreSQL或MySQL
模型资源存储 本地文件系统 S3、GCS或Azure Blob
跟踪服务器 单实例 负载均衡集群

Kubernetes部署方案

当Docker Compose不再满足需求时,您可以采取以下措施:

  • 使用KServe或Seldon实现无服务器模型服务,并具备自动扩展功能

  • 利用水平Pod自动扩展器根据CPU/内存等指标进行动态扩展

  • 采用金丝雀部署策略,安全地推出新模型(先让10%的流量通过新模型处理)

  • 为计算密集型模型配置GPU调度机制

高级监控功能

通过以下工具提升系统的可观测性:

  • Prometheus与Grafana可用于实时数据展示

  • OpenTelemetry支持分布式追踪功能

  • 集成PagerDuty或Slack可实现警报通知

  • 通过带标签的数据收集机制持续评估模型性能

A/B测试与多臂老虎机算法

如何使用模型注册系统:

  • 同时部署多个模型,进行对比测试

  • 根据具体场景动态分配流量

  • 为每个模型变体收集相应的指标数据

  • 自动推广表现最佳的模型

总结

恭喜您在本地环境中成功搭建了一套可用于生产环境的机器学习系统!

我们在这里展示的内容实际上反映了现实世界中机器学习平台的常见架构:

  • 最初我们只是使用一个保存在pickle文件中的模型

  • 最终我们掌握了MLOps最佳实践:包括实验跟踪、模型版本管理、特征资源存储、数据验证、监控机制、容器化技术以及持续集成/持续部署流程

我们使用的这些工具都是经过生产环境验证的高质量工具:

  • MLflow被微软、Facebook和Databricks等公司用于构建机器学习平台

  • Feast被Gojek、Shopify和Robinhood等企业采用

  • FastAPI是目前速度最快的Python Web框架之一

  • Great Expectations在GitHub和Shopify等公司中得到应用

  • Evidently被广泛用于大规模生产环境中的机器学习监控工作

这些原则适用于任何规模的应用:

  • 始终要跟踪实验过程

  • 始终要对模型进行版本控制

  • 始终要验证数据准确性

  • 始终要监测数据是否存在偏差

  • 为保持一致性,始终要使用容器化技术

  • 始终要自动化测试流程

你可以尝试的下一步行动

  1. 部署到云端 — 将你的Docker容器部署到AWS ECS、Google Cloud Run或Azure Container Instances上

  2. 增加模型可解释性 — 使用SHAP或LIME来解释模型的预测结果

  3. 实施A/B测试 — 提供多个模型版本并比较它们的性能

  4. 添加特征重要性监控功能 — 记录特征重要性随时间的变化情况

  5. 设置实时警报机制 — 将Evidently与Slack或PagerDuty连接起来

  6. 实施持续训练机制 — 一旦检测到数据偏差,就自动重新训练模型

  7. 添加偏见与公平性监控功能 — 确保模型能够公平地对待所有群体

请记住,将机器学习模型投入实际应用是一个迭代过程。总有一些额外的可靠性措施需要加入,一些边缘情况需要处理,还有一些指标需要跟踪。但是,凭借你在这里建立的基础,你已经走在了将实验成果转化为可部署、可监控且易于维护的生产级应用的正确道路上。

祝你在开发过程中一切顺利,希望你的模型能够具备高准确性,也希望你的开发流程能够具备强大的稳定性!

获取完整代码

本手册中的整个项目都可以通过公开的GitHub仓库获得:

🔗 github.com/sandeepmb/freecodecamp-local-ml-platform

该仓库包含以下内容:

  • 所有源代码(src/目录)

  • 测试文件(tests/目录)

  • Feast特征定义文件(feature_repo/目录)

  • Docker及CI/CD配置文件

  • 可直接运行的脚本

快速入门指南:

git clone https://github.com/sandeepmb/freecodecamp-local-ml-platform.git
cd freecodecamp-local-ml-platform
python -m venv venv && source venv/bin/activate
pip install -r requirements.txt
python src/generate_data.py
python src/train_naive.py

参考资料

Comments are closed.