以前对萨吉式培训工作的表格BERT模型进行了预培训,我开始使用预训练结果中的model.tar.gz
文件来创建用于微调的工作。
尽管这项工作从根本上类似于培训前过程,但在以下方面需要进行某些调整,我已将其编译为备忘录:
- 提取焦油文件
- 使用环境变量在本地环境和萨格人之间切换参数
提取焦油文件
model.tar.gz
文件由于预训练作业而存储在S3上,包含以下文件:型号pytorch_model.bin
,configuration file config.json
,dictionary file file vocab.nb
和令牌到ID转换文件vocab_token2id.bin
。要在微调过程中加载这些文件,必须设计一种在执行工作时提取焦油文件的方法。
最初,在作业文件的输入_model部分中设置model.tar.gz
文件的S3路径。因此,在执行作业时,model.tar.gz
文件将放置在/opt/ml/input/data/input_model/
(model_path)目录中。
import sagemaker
from sagemaker.estimator import Estimator
session = sagemaker.Session()
role = sagemaker.get_execution_role()
estimator = Estimator(
image_uri=<image-url>,
role=role,
instance_type="ml.g4dn.2xlarge",
instance_count=1,
base_job_name="tabformer-opt-fine-tuning",
output_path="s3://<bucket-name>/sagemaker/output_data/fine_tuning",
code_location="s3://<bucket-name>/sagemaker/output_data/fine_tuning",
sagemaker_session=session,
entry_point="fine-tuning.sh",
dependencies=["tabformer-opt"],
hyperparameters={
"data_root": "/opt/ml/input/data/input_data/",
"data_fname": "summary",
"output_dir": "/opt/ml/model/",
"model_path": "/opt/ml/input/data/input_model/",
}
)
estimator.fit({
"input_data": "s3://<bucket-name>/sagemaker/input_data/summary.csv",
"input_model": "s3://<bucket-name>/sagemaker/output_data/pre_training/tabformer-opt-2022-12-16-07-00-45-931/output/model.tar.gz"
})
接下来,在微调执行文件tabformer_bert_fine_tuning.py
中包括以下内容:
with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
mytar.extractall(path.join(args.model_path, f'model'))
token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
vocab_file = path.join(args.model_path, f"model/vocab.nb")
pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")
tarfile.open()
函数读取model.tar.gz
文件,而mytar.extractall(path.join(args.model_path, f'model'))
在/opt/ml/input/data/input_model/model/
目录下提取内容。
这使您可以加载提取的文件,例如使用token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
。
使用环境变量在本地环境和萨吉人之间切换参数
使用此设置,您现在可以从S3加载model.tar.gz
文件。但是,在某些情况下,您可能想在本地执行微调时更改文件源。
要处理这种情况,您可以使用os.getenv('SM_MODEL_DIR')
获取sagemaker环境变量SM_MODEL_DIR
(容器终止后将上传到S3的目录),并在本地和sagemaker(job)环境之间切换文件源。<<<<<<<<
key = os.getenv('SM_MODEL_DIR')
if key :
with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
mytar.extractall(path.join(args.model_path, f'model'))
token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
vocab_file = path.join(args.model_path, f"model/vocab.nb")
pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")
else :
vocab_file = path.join(args.model_path, f"vocab.nb")
token2id_file = path.join(args.model_path, f"vocab_token2id.bin")
pretrained_model = path.join(args.model_path, f"checkpoint-500/pytorch_model.bin")
pretrained_config = path.join(args.model_path, f"checkpoint-500/config.json")