提问:spark 数据源有几种扩展方式?
答:三种,两种是基于datasource v1的,第三种是datasource v2的实现;不推荐前两种方法,实现有点复杂
推荐使用第三种。本文将在后面针对datasourceV2做下读取excel的实现,如要扩展其他数据源,本文也可做为参考。
注意:本篇是用java实现的,其实用scala实现会更加合适,尤其对datasource v1用java写会有很多坑
下面是扩展数据源接口的伪代码:
方法一,基于datasource v1的实现之一,实现FileFormat接口
public class ExcelDataSource implements FileFormat, Serializable {
......
}
方法二,基于datasource v1的实现之二,实现RelationProvider接口,以及SchemaRelationProvider 接口
public class ExcelRelationProvider implements RelationProvider, SchemaRelationProvider {
}
方法三,实现DataSourceV2接口, 以及相关的ReadSupport接口(WriteSupport根据需要,如需写数据则实现)
public class ExcelDataSourceV2 implements DataSourceV2, ReadSupport, Serializable {
......
}
提问:其中原理是什么?
答:奉上代码,自行查看
- 方法一和方法二在spark源码中的调用位置,见DataSource 类的 resolveRelation方法,此处给出伪代码,可以看出针对不同的接口,进行了不同的处理,如下:
def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = {
val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
.......
case (dataSource: RelationProvider, None) =>
.......
case (_: SchemaRelationProvider, None) =>
throw new AnalysisException(s"A schema needs to be specified when using $className.")
case (dataSource: RelationProvider, Some(schema)) =>
.......
case (format: FileFormat, _)
.......
}
......
}
- 方法三接口在spark源码中的调用位置,见DataFrameReader类的 load(paths: String*)方法,伪代码如下:
def load(paths: String*): DataFrame = {
if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException("Hive data source can only be used with tables, you can not " +
"read files of Hive data source directly.")
}
val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
//针对datasourceV2的处理
......
} else {
//针对dataSouceV1的处理
loadV1Source(paths: _*)
}
}
提问:可否给个完整的实现例子?
答:当然,下面给出excel针对dataSource v2的实现。
- ExcelDataSourceV2.java,主体代码均写在这个类里,也可以根据需要拆开。
package self.robin.examples.spark.sources.v2.excel;
import com.google.gson.Gson;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.util.StringUtils;
import org.apache.poi.ss.usermodel.Workbook;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupport;
import org.apache.spark.sql.sources.v2.reader.DataReader;
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.SerializableConfiguration;
import self.robin.examples.spark.sources.SheetIterator;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;
import java.util.stream.Collectors;
import static self.robin.examples.spark.sources.SparkWorkbookHelper.*;
/**
* @Description: ...
* @Author: Li Yalei - Robin
* @Date: 2021/2/8 10:20
*/
public class ExcelDataSourceV2 implements DataSourceV2, ReadSupport, Serializable {
@Override
public DataSourceReader createReader(DataSourceOptions options) {
return new ExcelDataSourceV2Reader(SerializableOptions.of(options));
}
class ExcelDataSourceV2Reader implements DataSourceReader, Serializable {
private SerializableOptions options;
private volatile StructType schema;
private Collection<String> paths;
ExcelDataSourceV2Reader(SerializableOptions options) {
this.options = options;
init();
}
private void init(){
Optional<String> pathOpt = options.get("path");
if (!pathOpt.isPresent()) {
throw new RuntimeException("path 不能为空");
}
paths = StringUtils.getStringCollection(pathOpt.get(), ",");
}
/**
* 解析传进来的schema信息
*/
public void buildStructType(Map<String, String> map) {
List<StructField> fieldList = new ArrayList<>();
for (Map.Entry<String, String> entry : map.entrySet()) {
StructField structField = new StructField(entry.getKey(),
new CatalystSqlParser(new SQLConf()).parseDataType(entry.getValue()),
true, Metadata.empty());
fieldList.add(structField);
}
this.schema = new StructType(fieldList.toArray(new StructField[0]));
}
@Override
public StructType readSchema() {
if (this.schema != null && !this.schema.isEmpty()) {
return this.schema;
}
Optional<String> schemaOpt = options.get("schema");
if(schemaOpt.isPresent()){
Map<String, String> map = new Gson().fromJson(schemaOpt.get(), LinkedHashMap.class);
buildStructType(map);
}else {
tryParseColsFromFiles();
}
return this.schema;
}
private void tryParseColsFromFiles(){
boolean header = options.getBoolean("header", false);
//尝试从excel解析header, 构造StructType
//默认取第一个表单
//要求所有excel表单中的列必须一样多
List<String> colNames = new ArrayList<>();
int size = paths.stream().map(path -> {
try {
Workbook wb = createWorkbook(path, getConfiguration());
//默认取第一个表单
List<String> cols = getColumnNames(wb.getSheetAt(0), 1, 1, header);
//保存首个解析出的列名
if(colNames.isEmpty()){
colNames.addAll(cols);
}
//要求所有excel表单中的列必须一样多
return cols.size();
} catch (IOException e) {
e.printStackTrace();
return -1;
}
}).collect(Collectors.toSet()).size();
if(size!=1){
//说明所有excel文件的列不一致
throw new RuntimeException("提供的excel文件中表单的列不一致,请检查");
}
Map<String, String> map = new LinkedHashMap<>();
for (String col : colNames) {
map.put(col, "String");
}
buildStructType(map);
}
@Override
public List<DataReaderFactory<Row>> createDataReaderFactories() {
SerializableConfiguration serConfig = new SerializableConfiguration(getConfiguration());
boolean header = options.getBoolean("header", false);
return paths.parallelStream().map(path -> new DataReaderFactory<Row>() {
@Override
public DataReader<Row> createDataReader() {
return new WorkbookReader(header, path, serConfig);
}
}).collect(Collectors.toList());
}
/**
* 获取提交的配置信息
* @return
*/
private Configuration getConfiguration(){
SparkSession spark = SparkSession.getActiveSession().get();
Configuration config = spark.sparkContext().hadoopConfiguration();
return config;
}
}
class WorkbookReader implements DataReader<Row>, Serializable {
/**
* 是否第一行是表格头
*/
private boolean header;
/**
* 文件路径
*/
private String path;
/**
* excel
*/
private Workbook workbook;
private SheetIterator sheetIterator;
/**
* excel文件的path信息,以及表单中数据的位置信息
* @param header 首行是否是表头
* @param path 文件路径
* @param configuration hadoop
*/
public WorkbookReader(boolean header, String path,
SerializableConfiguration configuration) {
this.header = header;
this.path = path;
if (path == null || path.equals("")) {
throw new RuntimeException("path is null");
}
init(configuration.value());
}
/**
* 因为此处代码不在driver端运行,所以不能 SparkSession.getActiveSession()
*/
private void init(Configuration conf) {
try {
this.workbook = createWorkbook(path, conf);
this.sheetIterator = new SheetIterator(header, this.workbook.iterator());
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public boolean next() throws IOException {
return sheetIterator.hasNext();
}
@Override
public Row get() {
Object[] values = cellValuesInRow(this.sheetIterator.next());
return new GenericRow(values);
}
@Override
public void close() throws IOException {
if (this.workbook != null) {
this.workbook.close();
}
}
}
}
- SerializableOptions.java, 这个类可以理解为与DataSourceOptions完全一样,只不过实现了Serializable接口,因为我需要将入参继续往下传,所以需要类可序列化,仅做参数封装用,可以用Map替换掉。
package self.robin.examples.spark.sources.v2.excel;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
/**
* @Description: ...
* @Author: Robin-Li
* @DateTime: 2021-02-08 22:02
*/
public class SerializableOptions implements Serializable {
private final Map<String, String> keyLowerCasedMap;
private String toLowerCase(String key) {
return key.toLowerCase(Locale.ROOT);
}
public static SerializableOptions of(DataSourceOptions options){
return new SerializableOptions(options.asMap());
}
public SerializableOptions(Map<String, String> originalMap) {
keyLowerCasedMap = new HashMap<>(originalMap.size());
for (Map.Entry<String, String> entry : originalMap.entrySet()) {
keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue());
}
}
public Map<String, String> asMap() {
return new HashMap<>(keyLowerCasedMap);
}
/**
* Returns the option value to which the specified key is mapped, case-insensitively.
*/
public Optional<String> get(String key) {
return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key)));
}
/**
* Returns the boolean value to which the specified key is mapped,
* or defaultValue if there is no mapping for the key. The key match is case-insensitive
*/
public boolean getBoolean(String key, boolean defaultValue) {
String lcaseKey = toLowerCase(key);
return keyLowerCasedMap.containsKey(lcaseKey) ?
Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
}
/**
* Returns the integer value to which the specified key is mapped,
* or defaultValue if there is no mapping for the key. The key match is case-insensitive
*/
public int getInt(String key, int defaultValue) {
String lcaseKey = toLowerCase(key);
return keyLowerCasedMap.containsKey(lcaseKey) ?
Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
}
/**
* Returns the long value to which the specified key is mapped,
* or defaultValue if there is no mapping for the key. The key match is case-insensitive
*/
public long getLong(String key, long defaultValue) {
String lcaseKey = toLowerCase(key);
return keyLowerCasedMap.containsKey(lcaseKey) ?
Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
}
/**
* Returns the double value to which the specified key is mapped,
* or defaultValue if there is no mapping for the key. The key match is case-insensitive
*/
public double getDouble(String key, double defaultValue) {
String lcaseKey = toLowerCase(key);
return keyLowerCasedMap.containsKey(lcaseKey) ?
Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
}
}
- SheetIterator.java 这个类,提供对row可遍历的sheet迭代器,如果一个excel文件有多个sheet,支持前一个sheet取完,自动遍历下一个sheet,直到所有的sheet所有的row都遍历一遍,很方便对吧。
package self.robin.examples.spark.sources;
import org.apache.poi.ss.usermodel.Row;
import org.apache.poi.ss.usermodel.Sheet;
import java.io.Serializable;
import java.util.Iterator;
/**
* @Description: ...
* @Author: Li Yalei - Robin
* @Date: 2021/2/8 19:16
*/
public class SheetIterator implements Iterator<Row>, Serializable {
/** 是否首行是header */
private boolean header;
private Iterator<Sheet> sheetIterator;
private Iterator<Row> rowIterator;
public SheetIterator(boolean header, Iterator<Sheet> sheetIterator){
this.header = header;
this.sheetIterator = sheetIterator;
}
@Override
public boolean hasNext() {
if(this.rowIterator==null || !this.rowIterator.hasNext()){
if(this.sheetIterator==null || !this.sheetIterator.hasNext()){
//sheetIterator is null OR sheetIterator is empty
return false;
}
this.rowIterator = this.sheetIterator.next().rowIterator();
if(header){
//首行是标题
this.rowIterator.next();
}
}
return this.rowIterator.hasNext();
}
@Override
public Row next() {
return rowIterator.next();
}
}
- SparkWorkbookHelper.java 最后贴上我用到的工具类。
package self.robin.examples.spark.sources;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.poi.hssf.usermodel.HSSFWorkbook;
import org.apache.poi.ss.usermodel.Cell;
import org.apache.poi.ss.usermodel.Row;
import org.apache.poi.ss.usermodel.Sheet;
import org.apache.poi.ss.usermodel.Workbook;
import org.apache.poi.xssf.usermodel.XSSFWorkbook;
import org.apache.spark.sql.execution.datasources.CodecStreams;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* @Description: ...
* @Author: Li Yalei - Robin
* @Date: 2021/2/20 17:59
*/
public interface SparkWorkbookHelper {
/**
* 获取表单猎头的表单集合
* @param sheet 表单对象
* @param startRow 从第几行开始(大于等于1)
* @param startCol 从第几列开始(大于等于1)
* @param header 首行是否是标题行(首行即为startRow那行)
* @return 表单列头集合
*/
static List<String> getColumnNames(Sheet sheet, int startRow, int startCol, boolean header){
if(sheet==null){
return new ArrayList<>();
}
Iterator<Row> rowIte = sheet.rowIterator();
int rowIndex = startRow;
while (--rowIndex>0 && rowIte.hasNext()) {
rowIte.next();
}
if(!rowIte.hasNext()){
return new ArrayList<>();
}
Row row = rowIte.next();
int colIndex = startCol;
Iterator<Cell> colIte = row.iterator();
while (--colIndex>0 && colIte.hasNext()){
colIte.next();
}
List<String> cols = new ArrayList<>();
while (colIte.hasNext()){
if(header){
cols.add(colIte.next().getStringCellValue());
} else {
cols.add("col_"+(colIte.next().getColumnIndex()+1));
}
}
return cols;
}
/**
* 根据 path 和 hadoop Configuration 创建 workbook对象
* @param path
* @param conf
* @return wb
*/
static Workbook createWorkbook(String path, Configuration conf) throws IOException {
InputStream inputStream = CodecStreams.createInputStreamWithCloseResource(conf, new Path(path));
try(InputStream is = inputStream){
if (path.endsWith(".xls")) {
return new HSSFWorkbook(is);
} else if (path.endsWith(".xlsx")) {
return new XSSFWorkbook(is);
} else {
throw new IOException("File format is not supported");
}
}
}
/**
* 提取row中的cell的值
* @param row
* @return 返回cell值的数组
*/
static Object[] cellValuesInRow(org.apache.poi.ss.usermodel.Row row) {
Iterator<Cell> cellIte = row.cellIterator();
List cellBuffer = new ArrayList();
Cell cell;
while (cellIte.hasNext()) {
cell = cellIte.next();
switch (cell.getCellTypeEnum()) {
case NUMERIC:
cellBuffer.add(cell.getNumericCellValue());
break;
case BOOLEAN:
cellBuffer.add(cell.getBooleanCellValue());
break;
case STRING:
cellBuffer.add(cell.getStringCellValue());
break;
case BLANK:
cellBuffer.add(null);
break;
default:
throw new RuntimeException("unSupport cell type " + cell.getCellTypeEnum());
}
}
return cellBuffer.toArray();
}
}
- 运行
@Test
private void test(){
String path = "file:/C:/Users/liyalei/Downloads/test.xlsx";
SparkSession spark = SparkSession.builder().master("local[2]").appName("local test").getOrCreate();
String dataSource = ExcelDataSourceV2.class.getName();
Map<String, String> schemaMap = new HashMap<>();
schemaMap.put("a", "String");
schemaMap.put("b", "String");
schemaMap.put("c", "String");
schemaMap.put("d", "String");
Dataset<Row> rows = spark.read().format(dataSource)
//可选: 指定 schema 信息
// .option("schema", new Gson().toJson(schemaMap))
//必填:是否有表头
.option("header", true)
//必填:文件路径,多个路径用逗号分隔
.load(path);
rows.show();
}
写在最后:上面的实现中,读取excel使用的poi的工具包,但是poi对于稍大的excel文件就会oom,所有推荐使用ali 的 EasyExcel工具包替换掉,这个后面有空了再贴上EasyExcel的改版,暂时请读者自行改造。
最后补上一个对datasource v1的粗略实现的例子,细节地方读者根据需要自行补充完整;需要自取
package self.robin.examples.spark.sources.excel;
import lombok.val;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.Job;
import org.apache.poi.ss.usermodel.Cell;
import org.apache.poi.ss.usermodel.Row;
import org.apache.poi.ss.usermodel.Sheet;
import org.apache.poi.ss.usermodel.Workbook;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.datasources.FileFormat;
import org.apache.spark.sql.execution.datasources.FileFormat$class;
import org.apache.spark.sql.execution.datasources.OutputWriterFactory;
import org.apache.spark.sql.execution.datasources.PartitionedFile;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.util.SerializableConfiguration;
import scala.Function1;
import scala.Option;
import scala.collection.Iterator;
import scala.collection.JavaConversions;
import scala.collection.Seq;
import scala.collection.immutable.Map;
import scala.collection.mutable.ListBuffer;
import scala.runtime.AbstractFunction1;
import self.robin.examples.spark.sources.SparkWorkbookHelper;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* @Description: ...
* @Author: Li Yalei - Robin
* @Date: 2020/12/22 21:14
*/
public class ExcelDataSource implements FileFormat, Serializable {
@Override
public Option<StructType> inferSchema(SparkSession sparkSession, Map<String, String> options, Seq<FileStatus> files) {
ExcelOptions xlsxOptions = new ExcelOptions(options);
//TODO: 此处 schema 的解析未做详细实现
StructType structType = new StructType().add("aa", DataTypes.StringType.typeName())
.add("bb", DataTypes.StringType.typeName());
return Option.apply(structType);
}
@Override
public boolean supportBatch(SparkSession sparkSession, StructType dataSchema) {
return false;
}
@Override
public OutputWriterFactory prepareWrite(SparkSession sparkSession, Job job, Map<String, String> options, StructType dataSchema) {
throw new RuntimeException("unImplement OutputWriterFactory");
}
@Override
public Option<Seq<String>> vectorTypes(StructType requiredSchema, StructType partitionSchema, SQLConf sqlConf) {
throw new RuntimeException("unImplement vectorTypes");
}
@Override
public Function1<PartitionedFile, Iterator<InternalRow>> buildReaderWithPartitionValues(SparkSession sparkSession, StructType dataSchema,
StructType partitionSchema,
StructType requiredSchema, Seq<Filter> filters,
Map<String, String> options,
Configuration hadoopConf) {
return FileFormat$class.buildReaderWithPartitionValues(this, sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf);
}
@Override
public boolean isSplitable(SparkSession sparkSession, Map<String, String> options, Path path) {
return false;
}
@Override
public Function1<PartitionedFile, Iterator<InternalRow>> buildReader(SparkSession sparkSession,
StructType dataSchema,
StructType partitionSchema,
StructType requiredSchema,
Seq<Filter> filters,
Map<String, String> options,
Configuration hadoopConf) {
//TODO verify schema
val xlsxOptions = new ExcelOptions(options);
Broadcast<SerializableConfiguration> broadcastedHadoopConf = JavaSparkContext.fromSparkContext(sparkSession.sparkContext())
.broadcast(new SerializableConfiguration(hadoopConf));
return new InternalFunction1(requiredSchema, broadcastedHadoopConf, xlsxOptions);
}
/**
* 此内部类只为了能够序列化用
*/
class InternalFunction1 extends AbstractFunction1<PartitionedFile, Iterator<InternalRow>>
implements Serializable {
private StructType requiredSchema;
private Broadcast<SerializableConfiguration> hadoopConf;
private ExcelOptions xlsxOptions;
public InternalFunction1(StructType requiredSchema, Broadcast<SerializableConfiguration> hadoopConf, ExcelOptions xlsxOptions) {
this.requiredSchema = requiredSchema;
this.hadoopConf = hadoopConf;
this.xlsxOptions = xlsxOptions;
}
@Override
public Iterator<InternalRow> apply(PartitionedFile file) {
Configuration config = hadoopConf.getValue().value();
try(Workbook wb = SparkWorkbookHelper.createWorkbook(file.filePath(), config)) {
return readFile(xlsxOptions, config, wb, requiredSchema);
} catch (IOException e) {
e.printStackTrace();
return null;
}
}
}
/**
* read file
*
* @param requiredSchema
* @param hadoopConf
* @return
*/
public Iterator<InternalRow> readFile(ExcelOptions options, Configuration hadoopConf, Workbook workbook, StructType requiredSchema) {
ListBuffer<InternalRow> rowListBuffer = new ListBuffer();
int sheetNbr = workbook.getNumberOfSheets();
for (int i = 0; i < sheetNbr; i++) {
Sheet sheet = workbook.getSheetAt(i);
java.util.Iterator<Row> rowIte = sheet.rowIterator();
Row row;
while (rowIte.hasNext()) {
row = rowIte.next();
java.util.Iterator<Cell> cellIte = row.cellIterator();
List cellBuffer = new ArrayList();
Cell cell;
while (cellIte.hasNext()) {
cell = cellIte.next();
switch (cell.getCellTypeEnum()) {
case NUMERIC:
cellBuffer.add(cell.getNumericCellValue());
break;
case BOOLEAN:
cellBuffer.add(cell.getBooleanCellValue());
break;
case STRING:
cellBuffer.add(UTF8String.fromString(cell.getStringCellValue()));
break;
case BLANK:
cellBuffer.add(null);
break;
default:
throw new RuntimeException("unSupport cell type");
}
}
InternalRow internalRow = InternalRow.fromSeq(JavaConversions.asScalaBuffer(cellBuffer).toSeq());
rowListBuffer.$plus$eq(internalRow);
}
}
return rowListBuffer.iterator();
}
}
如果需要完整的项目,请移步至:https://github.com/Lahar-bigdata/fast-examples/tree/main/spark/src/main/java/self/robin/examples/spark/sources/v2/excel