需求
将任意Java对象RDD转换成DataFrame。
要做到这一点,主要需要如下两步:
Spark版本: 1.6.1
准备
推荐阅读spark源码org.apache.spark.sql.catalyst.ScalaReflection类,其中列举了大部分基础类型与SparkSQL类型的映射。
但我还是重新写了这部分功能,最重要的原因是源码只支持基本类型,对于复杂或嵌套Java类无能为力。
其次,我想支持更多的类型,且我想做到对某些类型的对象进行自定义转换。
比如我遇到的Java类中有个属性为Map<String, Object> parameters; 其中的泛型Object无法映射到任何SparkSQL类型中,
导致StructType无法构建完整,造成不得不放弃一部分数据。
但我的做法是,对泛型未指定或指定为Object的,直接调用toString方法转换为String,可以挽回一部分数据丢失。
还有一些常见的,比如需要将java.util.Date转换为java.sql.Date,将char[]转换为String的。
Type接口有一个子类和四个子接口,一个子类为java.lang.Class(最为大众所知),四个子接口为 GenericArrayType, ParameterizedType,
TypeVariable, WildcardType。
开发
开发时间大概两周,运行较为稳定。下面分享代码,发现问题欢迎指正。
外部调用的主要是两个方法
def getStructType(clazz: Class[_]): Option[StructType]
def getRow(clazz: Class[_], obj: Any): Option[Row]
完整代码
import java.lang.reflect.{ GenericArrayType, Modifier, ParameterizedType, Field } import java.lang.{ Iterable => JIterable } import java.util.{ Map => JMap } import scala.collection.JavaConversions._ import org.apache.spark.sql.Row import org.apache.spark.sql.types.{ DataType, StructField, StructType, DecimalType, DataTypes } import org.apache.spark.sql.types.DataTypes._ /** * @author yizhu.sun 2016年7月21日 */ object DataFrameReflectUtil { /** 成员变量的类型和sparkSQL类型的映射 */ val predefinedDataType: collection.mutable.Map[Class[_], DataType] = collection.mutable.Map( (classOf[Boolean], BooleanType), (classOf[java.lang.Boolean], BooleanType), (classOf[Byte], ByteType), (classOf[java.lang.Byte], ByteType), (classOf[Array[Byte]], BinaryType), (classOf[Array[java.lang.Byte]], BinaryType), (classOf[Short], ShortType), (classOf[java.lang.Short], ShortType), (classOf[Int], IntegerType), (classOf[java.lang.Integer], IntegerType), (classOf[Long], LongType), (classOf[java.lang.Long], LongType), (classOf[Float], FloatType), (classOf[java.lang.Float], FloatType), (classOf[Double], DoubleType), (classOf[java.lang.Double], DoubleType), (classOf[Char], StringType), (classOf[java.lang.Character], StringType), (classOf[Array[Char]], StringType), (classOf[Array[java.lang.Character]], StringType), (classOf[String], StringType), (classOf[java.math.BigDecimal], DecimalType.SYSTEM_DEFAULT), (classOf[java.util.Date], DateType), (classOf[java.sql.Date], DateType), (classOf[java.security.Timestamp], TimestampType), (classOf[java.util.Calendar], CalendarIntervalType), // 成员为Object类型的,都转为String (classOf[Any], StringType)) /** 类之间的转换。比如将java.util.Date转换为java.sql.Date */ private val classConverter: Map[Class[_], (Any) => _ <: Any] = Map( classOf[java.util.Date] -> ((o: Any) => new java.sql.Date(o.asInstanceOf[java.util.Date].getTime)), classOf[Char] -> ((o: Any) => o.asInstanceOf[Char].toString), classOf[java.lang.Character] -> ((o: Any) => o.asInstanceOf[java.lang.Character].toString), classOf[Array[Char]] -> ((o: Any) => new String(o.asInstanceOf[Array[Char]])), classOf[Array[java.lang.Character]] -> ((o: Any) => new String(o.asInstanceOf[Array[java.lang.Character]].map(_.charValue))), classOf[Any] -> ((o: Any) => o.toString)) /** cache of Class -> Option[StructType] */ private val structTypeCache = new org.apache.commons.collections.map.LRUMap(100) /** cache of java.lang.reflect.Type -> Option[DataType] */ private val dataTypeCache = new org.apache.commons.collections.map.LRUMap(1000) /** cache of Class -> Array[Field] */ private val classFieldsCache = collection.mutable.Map[Class[_], Array[Field]]() /** scala.collection.Map 类型的Class的cache */ private val scalaMapClassCache = collection.mutable.Set[Class[_]]() /** scala.collection.Iterable 类型的Class的cache */ private val scalaIterableClassCache = collection.mutable.Set[Class[_]]() /** java.util.Map 类型的Class的cache */ private val javaMapClassCache = collection.mutable.Set[Class[_]]() /** java.lang.Iterable 类型的Class的cache */ private val javaIterableClassCache = collection.mutable.Set[Class[_]]() // 注意在Scala中Map是Iterable的子类 def isScalaMapClass(clazz: Class[_]) = { if (scalaMapClassCache.contains(clazz)) true else if (classOf[Map[_, _]].isAssignableFrom(clazz)) { scalaMapClassCache += clazz true } else false } def isScalaIterableClass(clazz: Class[_]) = { if (scalaIterableClassCache.contains(clazz)) true else if (classOf[Iterable[_]].isAssignableFrom(clazz)) { scalaIterableClassCache += clazz true } else false } def isJavaMapClass(clazz: Class[_]) = { if (javaMapClassCache.contains(clazz)) true else if (classOf[JMap[_, _]].isAssignableFrom(clazz)) { javaMapClassCache += clazz true } else false } def isJavaIterableClass(clazz: Class[_]) = { if (javaIterableClassCache.contains(clazz)) true else if (classOf[JIterable[_]].isAssignableFrom(clazz)) { javaIterableClassCache += clazz true } else false } def getFields(clazz: Class[_]) = classFieldsCache.getOrElseUpdate(clazz, { val fields = clazz.getDeclaredFields .filterNot(f => Modifier.isTransient(f.getModifiers)) .flatMap(f => getDataType(f.getGenericType) match { case Some(_) => Some(f) case None => None }) fields.foreach(_.setAccessible(true)) fields }) /** * 根据Class对象,生成StructType对象。 */ def getStructType(clazz: Class[_]): Option[StructType] = { val cachedStructType = structTypeCache.get(clazz) if (cachedStructType == null) { val fields = getFields(clazz) val newStructType = if (fields.isEmpty) None else { val types = fields.map(f => { val dataType = getDataType(f.getGenericType).get StructField(f.getName, dataType, true) // 默认所有的字段都可能为空 }) if (types.isEmpty) None else Some(StructType(types)) } structTypeCache.put(clazz, newStructType) newStructType } else cachedStructType.asInstanceOf[Option[StructType]] } /** * 根据java.lang.reflect.Type获取org.apache.spark.sql.types.DataType * 递归处理嵌套类型 */ private def getDataType(tp: java.lang.reflect.Type): Option[DataType] = { val cachedDataType = dataTypeCache.get(tp) if (cachedDataType == null) { val newDataType = tp match { case ptp: ParameterizedType => // 带有泛型的数据类型,e.g. List[String] val clazz = ptp.getRawType.asInstanceOf[Class[_]] val rowTypes = ptp.getActualTypeArguments if (isScalaMapClass(clazz) || isJavaMapClass(clazz)) { (getDataType(rowTypes(0)), getDataType(rowTypes(1))) match { case (Some(keyType), Some(valueType)) => Some(DataTypes.createMapType(keyType, valueType, true)) case _ => None } } else if (isScalaIterableClass(clazz) || isJavaIterableClass(clazz)) { getDataType(rowTypes(0)) match { case Some(dataType) => Some(DataTypes.createArrayType(dataType, true)) case None => None } } else { getStructType(clazz) } case gatp: GenericArrayType => // 泛型数据类型的数组,e.g. Array[List[String]] getDataType(gatp.getGenericComponentType) match { case Some(dataType) => Some(DataTypes.createArrayType(dataType, true)) case None => None } case clazz: Class[_] => // 没有泛型的类型(包括没有指定泛型的Map和Collection) predefinedDataType.get(clazz) match { case Some(tp) => Some(tp) case None => if (clazz.isArray) { // 非泛型对象的数组 getDataType(clazz.getComponentType) match { case Some(dataType) => Some(DataTypes.createArrayType(dataType, true)) case None => None } } else if (isScalaMapClass(clazz) || isJavaMapClass(clazz)) { Some(DataTypes.createMapType(StringType, StringType, true)) } else if (isScalaIterableClass(clazz) || isJavaIterableClass(clazz)) { Some(DataTypes.createArrayType(StringType, true)) } else { // 一般Object类型,转换为嵌套类型 getStructType(clazz) } } case _ => throw new IllegalArgumentException("不支持 WildcardType 和 TypeVariable") } dataTypeCache.put(tp, newDataType) newDataType } else cachedDataType.asInstanceOf[Option[DataType]] } /** * 读取一行数据 */ def getRow(clazz: Class[_], obj: Any): Option[Row] = getStructType(clazz) match { case Some(_) => if (obj == null) Some(null) else Some(Row(getFields(clazz).flatMap(f => getCell(f.getGenericType, f.get(obj))): _*)) case None => None } /** * 读取单个数据 */ private def getCell(tp: java.lang.reflect.Type, value: Any): Option[Any] = tp match { case ptp: ParameterizedType => // 带有泛型的数据类型,e.g. List[String] val clazz = ptp.getRawType.asInstanceOf[Class[_]] val rowTypes = ptp.getActualTypeArguments if (isScalaMapClass(clazz)) { (getDataType(rowTypes(0)), getDataType(rowTypes(1))) match { case (Some(keyType), Some(valueType)) => if (value == null) Some(null) else Some(value.asInstanceOf[Map[Any, Any]].filterKeys(_ != null) .map { case (k, v) => getCell(rowTypes(0), k).get -> getCell(rowTypes(1), v).get }) case _ => None } } else if (isScalaIterableClass(clazz)) { getDataType(rowTypes(0)) match { case Some(_) => if (value == null) Some(null) else Some(value.asInstanceOf[Iterable[Any]].filter(_ != null).map(v => getCell(rowTypes(0), v).get).toSeq) case None => None } } else if (isJavaIterableClass(clazz)) { getDataType(rowTypes(0)) match { case Some(_) => if (value == null) Some(null) else Some(value.asInstanceOf[JIterable[Any]].filter(_ != null).map(v => getCell(rowTypes(0), v).get).toSeq) case None => None } } else if (isJavaMapClass(clazz)) { (getDataType(rowTypes(0)), getDataType(rowTypes(1))) match { case (Some(keyType), Some(valueType)) => if (value == null) Some(null) else Some(value.asInstanceOf[JMap[Any, Any]].filterKeys(_ != null) .map { case (k, v) => getCell(rowTypes(0), k).get -> getCell(rowTypes(1), v).get }) case _ => None } } else { getCell(clazz, value) } case gatp: GenericArrayType => // 泛型数据类型的数组,e.g. Array[List[String]] getDataType(gatp.getGenericComponentType) match { case Some(dataType) => Some(value.asInstanceOf[Array[Any]].map(v => getCell(gatp.getGenericComponentType, v).get).toSeq) case None => None } case clazz: Class[_] => // 没有泛型的类型(包括没有指定泛型的Map和Collection) predefinedDataType.get(clazz) match { case Some(_) => classConverter.get(clazz) match { case Some(converter) => Some(if (value == null) null else converter(value)) case None => Some(value) } case None => if (clazz.isArray) { // 非泛型对象的数组 getDataType(clazz.getComponentType) match { case Some(dataType) => if (value == null) Some(null) else Some(value.asInstanceOf[Array[_]].filter(_ != null).flatMap(v => getCell(clazz.getComponentType, v)).toSeq) case None => None } } else if (isScalaMapClass(clazz)) { Some(value.asInstanceOf[Map[Any, Any]].filterKeys(_ != null) .map { case (k, v) => getCell(classOf[Any], k).get -> getCell(classOf[Any], v).get }) } else if (isScalaIterableClass(clazz)) { Some(value.asInstanceOf[Iterable[Any]].filter(_ != null) .map(v => getCell(classOf[Any], v).get).toSeq) } else if (isJavaIterableClass(clazz)) { Some(value.asInstanceOf[JIterable[Any]].filter(_ != null) .map(v => getCell(classOf[Any], v).get).toSeq) } else if (isJavaMapClass(clazz)) { Some(value.asInstanceOf[JMap[Any, Any]].filterKeys(_ != null) .map { case (k, v) => getCell(classOf[Any], k).get -> getCell(classOf[Any], v).get }) } else { // 一般Object类型,转换为嵌套类型 getRow(clazz, value) } } case _ => throw new IllegalArgumentException("不支持 WildcardType 和 TypeVariable") } }
构建两个测试类
class TClass( val list1: List[Array[Char]], val map1: Map[String, Array[Int]], val obj1: TInnerClass) extends Serializable class TInnerClass( val date1: java.util.Date) extends Serializable
// sc: SparkContext // ssc: SQLContext val obj1 = new TClass( List(Array('1', '2', '3'), null), Map("123" -> Array(1, 2, 3), "nil" -> null), new TInnerClass(new java.util.Date)) val obj2 = new TClass( List(Array('1', '2', '3'), null), Map("empty" -> Array(), "90" -> Array(9, 0)), new TInnerClass(null)) val tClazz = classOf[TClass] val rdd = sc.makeRDD(Seq(obj1, obj2)) val rowRDD = rdd.flatMap(DataFrameReflectUtil.getRow(tClazz, _)) DataFrameReflectUtil.getStructType(tClazz) match { case Some(scheme) => val df = ssc.createDataFrame(rowRDD, scheme) df.registerTempTable("df") df.printSchema ssc.sql("select list1, map1, obj1 from df").show(false) ssc.sql("select map1['90'], map1['90'][0], date_add(obj1.date1, 1) from df").show(false) case None => println("getStructType failed") }
root |-- list1: array (nullable = true) | |-- element: string (containsNull = true) |-- map1: map (nullable = true) | |-- key: string | |-- value: array (valueContainsNull = true) | | |-- element: integer (containsNull = true) |-- obj1: struct (nullable = true) | |-- date1: date (nullable = true) +-----+------------------------------------------------------+------------+ |list1|map1 |obj1 | +-----+------------------------------------------------------+------------+ |[123]|Map(123 -> WrappedArray(1, 2, 3), nil -> null) |[2016-09-01]| |[123]|Map(empty -> WrappedArray(), 90 -> WrappedArray(9, 0))|[null] | +-----+------------------------------------------------------+------------+ +------+----+----------+ |_c0 |_c1 |_c2 | +------+----+----------+ |null |null|2016-09-02| |[9, 0]|9 |null | +------+----+----------+