SimilarChunks Strategy
Implement class: SimilarChunks Strategy
Core Logic
function calculate(text: String, canonicalName: String) -> SimilarChunkContext:
Split `text` into `lines` by line breaks
Take the last `snippetLength` lines from `lines` and join them into `beforeCursor` string
Get a list of `canonicalNames` from the keys of `fileTree`
Calculate the similarity between `canonicalName` and `canonicalNames`
Sort the results in descending order and take the top `maxRelevantFiles` paths as `relatedCodePath`
Log the value of `relatedCodePath`
Split `beforeCursor` into `chunks`
Get all related code chunks from `fileTree` using `relatedCodePath`
Join the chunks into `allRelatedChunks` string
Log the value of `allRelatedChunks`
Calculate the similarity score between each chunk in `allRelatedChunks` and `chunks`
Sort the results in descending order and take the top `maxRelevantFiles` chunks as `similarChunks`
If the size of `similarChunks` is greater than 3, take the first 3 chunks, otherwise, take all chunks
Create a `SimilarChunkContext` object with language set to "java",
`relatedCodePath` set to `relatedCodePath`, and `similarChunks` set to `similarChunks`
Return the `SimilarChunkContext` object
Template data
{
"language": "java",
"beforeCursor": "package com.example.config;\n\nimport org.dbunit.DatabaseUnitException;\nimport org.dbunit.database.DatabaseConfig;\nimport org.dbunit.database.DatabaseConnection;\nimport org.dbunit.database.IDatabaseConnection;\nimport org.dbunit.database.QueryDataSet;\nimport org.dbunit.dataset.IDataSet;\nimport org.dbunit.dataset.xml.FlatXmlDataSet;\nimport org.dbunit.dataset.xml.FlatXmlDataSetBuilder;\nimport org.dbunit.operation.DatabaseOperation;\nimport org.springframework.beans.factory.annotation.Autowired;\nimport org.springframework.jdbc.datasource.DataSourceUtils;\nimport org.springframework.stereotype.Service;\n\nimport javax.sql.DataSource;\nimport java.io.File;\nimport java.io.FileInputStream;\nimport java.io.FileNotFoundException;\nimport java.io.FileWriter;\nimport java.sql.Connection;\nimport java.sql.DatabaseMetaData;\nimport java.sql.ResultSet;\nimport java.sql.SQLException;\nimport java.util.ArrayList;\nimport java.util.List;\n\n/**\n * 这个类的目的是通过每次备份系统初始化的数据(这些初始化数据可能来自 flyway),\n * 来实现每次测试的数据一致的目的。\n */\n\n@Service\npublic class ResetDbService {\n\n public static final String ROOT_URL = \"build/resources/test/\";\n private static IDatabaseConnection conn;\n\n @Autowired\n private DataSource dataSource;\n private File tempFile;\n\n public void backUp() throws Exception {\n this.getConnection();\n this.backupCustom(tables());\n }\n\n public void rollback() throws Exception {\n this.reset();\n this.closeConnection();\n }\n\n List<String> tables() throws SQLException {\n Connection connection = dataSource.getConnection();\n DatabaseMetaData metaData = connection.getMetaData();\n ResultSet tables = metaData.getTables(null, null, null, new String[]{\"TABLE\"});\n ArrayList<String> tableNames = new ArrayList<>();\n while (tables.next()) {\n String tableName = tables.getString(\"TABLE_NAME\");\n tableNames.add(tableName);\n }\n\n connection.close();\n return tableNames;\n }\n\n protected void backupCustom(List<String> tableName) {",
"afterCursor": " try {\n QueryDataSet qds = getQueryDataSet();\n for (String str : tableName) {\n qds.addTable(str);\n }\n\n conn.getConfig().setProperty(DatabaseConfig.PROPERTY_ESCAPE_PATTERN , \"`?`\");\n\n tempFile = new File(ROOT_URL + \"temp.xml\");\n FlatXmlDataSet.write(qds, new FileWriter(tempFile), \"UTF-8\");\n } catch (Exception e) {\n e.printStackTrace();\n }\n }",
"similarChunks": [
"\n\npublic class ResetDbListener extends AbstractTestExecutionListener {\n\n @Override\n public int getOrder() {\n return 4500;\n }\n\n @Override\n public void beforeTestMethod(TestContext testContext) throws Exception {\n ResetDbService resetDbService =\n testContext.getApplicationContext().getBean(ResetDbService.class);\n resetDbService.backUp();\n }\n\n @Override\n public void afterTestMethod(TestContext testContext) throws Exception {\n ResetDbService resetDbService =\n testContext.getApplicationContext().getBean(ResetDbService.class);\n resetDbService.rollback();\n }\n}\n"
],
"output": " try {\n QueryDataSet qds = getQueryDataSet();\n for (String str : tableName) {\n qds.addTable(str);\n }\n\n conn.getConfig().setProperty(DatabaseConfig.PROPERTY_ESCAPE_PATTERN , \"`?`\");\n\n tempFile = new File(ROOT_URL + \"temp.xml\");\n FlatXmlDataSet.write(qds, new FileWriter(tempFile), \"UTF-8\");\n } catch (Exception e) {\n e.printStackTrace();\n }\n }"
}
BeforeCursor code
package com.example.config;
import org.dbunit.DatabaseUnitException;
import org.dbunit.database.DatabaseConfig;
import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.database.QueryDataSet;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSetBuilder;
import org.dbunit.operation.DatabaseOperation;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.datasource.DataSourceUtils;
import org.springframework.stereotype.Service;
import javax.sql.DataSource;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
/**
* 这个类的目的是通过每次备份系统初始化的数据(这些初始化数据可能来自 flyway),
* 来实现每次测试的数据一致的目的。
*/
@Service
public class ResetDbService {
public static final String ROOT_URL = "build/resources/test/";
private static IDatabaseConnection conn;
@Autowired
private DataSource dataSource;
private File tempFile;
public void backUp() throws Exception {
this.getConnection();
this.backupCustom(tables());
}
public void rollback() throws Exception {
this.reset();
this.closeConnection();
}
List<String> tables() throws SQLException {
Connection connection = dataSource.getConnection();
DatabaseMetaData metaData = connection.getMetaData();
ResultSet tables = metaData.getTables(null, null, null, new String[]{"TABLE"});
ArrayList<String> tableNames = new ArrayList<>();
while (tables.next()) {
String tableName = tables.getString("TABLE_NAME");
tableNames.add(tableName);
}
connection.close();
return tableNames;
}
protected void backupCustom(List<String> tableName) {
similarChunks
public class ResetDbListener extends AbstractTestExecutionListener {
@Override
public int getOrder() {
return 4500;
}
@Override
public void beforeTestMethod(TestContext testContext) throws Exception {
ResetDbService resetDbService =
testContext.getApplicationContext().getBean(ResetDbService.class);
resetDbService.backUp();
}
@Override
public void afterTestMethod(TestContext testContext) throws Exception {
ResetDbService resetDbService =
testContext.getApplicationContext().getBean(ResetDbService.class);
resetDbService.rollback();
}
}