abstract class DataType extends AbstractDataType {
// Type properties
def typeName: String
def sql: String
def catalogString: String
def simpleString: String
def json: String
// Type checking
def sameType(other: DataType): Boolean
def acceptsType(other: DataType): Boolean
def existsRecursively(f: DataType => Boolean): Boolean
// Default value handling
def defaultSize: Int
def asNullable: DataType
}// Numeric types
case object ByteType extends IntegralType {
def typeName: String = "byte"
def defaultSize: Int = 1
}
case object ShortType extends IntegralType {
def typeName: String = "short"
def defaultSize: Int = 2
}
case object IntegerType extends IntegralType {
def typeName: String = "integer"
def defaultSize: Int = 4
}
case object LongType extends IntegralType {
def typeName: String = "long"
def defaultSize: Int = 8
}
case object FloatType extends FractionalType {
def typeName: String = "float"
def defaultSize: Int = 4
}
case object DoubleType extends FractionalType {
def typeName: String = "double"
def defaultSize: Int = 8
}
case class DecimalType(precision: Int = 10, scale: Int = 0) extends FractionalType {
def typeName: String = "decimal"
def defaultSize: Int = if (precision <= 18) 16 else 32
// Decimal-specific operations
def bounded: Boolean = precision <= DecimalType.MAX_PRECISION
def isWiderThan(other: DecimalType): Boolean
def toCatalogString: String = s"decimal($precision,$scale)"
}
// String and binary types
case object StringType extends AtomicType {
def typeName: String = "string"
def defaultSize: Int = 20
}
case object BinaryType extends AtomicType {
def typeName: String = "binary"
def defaultSize: Int = 100
}
// Boolean type
case object BooleanType extends AtomicType {
def typeName: String = "boolean"
def defaultSize: Int = 1
}
// Temporal types
case object DateType extends AtomicType {
def typeName: String = "date"
def defaultSize: Int = 4
}
case object TimestampType extends AtomicType {
def typeName: String = "timestamp"
def defaultSize: Int = 8
}
case object TimestampNTZType extends AtomicType {
def typeName: String = "timestamp_ntz"
def defaultSize: Int = 8
}
// Interval types
case class DayTimeIntervalType(startField: Byte = 0, endField: Byte = 3) extends DataType {
def typeName: String = "interval day to second"
def defaultSize: Int = 16
}
case class YearMonthIntervalType(startField: Byte = 0, endField: Byte = 1) extends DataType {
def typeName: String = "interval year to month"
def defaultSize: Int = 4
}
// Special types
case object NullType extends DataType {
def typeName: String = "void"
def defaultSize: Int = 1
}
case object VariantType extends DataType {
def typeName: String = "variant"
def defaultSize: Int = 20
}// Array type
case class ArrayType(elementType: DataType, containsNull: Boolean = true) extends DataType {
def typeName: String = "array"
def catalogString: String = s"array<${elementType.catalogString}>"
def simpleString: String = s"array<${elementType.simpleString}>"
def defaultSize: Int = 100
// Array-specific operations
def buildFormattedString(prefix: String, stringConcat: StringConcat): Unit
def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true)
}
// Map type
case class MapType(
keyType: DataType,
valueType: DataType,
valueContainsNull: Boolean = true) extends DataType {
def typeName: String = "map"
def catalogString: String = s"map<${keyType.catalogString},${valueType.catalogString}>"
def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
def defaultSize: Int = 100
// Map-specific operations
def buildFormattedString(prefix: String, stringConcat: StringConcat): Unit
def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
}
// Struct type and field
case class StructField(
name: String,
dataType: DataType,
nullable: Boolean = true,
metadata: Metadata = Metadata.empty) extends Serializable {
// Field operations
def withName(name: String): StructField = copy(name = name)
def withDataType(dataType: DataType): StructField = copy(dataType = dataType)
def withNullability(nullable: Boolean): StructField = copy(nullable = nullable)
def withMetadata(metadata: Metadata): StructField = copy(metadata = metadata)
def withComment(comment: String): StructField = withMetadata(metadata.copy(comment))
def getComment(): Option[String] = metadata.getString("comment")
// Conversion operations
def toDDL: String
def jsonValue: JValue
}
case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] {
def typeName: String = "struct"
def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
// Schema operations
def add(field: StructField): StructType = StructType(fields :+ field)
def add(name: String, dataType: DataType): StructType
def add(name: String, dataType: DataType, nullable: Boolean): StructType
def add(name: String, dataType: DataType, nullable: Boolean, metadata: Metadata): StructType
def add(name: String, dataType: DataType, nullable: Boolean, comment: String): StructType
// Field access
def apply(name: String): StructField
def apply(names: Set[String]): StructType
def fieldIndex(name: String): Int
def indexOf(name: String): Int
def getFieldIndex(name: String): Option[Int]
def fieldNames: Array[String] = fields.map(_.name)
def names: Set[String] = fieldNames.toSet
// Schema queries
def exists(f: StructField => Boolean): Boolean = fields.exists(f)
def filter(f: StructField => Boolean): Array[StructField] = fields.filter(f)
def find(f: StructField => Boolean): Option[StructField] = fields.find(f)
def count(f: StructField => Boolean): Int = fields.count(f)
// Schema transformations
def map(f: StructField => StructField): StructType = StructType(fields.map(f))
def flatMap(f: StructField => TraversableOnce[StructField]): Array[StructField] = fields.flatMap(f)
def foreach(f: StructField => Unit): Unit = fields.foreach(f)
// Schema compatibility
def merge(that: StructType): StructType
def intersect(that: StructType): StructType
def subtract(that: StructType): StructType
// Conversion operations
def catalogString: String
def simpleString: String
def toDDL: String
def prettyJson: String
def jsonValue: JValue
def sql: String
// Collection interface
def length: Int = fields.length
def size: Int = fields.length
def iterator: Iterator[StructField] = fields.iterator
def toSeq: Seq[StructField] = fields.toSeq
def toList: List[StructField] = fields.toList
def toArray: Array[StructField] = fields
// Nullability operations
def asNullable: StructType = StructType(fields.map(f => f.copy(nullable = true)))
}// Base class for user-defined types
abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable {
// Type mapping
def sqlType: DataType
def pyUDT: String = null
// Serialization
def serialize(obj: UserType): Any
def deserialize(datum: Any): UserType
def userClass: Class[UserType]
// Type operations
def typeName: String = userClass.getSimpleName.toLowerCase(Locale.ROOT)
def catalogString: String = sqlType.catalogString
def sql: String = sqlType.sql
// Default implementations
def defaultSize: Int = sqlType.defaultSize
def asNullable: UserDefinedType[UserType] = this
// Equality and hashing
override def equals(other: Any): Boolean
override def hashCode(): Int
}
// Example UDT implementation
class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
override def pyUDT: String = "example.point.PointUDT"
override def serialize(obj: ExamplePoint): Any = {
if (obj == null) null else Array(obj.x, obj.y)
}
override def deserialize(datum: Any): ExamplePoint = {
datum match {
case null => null
case values: Seq[_] =>
val coords = values.asInstanceOf[Seq[Double]]
new ExamplePoint(coords(0), coords(1))
}
}
override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
}
// Example user class
class ExamplePoint(val x: Double, val y: Double) extends Serializable {
override def toString: String = s"Point($x, $y)"
override def equals(other: Any): Boolean = other match {
case p: ExamplePoint => x == p.x && y == p.y
case _ => false
}
override def hashCode(): Int = (x, y).hashCode()
}object DataTypes {
// Primitive type constants
val StringType: DataType = org.apache.spark.sql.types.StringType
val BinaryType: DataType = org.apache.spark.sql.types.BinaryType
val BooleanType: DataType = org.apache.spark.sql.types.BooleanType
val DateType: DataType = org.apache.spark.sql.types.DateType
val TimestampType: DataType = org.apache.spark.sql.types.TimestampType
val CalendarIntervalType: DataType = org.apache.spark.sql.types.CalendarIntervalType
val NullType: DataType = org.apache.spark.sql.types.NullType
val ByteType: DataType = org.apache.spark.sql.types.ByteType
val ShortType: DataType = org.apache.spark.sql.types.ShortType
val IntegerType: DataType = org.apache.spark.sql.types.IntegerType
val LongType: DataType = org.apache.spark.sql.types.LongType
val FloatType: DataType = org.apache.spark.sql.types.FloatType
val DoubleType: DataType = org.apache.spark.sql.types.DoubleType
// Factory methods
def createArrayType(elementType: DataType): ArrayType
def createArrayType(elementType: DataType, containsNull: Boolean): ArrayType
def createMapType(keyType: DataType, valueType: DataType): MapType
def createMapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean): MapType
def createStructType(fields: Array[StructField]): StructType
def createStructType(fields: java.util.List[StructField]): StructType
def createStructField(name: String, dataType: DataType, nullable: Boolean): StructField
def createStructField(name: String, dataType: DataType, nullable: Boolean, metadata: Metadata): StructField
def createDecimalType(): DecimalType
def createDecimalType(precision: Int): DecimalType
def createDecimalType(precision: Int, scale: Int): DecimalType
}
object DataType {
// Type parsing and conversion
def fromJson(json: String): DataType
def fromDDL(ddl: String): DataType
def fromCatalystType(catalystType: org.apache.spark.sql.catalyst.types.DataType): DataType
// Type utilities
def equalsIgnoreCompatibleNullability(left: DataType, right: DataType): Boolean
def equalsStructurally(left: DataType, right: DataType): Boolean
def canUpCast(from: DataType, to: DataType): Boolean
}
// Metadata for schema fields
case class Metadata(map: Map[String, Any] = Map.empty) extends Serializable {
def contains(key: String): Boolean = map.contains(key)
def getLong(key: String): Long = map(key).asInstanceOf[Long]
def getDouble(key: String): Double = map(key).asInstanceOf[Double]
def getBoolean(key: String): Boolean = map(key).asInstanceOf[Boolean]
def getString(key: String): String = map(key).asInstanceOf[String]
def getMetadata(key: String): Metadata = map(key).asInstanceOf[Metadata]
def getStringArray(key: String): Array[String] = map(key).asInstanceOf[Array[String]]
def getLongArray(key: String): Array[Long] = map(key).asInstanceOf[Array[Long]]
def getDoubleArray(key: String): Array[Double] = map(key).asInstanceOf[Array[Double]]
def getBooleanArray(key: String): Array[Boolean] = map(key).asInstanceOf[Array[Boolean]]
def getMetadataArray(key: String): Array[Metadata] = map(key).asInstanceOf[Array[Metadata]]
def putString(key: String, value: String): Metadata
def putLong(key: String, value: Long): Metadata
def putDouble(key: String, value: Double): Metadata
def putBoolean(key: String, value: Boolean): Metadata
def putMetadata(key: String, value: Metadata): Metadata
def putStringArray(key: String, value: Array[String]): Metadata
def putLongArray(key: String, value: Array[Long]): Metadata
def putDoubleArray(key: String, value: Array[Double]): Metadata
def putBooleanArray(key: String, value: Array[Boolean]): Metadata
def putMetadataArray(key: String, value: Array[Metadata]): Metadata
def remove(key: String): Metadata
def copy(map: Map[String, Any]): Metadata
def json: String
def prettyJson: String
}
object Metadata {
val empty: Metadata = Metadata()
def fromJson(json: String): Metadata
}import org.apache.spark.sql.types._
import org.apache.spark.sql.{SparkSession, Row}
val spark = SparkSession.builder().appName("DataTypes Demo").getOrCreate()
// Create schema with primitive types
val primitiveSchema = StructType(Array(
StructField("id", IntegerType, nullable = false),
StructField("name", StringType, nullable = false),
StructField("age", ByteType, nullable = true),
StructField("height", FloatType, nullable = true),
StructField("salary", DecimalType(10, 2), nullable = true),
StructField("is_active", BooleanType, nullable = false),
StructField("hire_date", DateType, nullable = true),
StructField("last_login", TimestampType, nullable = true),
StructField("profile_image", BinaryType, nullable = true)
))
// Create DataFrame with typed data
val employeeData = Seq(
Row(1, "Alice Johnson", 25.toByte, 5.6f,
new java.math.BigDecimal("75000.50"), true,
java.sql.Date.valueOf("2023-01-15"),
java.sql.Timestamp.valueOf("2023-12-01 10:30:00"),
"image_data".getBytes()),
Row(2, "Bob Smith", 30.toByte, 6.0f,
new java.math.BigDecimal("85000.00"), false,
java.sql.Date.valueOf("2022-06-20"),
java.sql.Timestamp.valueOf("2023-11-28 14:22:15"),
null)
)
val employeeDF = spark.createDataFrame(
spark.sparkContext.parallelize(employeeData),
primitiveSchema
)
employeeDF.printSchema()
employeeDF.show()// Array type example
val arraySchema = StructType(Array(
StructField("person_id", IntegerType, nullable = false),
StructField("skills", ArrayType(StringType, containsNull = false), nullable = true),
StructField("ratings", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("certifications", ArrayType(StringType, containsNull = true), nullable = true)
))
val arrayData = Seq(
Row(1, Array("Java", "Scala", "Python"), Array(9, 8, 7), Array("AWS", "Kubernetes")),
Row(2, Array("JavaScript", "React", "Node.js"), Array(8, 9, 8), null),
Row(3, Array("SQL", "Tableau"), Array(9, 7), Array("Tableau Desktop", null))
)
val skillsDF = spark.createDataFrame(
spark.sparkContext.parallelize(arrayData),
arraySchema
)
// Map type example
val mapSchema = StructType(Array(
StructField("employee_id", IntegerType, nullable = false),
StructField("performance_scores", MapType(StringType, DoubleType, valueContainsNull = false), nullable = true),
StructField("project_hours", MapType(StringType, IntegerType, valueContainsNull = true), nullable = true)
))
val mapData = Seq(
Row(1, Map("Q1" -> 8.5, "Q2" -> 9.0, "Q3" -> 8.7, "Q4" -> 9.2), Map("ProjectA" -> 120, "ProjectB" -> 80)),
Row(2, Map("Q1" -> 7.8, "Q2" -> 8.2, "Q3" -> 8.5), Map("ProjectA" -> 100, "ProjectC" -> null)),
Row(3, null, Map("ProjectB" -> 150))
)
val performanceDF = spark.createDataFrame(
spark.sparkContext.parallelize(mapData),
mapSchema
)// Complex nested schema
val addressSchema = StructType(Array(
StructField("street", StringType, nullable = false),
StructField("city", StringType, nullable = false),
StructField("state", StringType, nullable = false),
StructField("zipcode", StringType, nullable = false),
StructField("coordinates", StructType(Array(
StructField("latitude", DoubleType, nullable = false),
StructField("longitude", DoubleType, nullable = false)
)), nullable = true)
))
val personSchema = StructType(Array(
StructField("id", IntegerType, nullable = false),
StructField("personal_info", StructType(Array(
StructField("first_name", StringType, nullable = false),
StructField("last_name", StringType, nullable = false),
StructField("birth_date", DateType, nullable = true),
StructField("ssn", StringType, nullable = true)
)), nullable = false),
StructField("addresses", ArrayType(addressSchema, containsNull = false), nullable = true),
StructField("emergency_contacts", MapType(StringType, StructType(Array(
StructField("name", StringType, nullable = false),
StructField("phone", StringType, nullable = false),
StructField("relationship", StringType, nullable = false)
)), valueContainsNull = false), nullable = true)
))
// Create complex nested data
import scala.collection.mutable
val nestedData = Seq(
Row(
1,
Row("Alice", "Johnson", java.sql.Date.valueOf("1995-03-15"), "123-45-6789"),
Array(
Row("123 Main St", "Seattle", "WA", "98101", Row(47.6062, -122.3321)),
Row("456 Oak Ave", "Portland", "OR", "97201", Row(45.5152, -122.6784))
),
Map(
"primary" -> Row("John Johnson", "555-1234", "spouse"),
"secondary" -> Row("Mary Johnson", "555-5678", "sister")
)
)
)
val complexDF = spark.createDataFrame(
spark.sparkContext.parallelize(nestedData),
personSchema
)
// Access nested fields
import org.apache.spark.sql.functions._
val extractedData = complexDF.select(
$"id",
$"personal_info.first_name".as("first_name"),
$"personal_info.last_name".as("last_name"),
$"addresses"(0).getField("city").as("primary_city"),
$"addresses"(0).getField("coordinates.latitude").as("primary_lat"),
$"emergency_contacts".getItem("primary").getField("name").as("primary_contact")
)// Define a Point class
case class Point(x: Double, y: Double) {
def distance(other: Point): Double = {
math.sqrt(math.pow(x - other.x, 2) + math.pow(y - other.y, 2))
}
}
// Define UDT for Point
class PointUDT extends UserDefinedType[Point] {
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
override def serialize(obj: Point): Any = {
if (obj == null) null else Array(obj.x, obj.y)
}
override def deserialize(datum: Any): Point = {
datum match {
case null => null
case coords: Seq[_] =>
val values = coords.asInstanceOf[Seq[Double]]
Point(values(0), values(1))
}
}
override def userClass: Class[Point] = classOf[Point]
}
// Register the UDT
UDTRegistration.register(classOf[Point].getName, classOf[PointUDT].getName)
// Use in schema
val locationSchema = StructType(Array(
StructField("location_id", IntegerType, nullable = false),
StructField("name", StringType, nullable = false),
StructField("coordinates", new PointUDT(), nullable = false)
))
val locationData = Seq(
Row(1, "Headquarters", Array(47.6062, -122.3321)),
Row(2, "Branch Office", Array(40.7128, -74.0060))
)
val locationsDF = spark.createDataFrame(
spark.sparkContext.parallelize(locationData),
locationSchema
)// Schema introspection
val schema = employeeDF.schema
println(s"Schema has ${schema.fields.length} fields")
println(s"Field names: ${schema.fieldNames.mkString(", ")}")
// Check if field exists
if (schema.fieldNames.contains("salary")) {
val salaryField = schema("salary")
println(s"Salary field: ${salaryField.name}, Type: ${salaryField.dataType}, Nullable: ${salaryField.nullable}")
}
// Schema transformations
val modifiedSchema = schema
.add(StructField("full_name", StringType, nullable = false))
.add(StructField("years_experience", IntegerType, nullable = true))
// Create new schema programmatically
val newSchema = StructType(
schema.fields.map { field =>
field.name match {
case "salary" => field.copy(dataType = DecimalType(12, 2)) // Increase precision
case "age" => field.copy(dataType = IntegerType) // Change from Byte to Int
case _ => field
}
}
)
// Schema compatibility checking
def schemasCompatible(schema1: StructType, schema2: StructType): Boolean = {
schema1.fieldNames.toSet == schema2.fieldNames.toSet &&
schema1.zip(schema2).forall { case (field1, field2) =>
field1.name == field2.name &&
DataType.equalsIgnoreCompatibleNullability(field1.dataType, field2.dataType)
}
}
// DDL string generation
val ddlString = schema.toDDL
println(s"DDL: $ddlString")
// JSON representation
val jsonSchema = schema.prettyJson
println(s"JSON Schema:\n$jsonSchema")import org.apache.spark.sql.functions._
// Type casting examples
val typedConversions = employeeDF.select(
$"id",
$"name",
$"age".cast(IntegerType).as("age_as_int"),
$"height".cast(DoubleType).as("height_as_double"),
$"salary".cast(StringType).as("salary_as_string"),
$"is_active".cast(IntegerType).as("is_active_as_int"),
date_format($"hire_date", "yyyy-MM-dd").as("hire_date_formatted"),
$"hire_date".cast(StringType).as("hire_date_as_string")
)
// Safe type conversions with validation
val safeConversions = employeeDF.select(
$"id",
$"name",
when($"age".isNotNull, $"age".cast(IntegerType)).as("safe_age"),
when($"salary".isNotNull && $"salary" > 0, $"salary").as("valid_salary"),
coalesce($"height", lit(0.0)).cast(DoubleType).as("height_with_default")
)
// Complex type operations
val arrayOperations = skillsDF.select(
$"person_id",
size($"skills").as("num_skills"),
array_contains($"skills", "Java").as("knows_java"),
sort_array($"skills").as("sorted_skills"),
$"skills"(0).as("primary_skill"), // First element
slice($"skills", 2, 2).as("middle_skills") // Elements 2-3
)
val mapOperations = performanceDF.select(
$"employee_id",
map_keys($"performance_scores").as("quarters"),
map_values($"performance_scores").as("scores"),
$"performance_scores".getItem("Q1").as("q1_score"),
size($"performance_scores").as("num_quarters")
)// Create fields with metadata
val metadataExample = Metadata.empty
.putString("description", "Employee annual salary in USD")
.putString("format", "currency")
.putBoolean("sensitive", true)
.putLong("version", 1L)
val schemaWithMetadata = StructType(Array(
StructField("id", IntegerType, nullable = false,
Metadata.empty.putString("description", "Unique employee identifier")),
StructField("name", StringType, nullable = false,
Metadata.empty.putString("description", "Full employee name").putBoolean("pii", true)),
StructField("salary", DecimalType(10, 2), nullable = true, metadataExample),
StructField("department", StringType, nullable = true,
Metadata.empty
.putString("description", "Employee department")
.putStringArray("valid_values", Array("Engineering", "Sales", "Marketing", "HR")))
))
// Access metadata
val salaryField = schemaWithMetadata("salary")
val isSensitive = salaryField.metadata.getBoolean("sensitive")
val description = salaryField.metadata.getString("description")
val version = salaryField.metadata.getLong("version")
println(s"Salary field - Sensitive: $isSensitive, Description: $description, Version: $version")
// Schema documentation generation
def generateSchemaDocumentation(schema: StructType): String = {
schema.fields.map { field =>
val description = if (field.metadata.contains("description")) {
field.metadata.getString("description")
} else {
"No description available"
}
val nullable = if (field.nullable) "nullable" else "required"
s"${field.name} (${field.dataType.simpleString}, $nullable): $description"
}.mkString("\n")
}
val documentation = generateSchemaDocumentation(schemaWithMetadata)
println(s"Schema Documentation:\n$documentation")