本文可能是全网最全面最详细的 JDBC 驱动开发教程,转载请注明出处
要自己实现一个 JDBC 驱动无疑是较为困难的,在此之前我查阅了很多资料,也查到了许多正在提问的帖子,最后都是无疾而终,提问者没有得到想要的回答。而我由于刚好需要做这方面的开发,也只能硬啃 JDBC 驱动的文档了,写这篇算是一个资料的整理,同时我也已经完成了一个通用的 JDBC 生成工具,文末可以下载。
首先还是来读文档,JDBC 的文档可谓是非常详细了,直接翻到 6.3 JDBC 4.0 API Compliance
,它很明确的告诉了你应该做什么:
1. 支持自动加载的,继承自 java.sql.Driver 的类
2. 支持“仅向前”的结果集
3. 支持“只读的”结果集并发类型
4. 支持批量更新
5. 实现以下接口
1) java.sql.Driver
2) java.sql.DatabaseMetaData
3) java.sql.ParameterMetaData
4) java.sql.ResultSetMetaData
5) java.sql.Wrapper
6) javax.sql.DataSource
7) java.sql.Connection
8) java.sql.Statement
9) java.sql.PreparedStatement
10) java.sql.CallableStatement
11) java.sql.ResultSet
好的,看到这里我们会遇到官方文档挖的第一个坑,其实并非这么多东西都要实现的,实际开发中,只实现以下内容也可以:
1) java.sql.Driver
2) java.sql.DatabaseMetaData
3) java.sql.ResultSetMetaData
4) java.sql.Connection
5) java.sql.Statement
6) java.sql.PreparedStatement
7) java.sql.ResultSet
只要有这些东西,就足以支撑起一个完整的 JDBC 驱动。
那么接下来就是实现了,不得不说,整个 JDBC 的协议真的又臭又长,所有的类加起来一共要实现 584 个接口函数,而且里面有一大半是不起任何作用的。
限于篇幅,在这里我不可能把所有的函数予以列出,只看最重要的那些,其他的请大家自行发挥。
那先来看 Driver
:
class MyDriver: Driver {
companion object {
init {
try {
DriverManager.registerDriver(MyDriver())
} catch (e: Exception) {
throw RuntimeException("Can't register $DRIVER_NAME", e)
}
}
}
override fun acceptsURL(url: String) = url.toLowerCase().startsWith(JDBC_URL)
override fun connect(url: String, info: Properties?): Connection {
if (!acceptsURL(url)) throw SQLException("Invalid URL: $url")
val props = MyDriverUtil.parseMergeProperties(url.replace("jdbc:", ""), info)
return MyConnection(props)
}
... ...
}
使用 companion object
的 init
块来完成驱动的自动加载,需要特别注意的是,千万不要不带 companion object
,如果不带的话,init
块实质上是类的构造函数,而不是静态初始化块。
acceptsURL
方法指出了什么样的 URL 可以被驱动接受,比如说我们经常在使用 mysql 驱动时看到其 JDBC URL 为 jdbc:mysql://
,就是在 acceptsURL
里接受了这样的前缀。
connect
方法用于获取一个连接,它是在 DriverManager.getConnection
时自动被调用的,返回一个非空的 Connection
对象。
这里有个函数,MyDriverUtil.parseMergeProperties
用于将 url 的参数和 info: Properties
所携带的参数进行拼接,该函数实现如下:
fun parseMergeProperties(url: String, prop: Properties?) = mutableMapOf<String, String>().apply {
val uri = URI(url)
this[PROP_HOST] = uri.host
this[PROP_PORT] = (if (uri.port == -1) DEFAULT_PORT else uri.port).toString()
this[PROP_PATH] = uri.path.replaceFirst("/", "")
if (uri.query != null) {
this += uri.query.split("&").map { p -> p.split("=").let { i -> Pair(i[0], i[1]) } }.toMap()
}
if (prop != null) {
this += prop.map { e -> Pair(e.key.toString(), e.value.toString()) }.toMap()
}
}.toMap()
在这段代码中出现了 MyConnection
这个对象,我们下面就来看看如何实现它。
实现 Connection
:
class MyConnection(props: Map<String, String>) : Connection {
val io = MyIO(props)
private var isClosed = false
private val autoCommit = true
override fun prepareStatement(sql: String) = MyPreparedStatement(sql, this)
override fun prepareStatement(sql: String, resultSetType: Int, resultSetConcurrency: Int) = MyPreparedStatement(sql, this)
override fun prepareStatement(sql: String, resultSetType: Int, resultSetConcurrency: Int, resultSetHoldability: Int) = MyPreparedStatement(sql, this)
override fun prepareStatement(sql: String, autoGeneratedKeys: Int) = MyPreparedStatement(sql, this)
override fun prepareStatement(sql: String, columnIndexes: IntArray?) = MyPreparedStatement(sql, this)
override fun prepareStatement(sql: String, columnNames: Array<out String>?) = MyPreparedStatement(sql, this)
override fun getAutoCommit() = autoCommit
override fun getWarnings(): SQLWarning? = null
override fun getCatalog(): String? {
checkConnection()
return null
}
override fun isValid(timeout: Int) = isClosed
override fun close() {
isClosed = true
}
override fun isClosed() = isClosed
override fun isReadOnly() = false
override fun createStatement() = MyStatement(this)
override fun createStatement(resultSetType: Int, resultSetConcurrency: Int) = MyStatement(this)
override fun createStatement(resultSetType: Int, resultSetConcurrency: Int, resultSetHoldability: Int) = MyStatement(this)
override fun getMetaData() = MyDatabaseMetaData()
override fun getTransactionIsolation() = Connection.TRANSACTION_NONE
private fun checkConnection() {
if (isClosed()) throw SQLException("Connection is closed")
}
... ...
}
这里需要实现的东西就比较多了,需要实现所有返回 Statement
或 PreparedStatement
的函数,也需要实现 isClosed
和 close
这类改变或检查状态的函数。
这里有一个 MyIO
的对象,是一个自定义的类,用于完成真正的数据获取工作,在一会的代码中就会用到它了。
在实现 Statement 之前,要先做一点准备工作,有一些公用的东西需要被抽象出来。
abstract class MyAbsStatement {
internal var isClosed = false
internal var connection: MyConnection
internal lateinit var resultSet: ResultSet
protected lateinit var sql: String
constructor(sql: String, conn: MyConnection) {
this.sql = sql
this.connection = conn
}
constructor(conn: MyConnection) {
this.connection = conn
}
open fun executeForResultSet(sql: String): Boolean {
if (isClosed) throw SQLException("This statement is closed.")
try {
resultSet = connection.io.internalExecuteQuery(sql)
return true
} catch (th: Throwable) {
throw SQLException(th)
}
}
open fun executeForResult(sql: String): Int {
if (isClosed) throw SQLException("This statement is closed.")
try {
return connection.io.internalExecuteUpdate(sql)
} catch (th: Throwable) {
throw SQLException(th)
}
}
}
有了这个类之后,我们可以继承它,并且实现 Statement
接口:
class MyStatement(conn: MyConnection) : MyAbsStatement(conn), Statement {
private val batchOps = mutableListOf<String>()
override fun execute(sql: String, autoGeneratedKeys: Int) = execute(sql)
override fun execute(sql: String, columnIndexes: IntArray?) = execute(sql)
override fun execute(sql: String, columnNames: Array<out String>?) = execute(sql)
override fun clearBatch() {
batchOps.clear()
}
override fun getResultSetType() = ResultSet.TYPE_FORWARD_ONLY
override fun isCloseOnCompletion() = false
override fun <T : Any> unwrap(iface: Class<T>): T? = null
override fun getMaxRows() = 0
override fun getConnection() = this.connection
override fun getWarnings(): SQLWarning? = null
override fun executeQuery(sql: String): ResultSet {
this.execute(sql)
return this.getResultSet()
}
override fun close() {
isClosed = true
}
override fun isClosed() = this.isClosed
override fun getMaxFieldSize() = 0
override fun isWrapperFor(iface: Class<*>) = false
override fun getUpdateCount() = -1
override fun getFetchSize() = 0
override fun executeBatch() = IntArray(batchOps.size).apply {
this@MyStatement.batchOps.forEachIndexed { index, sql ->
try {
this@MyStatement.execute(sql)
this[index] = SUCCESS_NO_INFO
} catch (th: Throwable) {
throw BatchUpdateException(th)
}
}
}
override fun getQueryTimeout() = 0
override fun isPoolable() = false
override fun addBatch(sql: String) {
batchOps.add(sql)
}
override fun getGeneratedKeys(): ResultSet? = null
override fun getResultSetConcurrency() = ResultSet.CONCUR_READ_ONLY
override fun getResultSet() = this.resultSet
override fun execute(sql: String) = executeForResultSet(sql)
override fun executeUpdate(sql: String) = executeForResult(sql)
override fun executeUpdate(sql: String, autoGeneratedKeys: Int) = executeUpdate(sql)
override fun executeUpdate(sql: String, columnIndexes: IntArray?) = executeUpdate(sql)
override fun executeUpdate(sql: String, columnNames: Array<out String>?) = executeUpdate(sql)
override fun getFetchDirection() = 0
override fun getResultSetHoldability() = ResultSet.CLOSE_CURSORS_AT_COMMIT
override fun getMoreResults() = false
override fun getMoreResults(current: Int) = false
... ...
}
这里主要实现 execute
相关的方法,这个时候定义在抽象类里的 executeForResult
和 executeForResultSet
就有了用武之地,它们可以将所有的请求一并接管起来。
同样需要注意的,是在 JDBC 文档内所述的,必须支持批量更新,在 Statement
内需要予以支持。
与 Statement
类似的,下面来实现 PreparedStatement
,与 Statement
不同的地方在于,PreparedStatement
需要用户自己处理替换问号占位符的操作。
先给出这个操作的代码:
private fun replaceSQL() {
var idx = 1
while (sql.indexOf("?") > 1) {
try {
val p = parameters[idx]
sql = sql.replaceFirst("?", if (p == null) "null" else "'$p'")
} catch (e: IndexOutOfBoundsException) {
throw SQLException("Can't find defined parameter for position: $idx")
}
idx++
}
}
然后来看看实现 PreparedStatement
需要做些什么:
class MyPreparedStatement(sql: String, conn: MyConnection) : MyAbsStatement(sql, conn), PreparedStatement {
private val parameters = mutableMapOf<Int, String?>()
override fun execute(): Boolean {
replaceSQL()
return super.executeForResultSet(sql)
}
override fun execute(sql: String): Boolean {
this.sql = sql
return this.execute()
}
override fun execute(sql: String, autoGeneratedKeys: Int) = execute(sql)
override fun execute(sql: String, columnIndexes: IntArray?) = execute(sql)
override fun execute(sql: String, columnNames: Array<out String>?) = execute(sql)
override fun getResultSetType() = ResultSet.TYPE_FORWARD_ONLY
override fun clearParameters() {
parameters.clear()
}
override fun getConnection() = this.connection
override fun getWarnings(): SQLWarning? = null
override fun getParameterMetaData(): ParameterMetaData? = null
override fun executeQuery(): ResultSet {
this.execute()
return this.resultSet
}
override fun executeQuery(sql: String): ResultSet {
execute(sql)
return this.resultSet
}
override fun executeUpdate(): Int {
replaceSQL()
return executeForResult(sql)
}
override fun executeUpdate(sql: String): Int {
replaceSQL()
return executeForResult(sql)
}
override fun executeUpdate(sql: String, autoGeneratedKeys: Int) = executeUpdate(sql)
override fun executeUpdate(sql: String, columnIndexes: IntArray?) = executeUpdate(sql)
override fun executeUpdate(sql: String, columnNames: Array<out String>?) = executeUpdate(sql)
override fun close() {
isClosed = true
}
override fun isCloseOnCompletion() = false
override fun getMaxRows() = 0
override fun isClosed() = isClosed
override fun getMaxFieldSize() = 0
override fun getUpdateCount() = 0
override fun getFetchSize() = 0
override fun executeBatch(): IntArray? = null
override fun getQueryTimeout() = 0
override fun isPoolable() = false
override fun getGeneratedKeys(): ResultSet? = null
override fun getResultSetConcurrency() = ResultSet.CONCUR_READ_ONLY
override fun getResultSet() = this.resultSet
override fun getMetaData() = MyResultSetMetaData()
override fun getFetchDirection() = ResultSet.FETCH_FORWARD
override fun getResultSetHoldability() = ResultSet.CLOSE_CURSORS_AT_COMMIT
override fun getMoreResults() = false
override fun getMoreResults(current: Int) = false
override fun setFloat(parameterIndex: Int, x: Float) {
pushIntoParameters(parameterIndex, x.toString())
}
override fun setLong(parameterIndex: Int, x: Long) {
pushIntoParameters(parameterIndex, x.toString())
}
override fun setDouble(parameterIndex: Int, x: Double) {
pushIntoParameters(parameterIndex, x.toString())
}
override fun setInt(parameterIndex: Int, x: Int) {
pushIntoParameters(parameterIndex, x.toString())
}
override fun setString(parameterIndex: Int, x: String?) {
pushIntoParameters(parameterIndex, x)
}
override fun setTimestamp(parameterIndex: Int, x: Timestamp?) {
pushIntoParameters(parameterIndex, x.toString())
}
private fun pushIntoParameters(index: Int, value: String?) {
if (index <= 0) throw SQLException("Invalid position for parameter ($index)")
this.parameters[index] = value
}
... ...
}
可以清楚的看到,在这里主要是用 Map 来保存需要替换的值,然后在执行的时候将真实的参数替换进问号中。然后对于执行 SQL 的地方,依然调用了抽象类里的 executeForResult
和 executeForResultSet
方法。
好了,现在我们已经完成了 Statement
和 PreparedStatement
,你可能要问了,能不能跑起代码看看效果呀?答案是否定的,因为还没有做好完整的准备,我们还需要一些东西,下面这个也很关键,是 ResultSet
。
其实这也是 JDBC 坑的一个地方,通常情况下我们可能会希望写一点代码就运行起来看看效果,但是写 JDBC 驱动时偏偏不能,还是要先完整实现才可以。
一个标准的 ResultSet
实现如下:
class MyResultSet : ResultSet {
private var isClosed = false
private var position = -1
private lateinit var fields: List<String>
private lateinit var result: List<List<String>>
constructor(jsonString: String) {
MyResultSetUtil.jsonToResultData(jsonString) { f, l ->
fields = f
result = l
}
}
constructor(fields: List<String>, list: List<List<String>>) {
this.fields = fields
this.result = list
}
override fun findColumn(columnLabel: String) = fields.indexOf(columnLabel)
override fun getStatement(): Statement? = null
override fun getWarnings(): SQLWarning? = null
override fun beforeFirst() {
checkIfClosed()
position = -1
}
override fun close() {
isClosed = true
}
override fun isFirst(): Boolean {
checkIfClosed()
return position == 0
}
override fun isLast(): Boolean {
checkIfClosed()
return position == result.size - 1
}
override fun last(): Boolean {
position = result.size - 1
return result.isNotEmpty()
}
override fun isAfterLast(): Boolean {
checkIfClosed()
return position >= result.size
}
override fun relative(rows: Int): Boolean {
checkIfClosed()
return if (rows + position in 1 until result.size) {
position += rows
true
} else {
false
}
}
override fun absolute(row: Int): Boolean {
checkIfClosed()
return if (row in 1 until result.size) {
position = row
true
} else {
false
}
}
override fun next(): Boolean {
checkIfClosed()
return if (position < result.size - 1) {
position++
true
} else {
false
}
}
override fun first(): Boolean {
checkIfClosed()
position = 0
return result.isNotEmpty()
}
override fun afterLast() {
checkIfClosed()
position = result.size
}
override fun previous(): Boolean {
checkIfClosed()
return if (position > 1) {
position--
true
} else {
false
}
}
override fun isBeforeFirst(): Boolean {
checkIfClosed()
return position < 0
}
override fun getFloat(columnIndex: Int) = result[position][columnIndex].toFloat()
override fun getFloat(columnLabel: String) = result[position][findColumn(columnLabel)].toFloat()
override fun wasNull() = false
override fun getRow() = position + 1
override fun getType() = ResultSet.TYPE_SCROLL_INSENSITIVE
override fun getString(columnIndex: Int) = result[position][columnIndex]
override fun getString(columnLabel: String) = result[position][findColumn(columnLabel)]
override fun getLong(columnIndex: Int) = result[position][columnIndex].toLong()
override fun getLong(columnLabel: String) = result[position][findColumn(columnLabel)].toLong()
override fun getTimestamp(columnIndex: Int): Timestamp = Timestamp.valueOf(result[position][columnIndex])
override fun getTimestamp(columnLabel: String): Timestamp = Timestamp.valueOf(result[position][findColumn(columnLabel)])
override fun getDouble(columnIndex: Int) = result[position][columnIndex].toDouble()
override fun getDouble(columnLabel: String) = result[position][findColumn(columnLabel)].toDouble()
override fun getInt(columnIndex: Int) = result[position][columnIndex].toInt()
override fun getInt(columnLabel: String) = result[position][findColumn(columnLabel)].toInt()
override fun isClosed() = isClosed
override fun getFetchSize() = result.size
override fun getConcurrency() = ResultSet.CONCUR_READ_ONLY
override fun clearWarnings() {
checkIfClosed()
}
override fun getFetchDirection() = ResultSet.TYPE_SCROLL_INSENSITIVE
private fun checkIfClosed() {
if (isClosed()) throw SQLException()
}
... ...
}
这个看起来就有点复杂了,主要是对游标的移动和获取值的操作,同样的,这里也有一个自定义的函数 MyResultSetUtil.jsonToResultData
,用于将 json 字符串转换为二维数组。这也就意味着我们在这里已经决定了数据的传递方式,以是 json 作为媒介的。
转换函数的实现如下:
fun jsonToResultData(jsonString: String, callback:(fields: List<String>, data: List<List<String>>) -> Unit) {
val fields = getFields(jsonString)
val data = mutableListOf<List<String>>()
JSONArray(jsonString).forEach { _, obj -> data.add(fields.map { obj.get(it).toString() }) }
callback(fields, data)
}
private fun getFields(jsonString: String) = try {
JSONArray(jsonString).run { if (length() > 0) getJSONObject(0).keySet().toList() else listOf() }
} catch (th: Throwable) {
throw SQLException("Cannot get result fields.")
}
最后是补全驱动所需的另两个文件,分别是 DatabaseMetaData
和 ResultSetMetaData
其实这两个 MetaData 都可以什么都不填,因为基本上用不到,只是 JDBC 标准说必须实现,所以才予以实现,通常的处理方法是将其中所有的方法都标记为“不支持”:
throw SQLFeatureNotSupportedException()
像这样就可以了。
好了,是不是现在就想跑起代码来看看效果?我们还有最后一步,还记得上面提到的 IO
对象不,现在来实现这个对象,以完成对数据的请求。当然了,在这里我们使用的是写死的假数据:
object MyTestRequset {
var LOCAL_TEST = false
private val SAMPLEDATA = """[{"id":1, "name":"test1", "age":10},{"id":2, "name":"test2", "age":20},{"id":3, "name":"test3", "age":30},{"id":4, "name":"test4", "age":40},{"id":5, "name":"test5", "age":50}]"""
@TestOnly
fun localTestInternalRequest(sql: String) = if (sql.contains("select ")) SAMPLEDATA else "1"
}
class MyIO(private val prop: Map<String, String>) {
fun internalExecuteQuery(sql: String) = try {
MyResultSet(internalRequest(sql))
} catch (th: Throwable) {
println("internalExecuteQuery error: $th")
null
} ?: throw SQLException("cannot parse ResultSet")
fun internalExecuteUpdate(sql: String) = try {
internalRequest(sql).toInt()
} catch (th: Throwable) {
println("internalExecuteUpdate error: $th")
-1
}
private fun internalRequest(sql: String): String {
if (MyTestRequset.LOCAL_TEST) return MyTestRequset.localTestInternalRequest(sql)
TODO("获取数据的真实代码写在此处")
}
}
好了,现在我们的代码已经完整了,可以运行看看效果,在此写一个 Testcase 来跑一下:
class Test {
@Test
fun doTest() {
MyTestRequset.LOCAL_TEST = true
Class.forName("com.sample.MyDriver")
DriverManager.getConnection("jdbc:myurl://0.0.0.0/sampledb", Properties().apply { setProperty(PROP_SCHEMA, "http") }).use { conn ->
conn.prepareStatement("select * from Data").use { stmt ->
stmt.executeQuery().use { result ->
while (result.next()) {
println(result.getString("name"))
}
}
}
conn.prepareStatement("insert into Data(name) values (?)").use { stmt ->
stmt.setString(1, "23333")
println(stmt.executeUpdate())
}
}
}
}
能顺利跑通就说明我们的驱动已经正常工作了。同样的,符合 JDBC 标准的驱动也可以被 myBatis 等框架加载并使用。
好了,下面是大招,还记得上面的 MyIO
里有一个 TODO
吗?我们完全可以把对数据库的请求代理掉,让它成为一个远程的数据请求,代码如下:
private fun internalRequest(sql: String): String {
var ret: String? = null
http {
url = "${if (prop.containsKey(PROP_SCHEMA)) prop[PROP_SCHEMA] else "http"}://${prop[PROP_HOST]}:${prop[PROP_PORT]}/${prop[PROP_PATH]}"
method = HttpMethod.POST
if (prop.containsKey(PROP_USER)) authenticatorUser = prop[PROP_USER]
if (prop.containsKey(PROP_PASSWORD)) authenticatorPassword = prop[PROP_PASSWORD]
postParam = mutableMapOf("sql" to sql)
onSuccess { code, text, _ ->
if (code != 200) throw SQLException("Remote execute SQL failed: $code")
ret = text
}
}
return ret ?: throw SQLException("Remote SQL result is null.")
}
同时,只需要使用 Ktor 写几行代码,跑起服务器,这一切都顺理成章了(还不会 Ktor 的小伙伴可以看我的 Ktor 从入门到放弃 系列)。
服务端代码:
fun Routing.ISCRouting() {
post("/sampledb") {
val sql = call.requestParameters()["sql"] ?: ""
call.respondText { doRequestDb(sql) }
}
}
在 doRequestDb
的过程中,就可以做各种骚操作了,如分库分表,权限控制等,在此就不赘述了,大家可以发挥自己的想象力。
最后,最上面提到的那个生成 JDBC 驱动代码的工具,可以从我的 Github 下载 EasyJDBC 并编译,然后愉快的开发吧。