DataProvider类
public class MarketingDataProvider {
/**
* 使用方法 _testNg的方法上加如下注解
*
* @DataProviderParams(csvPrefix="userDimension/userDimensionRule")
* @Test(dataProvider=MarketingDataProvider.DEFAULT_PROVIDER,dataProviderClass=MarketingDataProvider.class)
* public void failedCutoff(String field,String dimension,String isLimit,String dayOffSet,String description){}
*
* 对应的CSV文件为:./src/test/resources/data/input/userDimension/userDimensionRule_failedCutoff.csv
*
* CSV文件中的表头列需要与注解的方法传入参数一一对应
*
*/
public final static String DEFAULT_PROVIDER = "marketingDataProvider";
private final static String DEFAULT_FILE_PREFIX = "data/input/"; //ClassLoader.getSystemResource()这里路径前面不能加 /
private final static String DEFAULT_FILE_SUFFIX = ".csv";
private final static String DEFAULT_ENCODING = "UTF-8";
private final static charcsvSeprator = ',';
@SuppressWarnings("resource")
@DataProvider(name = DEFAULT_PROVIDER)
public static Object[][] prepareData(Method method) {
CSVReader reader = null;
DataProviderParams dataProviderParams = method.getAnnotation(DataProviderParams.class);
if (null == dataProviderParams) {
throw new IllegalArgumentException(DEFAULT_PROVIDER + " must use together with @DataProviderParams.");
}
if ("".equals(dataProviderParams.csvPrefix())) {
throw new IllegalArgumentException("csvPrefix can not be null.");
}
String csvDirectory = DEFAULT_FILE_PREFIX + dataProviderParams.csvPrefix() +
"_" + method.getName().toString() + DEFAULT_FILE_SUFFIX;
String systemEncoding = System.getProperty("file.encoding");
try {
URL url = ClassLoader.getSystemResource(csvDirectory);
File file = new File(url.getFile());
reader = new CSVReader(new InputStreamReader(new FileInputStream(file), systemEncoding), csvSeprator);
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
try {
reader.readNext();
} catch (IOException e) {
e.printStackTrace();
}
String[] csvRow = null;
List<Object[]> csvList = new ArrayList<Object[]>();
try {
while ((csvRow = reader.readNext()) != null) {
csvList.add(charsetConvert(csvRow,systemEncoding,DEFAULT_ENCODING));
}
} catch (IOException e) {
e.printStackTrace();
}
Object[][] results;
int[] indices = dataProviderParams.indices();
if (indices.length > 0) {
indices = arrayUnique(indices);
results = new Object[indices.length][];
for (int i = 0; i < indices.length; i++) {
if (indices[i] <= 0)
continue;
results[i] = csvList.get(indices[i] - 1);
}
} else {
results = new Object[csvList.size()][];
for (int i = 0; i < csvList.size(); i++) {
results[i] = csvList.get(i);
}
}
return results;
}
/**
*
* @param source
* @param sourceCharset
* @param targetCharset
* @return
*/
private static String[] charsetConvert(String[] source,String sourceCharset,String targetCharset){
String[] target = new String[source.length];
for(int i=0;i<source.length;i++){
try {
target[i] = new String(source[i].getBytes(sourceCharset),targetCharset);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
}
return target;
}
private staticint[] arrayUnique(int[] indices) {
Set<Integer> set = new HashSet<Integer>();
for (int i = 0; i < indices.length; i++) {
set.add(indices[i]);
}
int[] arrays = newint[set.size()];
int count = 0;
for (Integer integer : set) {
arrays[count] = integer;
count++;
}
return arrays;
}
}
DataProviderParams 注解类
@Retention(java.lang.annotation.RetentionPolicy.RUNTIME)
@Target({ METHOD })
public @interface DataProviderParams {
/**
* CSV文件路径前缀
*/
public String csvPrefix() default "";
/**
* 是否开启debug模式
*/
publicboolean debugModel() default false;
/**
* 选择哪些行进行执行, 默认: all.
*/
int[] indices() default {};
}
使用方法
@DataProviderParams(csvPrefix="userDimension/userDimensionRule")
@Test(dataProvider=MarketingDataProvider.DEFAULT_PROVIDER,dataProviderClass=MarketingDataProvider.class)
public void failedCutoff(String field,String dimension,String isLimit,String dayOffSet,String description) {
}