pyspark读取Oracle数据库并根据字段进行分区

前一篇文章pyspark连接oracle中详细讲述了初步连接Oracle的方法,这种连接方式每次只使用一个RDD分区,即numPartitions默认为1.这种方式当表特别大的时候,很可能出现OOM.

pyspark提供两种对数据库进行分区读取的方式

方法一:指定数据库字段的范围

之前的方式是:

empDF = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:oracle:thin:@//hostname:portnumber/SID") \
    .option("dbtable", "hr.emp") \
    .option("user", "db_user_name") \
    .option("password", "password") \
    .option("driver", "oracle.jdbc.driver.OracleDriver") \
    .load()

 现在需要增加partitionColumn, lowerBound, upperBound, numPartitions这几个属性值

empDF = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:oracle:thin:@//hostname:portnumber/SID") \
    .option("dbtable", "hr.emp") \
    .option("user", "db_user_name") \
    .option("password", "password") \
    .option("driver", "oracle.jdbc.driver.OracleDriver") \
    .option("partitionColumn", partitionColumn)
    .option("lowerBound", lowerBound)
    .option("upperBound", upperBound)
    .option("numPartitions", numPartitions)
    .load()

这些属性仅适用于读数据,而且必须同时被指定。partitionName就是需要分区的字段,这个字段在数据库中的类型必须是数字;lowerBound就是分区的下界;upperBound就是分区的上界;numPartitions是分区的个数。

这个方法可以将数据库中表的数据分布到RDD的几个分区中,分区的数量由numPartitions参数决定,在理想情况下,每个分区处理相同数量的数据,我们在使用的时候不建议将这个值设置的比较大,因为这可能导致数据库挂掉!但是根据前面介绍,这个函数的缺点就是只能使用整形数据字段作为分区关键字。

方法二:根据任意字段进行分区

因为数据库中有很多时候需要对日期进行分段,所以第一种方法就不适用了。还好,spark根据不同需求提供了一个函数:

jdbc(url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, predicates=None, properties=None)
'''
构建一个DataFrame表示通过JDBC URL url命名的table和连接属性连接的数据库表。
column参数可用于对表进行分区,然后根据传递给此函数的参数并行检索它。
predicates参数给出了一个适合包含在WHERE子句中的列表表达式; 每一个都定义了DataFrame的一个分区。
注:不要在大型集群上并行创建太多分区; 否则Spark可能会使外部数据库系统崩溃。

参数:url – 一个JDBC URL
     table – 表名称
    column – 用于分区的列
    lowerBound – 分区列的下限
    upperBound – 分区列的上限
    numPartitions – 分区的数量
    predicates – 包含在WHERE子句中的表达式列表; 每一个都定义了DataFrame的一个分区
    properties – JDBC数据库连接参数,任意字符串的标签/值的列表。For example { 'user' : 'SYSTEM', 'password' : 'mypassword' }
返回 : 一个DataFrame
'''

 在这个函数里需要设置属性predicates、properties 的值

predicates = []
datelist = {"2014-11-01": "2015-01-01",
            "2014-09-01": "2014-11-01",
            "2014-07-01": "2014-09-01",
            "2014-05-01": "2014-07-01",
            "2014-03-01": "2014-05-01",
            "2014-01-01": "2014-03-01"}
for startdate, enddate in datelist.items():
    predicates.append("STARTDATE >= to_date('" + startdate + "', 'yyyy-MM-dd'" \
        + "and STARTDATE < to_date('" + enddate + "', 'yyyy-MM-dd')")
properties = {"user": db_user_name,
              "password" : password,
              "driver": driver}
df = spark.read.jdbc(url=url, table=dbtable, predicates=predicates, properties=properties)

 最后rdd的分区数量就等于predicates.length。

有一点要注意的是,驱动是放在properties里,网上一般都是连接MySQL数据库,不像oracle数据库一样需要额外的驱动。

还有数据中STARTDATE是date类型的数据,所以需要利用to_date()做数据类型转换。

参考文献:

Spark读取数据库(Mysql)的四种方式讲解

spark jdbc(mysql) 读取并发度优化

https://github.com/UrbanInstitute/sparkr-tutorials/blob/master/08_databases-with-jdbc.md

【SparkSQL】partitionColumn, lowerBound, upperBound, numPartitions的理解

Spark JDBC To MySQL

你可能感兴趣的:(spark)