如何使用打开AI编写简单的智能代码搜索
#javascript #python #openai #nlp

在本文中,我们将简要审查一种基于Chatgpt嵌入的技术。另外,我们将在项目的代码库中编写一个简单的智能搜索。

嵌入是将单词或文本转换为一个称为数字向量的数字集的过程。可以将向量相互比较,以确定两个文本或单词的含义。

例如,我们采用“捐赠”和“ give”单词的两个数值向量(嵌入)。单词不同,但含义相似,即它们是相互联系的,两者的结果都会给某人。

要在代码的上下文中获得类似的结果,我们可以将单词转换为嵌入并比较其similarity measures。度量的值范围为0到1,其中1是最大相似性,而0是无相关性的。 Cosine similarity可以作为这种比较函数。想象一下,我们已经用嵌入式完成了所有必要的操作,现在分析结果。我们还比较了两个单词“ give”和``红色''单词的嵌入。对于“给出”和“捐赠”函数的单词,该功能返回了数字0.80,而“给予”和“红色”仅0.20。因此,我们可以得出结论,“给予”和“捐赠”比“给予”和“红色”更近。

在项目的代码库中,可以使用嵌入方式在代码或文档中搜索。例如,您可以从搜索查询中进行嵌入(向量),然后测量同时找到相关功能或类的相似性。

因此,要使它需要一个打开的AI帐户和一个API令牌。如果您还没有帐户,则可以在Open AI官方网站上注册。注册并验证您的帐户后,转到API Keys个人资料部分并生成API令牌。

开始,他们给了18美元。这足以为本文(下图)提供一个示例并对服务进行进一步的测试。

将打字稿项目作为代码库。我建议服用一小部分,以免等待长时间产生嵌入。您也可以使用an example。您还需要python 3个以上版本和开放AI的库。如果您不知道这些语言,就不会担心。下面的代码示例很简单,不需要对两者都有深入的了解。

让我们开始。首先,您需要编写一个代码来从项目中提取各种代码,例如功能。 TypeScript提供了一种方便的compiler API,用于与AST一起工作,从而简化了任务。安装CSV-Stringify库生成CSV:

$ npm install csv-stringify

然后创建文件 code-code-csv.js ,然后从代码中提取信息:

const path = require('path');
const ts = require('typescript');
const csv = require('csv-stringify/sync');

const cwd = process.cwd();
const configJSON = require(path.join(cwd, 'tsconfig.json'));
const config = ts.parseJsonConfigFileContent(configJSON, ts.sys, cwd);
const program = ts.createProgram(
    config.fileNames, 
    config.options, 
    ts.createCompilerHost(config.options)
);
const checker = program.getTypeChecker();

const rows = [];

const addRow = (fileName, name, code, docs = '') => rows.push({
    file_name: path.relative(cwd, fileName),
    name,
    code,
    docs
});

function addFunction(fileName, node) {
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = symbol.getName();
        const docs = getDocs(symbol);
        const code = node.getText();
        addRow(fileName, name, code, docs);
    }
}

function addClass(fileName, node) {
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = symbol.getName();
        const docs = getDocs(symbol);
        const code = `class ${name} {}`;
        addRow(fileName, name, code, docs);
        node.members.forEach(m => addClassMember(fileName, name, m));
    }
}

function addClassMember(fileName, className, node) {
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = className + ':' + symbol.getName();
        const docs = getDocs(symbol);
        const code = node.getText();
        addRow(fileName, name, code, docs);
    }
}

function addInterface(fileName, node) {
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = symbol.getName();
        const docs = getDocs(symbol);
        const code = `interface ${name} {}`;
        addRow(fileName, name, code, docs);
        node.members.forEach(m => addInterfaceMember(fileName, name, m));
    }
}

function addInterfaceMember(fileName, interfaceName, node) {
    if (!ts.isPropertySignature(node) || !ts.isMethodSignature(node)) {
        return;
    }
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = interfaceName + ':' + symbol.getName();
        const docs = getDocs(symbol);
        const code = node.getText();
        addRow(fileName, name, code, docs);
    }
}

function getDocs(symbol) {
    return ts.displayPartsToString(symbol.getDocumentationComment(checker));
}

for (const fileName of config.fileNames) {
    const sourceFile = program.getSourceFile(fileName);
    const visitNode = node => {
        if (ts.isFunctionDeclaration(node)) {
            addFunction(fileName, node);
        } else if (ts.isClassDeclaration(node)) {
            addClass(fileName, node);
        } else if (ts.isInterfaceDeclaration(node)) {
            addInterface(fileName, node);
        }
        ts.forEachChild(node, visitNode);
    };
    ts.forEachChild(sourceFile, visitNode);
}

for (const row of rows) {
    row.combined = '';
    if (row.docs) {
        row.combined += `Code documentation: ${row.docs}; `;
    }
    row.combined += `Code: ${row.code}; Name: ${row.name};`;
}

const output = csv.stringify(rows, {
    header: true
});

console.log(output);

脚本收集了我们需要的所有片段,并将CSV表打印到控制台。 CSV表由列 file_name name 代码 docs 组合

  • file_name 包含项目中文件的路径,
  • 名称是片段的名称,例如“函数的名称”,
  • 代码是实体代码,
  • 文档是从评论到片段的描述,
  • 组合是代码和文档列的内容的添加,我们将使用此列来生成嵌入。

您不需要运行它。

现在到Python。

从Open AI和实用程序中安装库以使用嵌入:

$ pip install openai[embeddings]

创建文件 create_search_db.py 带有以下代码:

from io import StringIO
from subprocess import PIPE, run
from pandas import read_csv
from openai.embeddings_utils import get_embedding as _get_embedding
from tenacity import wait_random_exponential, stop_after_attempt

get_embedding = _get_embedding.retry_with(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(10))

if __name__ == '__main__':
    # 1
    result = run(['node', 'code-to-csv.js'], stdout=PIPE, stderr=PIPE, universal_newlines=True)
    if result.returncode != 0:
        raise RuntimeError(result.stderr)
    # 2
    db = read_csv(StringIO(result.stdout))
    # 3
    db['embedding'] = db['combined'].apply(lambda x: get_embedding(x, engine='text-embedding-ada-002'))
    # 4
    db.to_csv("search_db.csv", index=False)

脚本运行 code-code-csv.js (1),将结果加载到dataframe(2)中,并为 compined (3)列。嵌入将写入嵌入列。最终表具有搜索所需的所有内容。

要使脚本工作,您需要一个API令牌。 openai 库可以自动从环境变量中获取令牌,因此您可以编写一个方便的脚本来制作:

export OPENAI_API_KEY=YourToken

将其保存在某个地方,例如在 env.sh 中,然后运行:

$ source env.sh

一切准备生成搜索数据库。

运行脚本 create_search_db.py ,然后等到出现带有数据库的CSV文件。这可能需要几分钟。之后,您可以开始编写搜索引擎。

创建一个新的 search.py​​ '文件并写下以下内容:

import sys
import numpy as np
from pandas import read_csv
from openai.embeddings_utils import cosine_similarity, get_embedding as _get_embedding
from tenacity import  stop_after_attempt, wait_random_exponential

get_embedding = _get_embedding.retry_with(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(10))

def search(db, query):
    # 4
    query_embedding = get_embedding(query, engine='text-embedding-ada-002')
    # 5
    db['similarities'] = db.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
    # 6
    db.sort_values('similarities', ascending=False, inplace=True)
    result = db.head(3)
    text = ""
    for row in result.itertuples(index=False):
        score=round(row.similarities, 3)
        if type(row.docs) == str:
            text += '/**\n * {docs}\n */\n'.format(docs='\n * '.join(row.docs.split('\n')))
        text += '{code}\n\n'.format(code='\n'.join(row.code.split('\n')[:7]))
        text += '[score={score}] {file_name}:{name}\n'.format(score=score, file_name=row.file_name, name=row.name)
        text += '-' * 70 + '\n\n'
    return text

if __name__ == '__main__':
    # 1
    db = read_csv('search_db.csv')
    # 2
    db['embedding'] = db.embedding.apply(eval).apply(np.array)
    query = sys.argv[1]
    print('')
    # 3
    print(search(db, query))

让我们分析脚本的工作原理。来自search_db.csv的数据加载到数据框中(1),该表的面向对象表示。然后将表中的嵌入方式转换为具有数字(2)的数组,以便可以使用它们。最后,使用搜索查询字符串(3)启动搜索功能。

搜索函数生成了查询(4)的嵌入,测量嵌入与从基数嵌入的嵌入的相似性,并将相似性得分存储在相似性列(5)中。

相似度的程度取决于1到1的数字,其中1表示最大拟合。表中的行通过相似性进行排序(6)。

最后,从数据库中检索了前三行并将其打印到控制台。

搜索引擎已经准备就绪,您可以对其进行测试。

对于测试,请使用请求运行命令:

Console output

现在尝试用另一种语言输入请求:

Console output

您可以看到,搜索是基于查询中单词的含义,而不仅仅是关键字。

该工具不仅限于这种情况和一个项目。您可以一次对所有项目进行更广泛的搜索。如果您每年开发多个类似的应用程序,并且希望快速找到代码片段,或者您有很多文档,并且关键字搜索不是一个好调用,这将很有用。这完全取决于任务和范围。

感谢您的关注!

链接: