为什么不使用可用解决方案?所有人都使用chatgpt。但是我没有信用;)
我当然想学习一些东西!
如何生成git提交消息?
git允许您创建钩子。让我们使用全球一个。全局钩工作而无需修改每个git repo。
为钩子创建一个目录:
$ mkdir ~/.config/git/hooks/
让git知道钩子在哪里:
$ git config core.hooksPath ~/.config/git/hooks/
长话短说,prepare-commit-msg
是我们需要的。我们需要更新的文件作为第一个参数传递。
创建一个简单的脚本:
#!/bin/sh
echo "Fancy commit message" > $1
使其可执行:
$ chmod +z ~/.confog/git/hooks/prepare-commit-msg
它有效吗?让我们提出一些事情...是的,我们在提交消息的末尾有一条消息。
让我们生成一些东西:
生成提交消息
让我们脱机地构建一些有效的东西。 AI?是的,让我们使用AI!
我们需要模型吗?
让我们看一下huggingface!
是:https://huggingface.co/mamiksik/T5-commit-message-generation,但没有文档:(
但是,如果您看起来更深,您会发现https://huggingface.co/spaces/mamiksik/commit-message-generator
我们可以使用一些修改的https://huggingface.co/spaces/mamiksik/commit-message-generator/blob/main/app.py。
我们可以在钩中使用任何shell脚本,让我们使用python。
让我们看看那里有什么:
import re
import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")
model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")
def parse_files(patch):
accumulator = []
lines = patch.splitlines()
filename_before = None
for line in lines:
if line.startswith("index") or line.startswith("diff"):
continue
if line.startswith("---"):
filename_before = line.split(" ", 1)[1][1:]
continue
if line.startswith("+++"):
filename_after = line.split(" ", 1)[1][1:]
if filename_before == filename_after:
accumulator.append(f"<ide><path>{filename_before}")
else:
accumulator.append(f"<add><path>{filename_after}")
accumulator.append(f"<del><path>{filename_before}")
continue
line = re.sub("@@[^@@]*@@", "", line)
if len(line) == 0:
continue
if line[0] == "+":
line = line.replace("+", "<add>", 1)
elif line[0] == "-":
line = line.replace("-", "<del>", 1)
else:
line = f"<ide>{line}"
accumulator.append(line)
return '\n'.join(accumulator)
def predict(patch, max_length, min_length, num_beams, prediction_count):
input_text = parse_files(patch)
with torch.no_grad():
token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1]
input_ids = tokenizer(
input_text,
truncation=True,
padding=True,
return_tensors="pt",
).input_ids
outputs = model.generate(
input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
num_return_sequences=prediction_count,
)
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return token_count, input_text, {k: 0 for k in result}
iface = gr.Interface(fn=predict, inputs=[
gr.Textbox(label="Patch (as generated by git diff)"),
gr.Slider(1, 128, value=40, label="Max message length"),
gr.Slider(1, 128, value=5, label="Min message length"),
gr.Slider(1, 10, value=7, label="Number of beams"),
gr.Slider(1, 15, value=5, label="Number of predictions"),
], outputs=[
gr.Textbox(label="Token count"),
gr.Textbox(label="Parsed patch"),
gr.Label(label="Predictions")
], examples=[
["""
diff --git a/.github/workflows/pylint.yml b/.github/workflows/codestyle_checks.yml
similarity index 86%
rename from .github/workflows/pylint.yml
rename to .github/workflows/codestyle_checks.yml
index a5d5c4d9..8cbf9713 100644
--- a/.github/workflows/pylint.yml
+++ b/.github/workflows/codestyle_checks.yml
@@ -20,3 +20,6 @@ jobs:
- name: Analysing the code with pylint
run: |
pylint --rcfile=.pylintrc webapp core
+ - name: Analysing the code with flake8
+ run: |
+ flake8
""", 40, 5, 7, 5]
]
)
if __name__ == "__main__":
iface.launch()
我们需要的一切都在这里!我们需要:
- 获取gitmessage文件以更新
- 提取git diff
- 使用当前脚本进行预测
- 预先启动提交消息到gitmessage文件
我们需要更新的文件作为第一个参数,因此
import sys
sys.argv[1]
呵呵,这很容易。
获取git差异
import subprocess
subprocess.run(['git', 'diff', '--cached'], capture_output=True).stdout.decode('utf-8')
简单的peasy!
使用当前脚本进行预测
max_message = 40
min_message = 5
num_beams = 10
num_predictions = 1
msg = predict(diff, max_message, min_message, num_beams, num_predictions)
将我们的消息预先发送到gitmessage文件
with open(sys.argv[1], 'r+') as f:
content = f.read()
f.seek(0)
f.write(msg + '\n' + content)
f.close()
就是这样。几乎没有清理,这是我们的最终脚本。
#!/usr/bin/env python
print("Generating commit message", end="", flush=True)
import sys
import re
import subprocess
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer
def parse_files(patch):
accumulator = []
lines = patch.splitlines()
filename_before = None
for line in lines:
print(".", end="", flush=True)
if line.startswith("index") or line.startswith("diff"):
continue
if line.startswith("---"):
filename_before = line.split(" ", 1)[1][1:]
continue
if line.startswith("+++"):
filename_after = line.split(" ", 1)[1][1:]
if filename_before == filename_after:
accumulator.append(f"<ide><path>{filename_before}")
else:
accumulator.append(f"<add><path>{filename_after}")
accumulator.append(f"<del><path>{filename_before}")
continue
line = re.sub("@@[^@@]*@@", "", line)
if len(line) == 0:
continue
if line[0] == "+":
line = line.replace("+", "<add>", 1)
elif line[0] == "-":
line = line.replace("-", "<del>", 1)
else:
line = f"<ide>{line}"
accumulator.append(line)
return '\n'.join(accumulator)
def predict(patch, max_length, min_length, num_beams, prediction_count):
print(".", end="", flush=True)
input_text = parse_files(patch)
tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01", low_cpu_mem_usage=True)
print(".", end="", flush=True)
model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01", low_cpu_mem_usage=True)
print(".", end="", flush=True)
with torch.no_grad():
input_ids = tokenizer(
input_text,
truncation=True,
padding=True,
return_tensors="pt",
).input_ids
print(".", end="", flush=True)
outputs = model.generate(
input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
num_return_sequences=prediction_count,
)
print(".", end="", flush=True)
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return result[0]
if __name__ == "__main__":
diff = subprocess.run(['git', 'diff', '--cached'], capture_output=True).stdout.decode('utf-8')
max_message = 40
min_message = 5
num_beams = 10
num_predictions = 1
msg = predict(diff, max_message, min_message, num_beams, num_predictions)
with open(sys.argv[1], 'r+') as f:
content = f.read()
f.seek(0)
f.write(msg + '\n' + content)
f.close()
print("Done!\n")
它在CPU上很快,但是加载模型需要很多次。无论如何3s还可以。
就这样。有用。至少对我而言。