降维训练脚本

环境依赖

训练模型

本章节涉及的脚本的默认存放路径为:“tools/train/reduction”

  1. 训练模型。

    python3 call_train.py --dataset_dir=Dataset_Dir --val_dataset_dir=./valid --generate_val=True --save_path=./modelsDr --dim=512 --npu=0 --ratio=4 --metric=L2 --mode=train --train_size=100000 --epochs=20 --train_batch_size=8192 --infer_batch_size=128 --learning_rate=0.0005 --log_stride=500 --construct_neighbors=100 --queries_validation=1000

    参数

    说明

    dataset_dir

    数据集路径,类型为string。目前实现默认读取base.npy,query.npy和gt.npy。

    若数据集为其他名称,可以自行实现数据集读取,并对该脚本“get_train_data”所在行做对应修改。

    例如。原代码为:

     # load dataset demo before training, modify here if you want to load your own dataset
            #####################################################################
            learn, base = get_train_data(args.dataset_dir, args.train_size)
            #####################################################################

    可修改为:

     # load dataset demo before training, modify here if you want to load your own dataset
            #####################################################################
            # learn, base = get_train_data(args.dataset_dir, args.train_size)
            learn = np.fromfile(YOUR_LEARN_DATASET_DIR, dtype=np.float32).reshape((-1, YOUR_DATA_DIM))
            base = np.fromfile(YOUR_BASE_DATASET_DIR, dtype=np.float32).reshape((-1, YOUR_DATA_DIM))
            #####################################################################

    val_dataset_dir

    “generate_val”“True”时有效,生成验证集的存放路径,类型为string。

    generate_val

    是否生成验证集,默认为“False”。首次训练请设置为“True”。类型为bool。

    save_path

    模型存放路径。类型为string。

    dim

    可选,数据集维度。取值范围:[96, 128, 200, 256, 512, 2048]。类型为int。

    npu

    训练所用的DeviceId,即设备号。类型为int。

    仅支持单卡训练,默认为CPU训练。

    ratio

    可选,降维比例。取值范围:[2, 4, 8, 16]。类型为int。

    metric

    训练模型时的距离度量标准,可选L2或IP。类型为string。

    mode

    默认为“train”。类型为string。

    建议不要修改。

    train_size

    训练集大小,取值范围小于整个数据集样本个数。用于读取数据集时随机采样部分数据进行训练。类型为int。

    若自行实现数据集读取,请根据train_size进行采样以防止训练速度过慢。

    epochs

    训练迭代轮数,默认为“30”。类型为int。迭代次数设置过大,会显著增加训练时长。

    train_batch_size

    训练时的batch大小,默认为“8192”,类型为int。

    infer_batch_size

    推理时的batch大小,默认为“128”。类型为int。

    learning_rate

    学习率大小,默认为“0.0005”。类型为float。

    log_stride

    训练日志打印间隔(step),默认为“500”。类型为int。

    construct_neighbors

    构造训练集时所取的近邻的范围,用于构造降维所需的特殊训练集结构,默认为“100”。应根据数据集中每个人所对应的人脸数修改。类型为int。

    queries_validation

    构造验证集时所需查询向量的数量,默认为“1000”。类型为int。

  2. 生成OM模型。

    1. 生成精度为32的om模型。
      bash atc.sh {save_path}/best.onnx 2Cex4_fp32
    2. 生成精度为16的om模型
      bash atc_16.sh {save_path}/best.onnx 2Cex4_fp16

    {save_path}表示模型存储的路径。