Spark 实现MySQL update操作

背景

目前 spark 对 MySQL 的操作只有 Append,Overwrite,ErrorIfExists,Ignore几种表级别的模式,有时我们需要对表进行行级别的操作,比如update。即我们需要构造这样的语句出来:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;

需求:我们的目的是既不影响以前写的代码,又不引入新的API,只需新加一个配置如:savemode=update这样的形式来实现。

实践

要满足以上需求,肯定是要改源码的,首先创建自己的saveMode,只是新加了一个Update而已:

public enum I4SaveMode {
    Append,
    Overwrite,
    ErrorIfExists,
    Ignore,
    Update
}

JDBC数据源的相关实现主要在JdbcRelationProvider里,我们需要关注的是createRelation方法,我们可以在此方法里,把SaveMode改成我们自己的mode,并把mode带到saveTable方法里,所以改造后的方法如下(改了的地方都有注释):

   override def createRelation(
                                   sqlContext: SQLContext,
                                   mode: SaveMode,
                                   parameters: Map[String, String],
                                   df: DataFrame): BaseRelation = {
        val options = new JDBCOptions(parameters)
        val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
        // 替换成自己的saveMode
        var saveMode = mode match {
                case SaveMode.Overwrite => I4SaveMode.Overwrite
                case SaveMode.Append => I4SaveMode.Append
                case SaveMode.ErrorIfExists => I4SaveMode.ErrorIfExists
                case SaveMode.Ignore => I4SaveMode.Ignore
            }
        //重点在这里,检查是否有saveMode=update的参数,并设为对应的模式
        val parameterLower = parameters.map(kv => (kv._1.toLowerCase,kv._2))
        if(parameterLower.keySet.contains("savemode")){
            saveMode = if(parameterLower.get("savemode").get.equals("update")) I4SaveMode.Update else saveMode
        }
        val conn = JdbcUtils.createConnectionFactory(options)()
        try {
            val tableExists = JdbcUtils.tableExists(conn, options)
            if (tableExists) {
                saveMode match {
                    case I4SaveMode.Overwrite =>
                        if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
                            // In this case, we should truncate table and then load.
                            truncateTable(conn, options.table)
                            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
                            saveTable(df, tableSchema, isCaseSensitive, options, saveMode)
                        } else {
                        ......
    }

接下来就是saveTable方法:

def saveTable(
      df: DataFrame,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      options: JDBCOptions,
      mode: I4SaveMode): Unit = { 
    ......
    val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
    .....
    repartitionedDF.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
    )
  }

这里通过getInsertStatement方法构造sql语句,接着遍历每个分区进行对应的save操作,我们先看是构造语句是怎么改的(改了的地方都有注释):

def getInsertStatement(
      table: String,
      rddSchema: StructType,
      tableSchema: Option[StructType],
      isCaseSensitive: Boolean,
      dialect: JdbcDialect,
      mode: I4SaveMode): String = {
    val columns = if (tableSchema.isEmpty) {
      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    } else {
      val columnNameEquality = if (isCaseSensitive) {
        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
      } else {
        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
      } 
      val tableColumnNames = tableSchema.get.fieldNames
      rddSchema.fields.map { col =>
        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
          throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
        }
        dialect.quoteIdentifier(normalizedName)
      }.mkString(",")
    } 
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    // s"INSERT INTO $table ($columns) VALUES ($placeholders)"
   //若为update模式需要单独构造
    mode match {
            case I4SaveMode.Update ⇒
                val duplicateSetting = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).map(name ⇒ s"$name=?").mkString(",")
                s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
            case _ ⇒ s"INSERT INTO $table ($columns) VALUES ($placeholders)"
        }
  }

只需判断是否是update模式来构造对应的 sql语句,接着主要是看 savePartition 方法,看看具体是怎么保存的:

 def savePartition(
      getConnection: () => Connection,
      table: String,
      iterator: Iterator[Row],
      rddSchema: StructType,
      insertStmt: String,
      batchSize: Int,
      dialect: JdbcDialect,
      isolationLevel: Int): Iterator[Byte] = {
    val conn = getConnection()
    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE
    if (isolationLevel != Connection.TRANSACTION_NONE) {
      try {
        val metadata = conn.getMetaData
        if (metadata.supportsTransactions()) {
          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation
          if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {
            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logWarning(s"Requested isolation level $isolationLevel is not supported; " +
                s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {
        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

    try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = conn.prepareStatement(insertStmt)
      val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
      val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
      val numFields = rddSchema.fields.length

      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          while (i < numFields) {
            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }
        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }
      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
      Iterator.empty
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (cause != null && e.getCause != cause) {
          // If there is no cause already, set 'next exception' as cause. If cause is null,
          // it *may* be because no cause was set yet
          if (e.getCause == null) {
            try {
              e.initCause(cause)
            } catch {
              // Or it may be null because the cause *was* explicitly initialized, to *null*,
              // in which case this fails. There is no other way to detect it.
              // addSuppressed in this case as well.
              case _: IllegalStateException => e.addSuppressed(cause)
            }
          } else {
            e.addSuppressed(cause)
          }
        }
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        }
      }
    }
  }

大体思想就是在迭代该分区数据进行插入之前就先根据数据的schema设置好了插入模板setters,迭代的时候只需将此模板应用到每一行数据上就行了,避免了每一行都需要去判断数据类型。
在非update的情况下:insert into tb (id,name,age) values (?,?,?)
在update情况下:insert into tb (id,name,age) values (?,?,?) on duplicate key update id=?,name =? ,age=?;
即占位符多了一倍,在update模式下进行写入的时候需要向PreparedStatement多喂一遍数据。原本的makeSetter方法如下:

private def makeSetter(
      conn: Connection,
      dialect: JdbcDialect,
      dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))
    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))
    ...
  }

我们只需要再加一个相对位置参数offset来控制,即改造成:

private def makeSetter(
       conn: Connection,
       dialect: JdbcDialect,
       dataType: DataType): JDBCValueSetter = dataType match {
     case IntegerType ⇒
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
             stmt.setInt(pos + 1, row.getInt(pos - offset))
     case LongType ⇒
        (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
             stmt.setLong(pos + 1, row.getLong(pos - offset))
    ...

在非update模式下offset就为0,在update模式下在没有超过numFields时offset为0,超过numFileds时offset为numFields。改造后的savePartition方法为:

def savePartition(
                 getConnection: () => Connection,
                 table: String,
                 iterator: Iterator[Row],
                 rddSchema: StructType,
                 insertStmt: String,
                 batchSize: Int,
                 dialect: JdbcDialect,
                 isolationLevel: Int,
                 mode: I4SaveMode): Iterator[Byte] = {
    ...
    //判断是否为update
    val isUpdateMode = mode == I4SaveMode.Update
    val stmt = conn.prepareStatement(insertStmt)
    val setters: Array[JDBCValueSetter] = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
    val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
    val length = rddSchema.fields.length
    // update模式下占位符是2倍
    val numFields = if (isUpdateMode) length * 2 else length
    val midField = numFields / 2
    try {
        var rowCount = 0
        while (iterator.hasNext) {
            val row = iterator.next()
            var i = 0
            while (i < numFields) {
                if (isUpdateMode) {
                    // update模式下未超过字段长度,offset为0
                    i < midField match {
                        case true ?
                            if (row.isNullAt(i)) {
                                stmt.setNull(i + 1, nullTypes(i))
                            } else {
                                setters(i).apply(stmt, row, i, 0)
                            }
                        // update模式下超过字段长度,offset为midField,即字段长度
                        case false ?
                            if (row.isNullAt(i - midField)) {
                                stmt.setNull(i + 1, nullTypes(i - midField))
                            } else {
                                setters(i - midField).apply(stmt, row, i, midField)
                            }
                    }
                
                } else {
                    if (row.isNullAt(i)) {
                        stmt.setNull(i + 1, nullTypes(i))
                    } else {
                        setters(i).apply(stmt, row, i, 0)
                    }
                }
                i = i + 1
            }
          ...

改造好源码后,需要重新编译打包,替换掉线上对应的jar即可。其实这里有个捷径,自己创建相同的包名,改好源码后打成jar包,把该jar里面的class文件替换掉线上jar里面对应的那些class文件就可以了。

如何使用

若需要使用到update模式:

df.write.option("saveMode","update").jdbc(...)

参考

https://blog.csdn.net/cjuexuan/article/details/52333970

我的GitHub

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,922评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,591评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,546评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,467评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,553评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,580评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,588评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,334评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,780评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,092评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,270评论 1 344
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,925评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,573评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,194评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,437评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,154评论 2 366
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,127评论 2 352

推荐阅读更多精彩内容