Custom Spark UDFs with GeoBrix
Learn how to build custom Spark User-Defined Functions (UDFs) using GeoBrix's execute methods for specialized geospatial operations.
Understanding Execute vs Eval
GeoBrix provides two interfaces for accessing functionality:
Eval Methods (Standard - Spark Expressions)
Used by GeoBrix's registered Spark functions:
from databricks.labs.gbx.rasterx import functions as rx
rx.register(spark)
rasters = spark.read.format("gdal").load(
_RASTER_PATH
)
df = rasters.select(rx.rst_boundingbox("tile").alias("bbox"))
df.limit(2).show(truncate=40)
Example output
+--------------------+
|bbox |
+--------------------+
|POLYGON ((...)) |
|... |
+--------------------+
Characteristics:
- Operates on Spark Columns
- Integrated with Catalyst optimizer
- Automatic serialization/deserialization
- Best for standard operations
Execute Methods (Advanced - Direct GDAL)
Used for building custom UDFs:
Execute Methods Example
import com.databricks.labs.gbx.rasterx.expressions.accessors.RST_BoundingBox
import org.gdal.gdal.Dataset
// Direct GDAL dataset access
val bbox = RST_BoundingBox.execute(dataset)
Characteristics:
- Operates on GDAL Dataset/Band objects directly
- Full control over GDAL operations
- Building blocks for custom logic
- Best for specialized operations
Why Use Execute Methods?
Use Cases
-
Custom Business Logic
- Apply domain-specific rules
- Combine multiple operations
- Implement proprietary algorithms
-
Specialized GDAL Operations
- Access GDAL features not exposed in standard API
- Fine-tune GDAL parameters
- Implement complex workflows
-
Performance Optimization
- Batch multiple operations
- Minimize dataset loading
- Custom caching strategies
-
Integration Requirements
- Connect with external systems
- Custom format handling
- Specialized metadata extraction
Basic UDF Pattern
Python UDF Example
import json
# Import GeoBrix execute methods (via Py4J bridge)
# from databricks.labs.gbx.rasterx.expressions import accessors
@udf(MapType(StringType(), StringType()))
def extract_custom_metadata(tile_binary):
"""
Extract custom metadata from raster tile
"""
try:
# Load GDAL dataset from binary
# This is simplified - actual implementation needs proper deserialization
# from databricks.labs.gbx.rasterx.gdal import GDALManager
# Get dataset handle
# dataset = load_dataset_from_tile(tile_binary)
# Use execute methods
metadata = {}
# metadata["format"] = accessors.RST_Format.execute(dataset)
# metadata["width"] = str(accessors.RST_Width.execute(dataset))
# metadata["height"] = str(accessors.RST_Height.execute(dataset))
# Add custom logic
metadata["aspect_ratio"] = "1.0" # Placeholder
# Clean up
# dataset.delete()
return metadata
except Exception as e:
return {"error": str(e)}
rasters = spark.read.format("gdal").load(
_RASTER_PATH
)
enriched = rasters.withColumn("custom_metadata", extract_custom_metadata("tile"))
enriched.select("path", "custom_metadata").limit(2).show(truncate=30)
Example output
+--------------------+------------------+
|path |custom_metadata |
+--------------------+------------------+
|.../nyc_sentinel2...|{aspect_ratio=1.0}|
|... |... |
+--------------------+------------------+
Scala UDF Example
Scala UDF Example
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import com.databricks.labs.gbx.rasterx.expressions.accessors._
import com.databricks.labs.gbx.rasterx.gdal.RasterDriver
import org.gdal.gdal.Dataset
object CustomRasterUDFs {
/**
* Extract custom statistics from raster
*/
def customRasterStats: UserDefinedFunction = udf((tileBytes: Array[Byte]) => {
try {
// Read Dataset from binary raster data
val ds: Dataset = RasterDriver.readFromBytes(tileBytes, Map.empty[String, String])
// Use execute methods to get statistics
val width = RST_Width.execute(ds)
val height = RST_Height.execute(ds)
val numBands = RST_NumBands.execute(ds)
// Custom calculation
val totalPixels = width * height * numBands
val pixelWidth = RST_PixelWidth.execute(ds)
val pixelHeight = RST_PixelHeight.execute(ds)
val coverage = width * pixelWidth * height * pixelHeight
// Clean up
RasterDriver.releaseDataset(ds)
// Return custom result
Map(
"total_pixels" -> totalPixels.toString,
"coverage_sqm" -> coverage.toString,
"pixel_density" -> (totalPixels / coverage).toString
)
} catch {
case e: Exception => Map("error" -> e.getMessage)
}
})
/**
* Custom bounding box with buffer
*/
def boundingBoxWithBuffer(bufferMeters: Double): UserDefinedFunction =
udf((tileBytes: Array[Byte]) => {
try {
// Read Dataset from binary raster data
val ds: Dataset = RasterDriver.readFromBytes(tileBytes, Map.empty[String, String])
// Get bounding box using execute
val bbox = RST_BoundingBox.execute(ds)
// Apply buffer (custom logic)
val buffered = bbox.buffer(bufferMeters)
// Clean up
RasterDriver.releaseDataset(ds)
// Convert to WKB
val wkb = buffered.getBinary
wkb
} catch {
case e: Exception => Array.empty[Byte]
}
})
/**
* Filter rasters by custom criteria
*/
def meetsQualityCriteria: UserDefinedFunction = udf((tileBytes: Array[Byte]) => {
try {
// Read Dataset from binary raster data
val ds: Dataset = RasterDriver.readFromBytes(tileBytes, Map.empty[String, String])
// Multiple execute calls for criteria
val width = RST_Width.execute(ds)
val height = RST_Height.execute(ds)
val numBands = RST_NumBands.execute(ds)
val band = ds.GetRasterBand(1)
val noData = RST_GetNoData.execute(band)
// Custom quality logic
val validSize = width >= 512 && height >= 512
val hasMultipleBands = numBands >= 3
val hasNoDataValue = noData.isDefined
validSize && hasMultipleBands && hasNoDataValue
} catch {
case _: Exception => false
}
})
}
Example output
+--------------------+------------------------------------------+
|path |custom_stats |
+--------------------+------------------------------------------+
|.../raster.tif |{total_pixels=..., coverage_sqm=..., ...} |
+--------------------+------------------------------------------+
Real-World Examples
Example 1: Custom Cloud Mask
Custom Cloud Mask Example
import com.databricks.labs.gbx.rasterx.expressions.accessors._
import org.apache.spark.sql.functions.udf
/**
* Apply custom cloud masking logic based on multiple bands
*/
def applyCloudMask: UserDefinedFunction = udf((tileBytes: Array[Byte]) => {
// Read Dataset from binary raster data
val ds: Dataset = RasterDriver.readFromBytes(tileBytes, Map.empty[String, String])
try {
// Get band data using execute methods
val band1 = ds.GetRasterBand(1)
val band2 = ds.GetRasterBand(2)
val band3 = ds.GetRasterBand(3)
val width = RST_Width.execute(ds)
val height = RST_Height.execute(ds)
// Read pixel data
val pixels1 = band1.ReadRaster(0, 0, width, height)
val pixels2 = band2.ReadRaster(0, 0, width, height)
val pixels3 = band3.ReadRaster(0, 0, width, height)
// Apply custom cloud detection algorithm
// (simplified example - actual algorithm would be more complex)
val cloudMask = detectClouds(pixels1, pixels2, pixels3)
// Create new raster with cloud mask applied
val maskedRaster = applyMask(ds, cloudMask)
// Write to bytes and return
val result = RasterDriver.writeToBytes(maskedRaster, Map.empty[String, String])
RasterDriver.releaseDataset(maskedRaster)
result
} finally {
RasterDriver.releaseDataset(ds)
}
})
Example output
+--------------------+-----------+
|path |masked_tile|
+--------------------+-----------+
|.../multiband.tif |[BINARY] |
+--------------------+-----------+
Example 2: Multi-Temporal Analysis
Multi-Temporal Analysis Example
/**
* Compare rasters from different time periods
*/
def calculateNDVIChange: UserDefinedFunction =
udf((before: Array[Byte], after: Array[Byte]) => {
// Read Datasets from binary raster data
val dsBefore: Dataset = RasterDriver.readFromBytes(before, Map.empty[String, String])
val dsAfter: Dataset = RasterDriver.readFromBytes(after, Map.empty[String, String])
try {
// Extract NIR and Red bands using execute methods
val nirBefore = dsBefore.GetRasterBand(4)
val redBefore = dsBefore.GetRasterBand(3)
val nirAfter = dsAfter.GetRasterBand(4)
val redAfter = dsAfter.GetRasterBand(3)
// Calculate NDVI for both periods
val ndviBefore = calculateNDVI(nirBefore, redBefore)
val ndviAfter = calculateNDVI(nirAfter, redAfter)
// Calculate change
val change = ndviAfter - ndviBefore
// Return statistics
Map(
"mean_change" -> change.mean.toString,
"max_gain" -> change.max.toString,
"max_loss" -> change.min.toString,
"percent_improved" -> (change.filter(_ > 0.1).count.toDouble / change.size * 100).toString
)
} finally {
RasterDriver.releaseDataset(dsBefore)
RasterDriver.releaseDataset(dsAfter)
}
})
Example output
+--------------------+--------------------+----------------------------------+
|before |after |change_stats |
+--------------------+--------------------+----------------------------------+
|.../t1.tif |.../t2.tif |{mean_change=..., percent_impr...}|
+--------------------+--------------------+----------------------------------+
Example 3: Custom Format Handler
Custom Format Handler Example
/**
* Handle proprietary raster format
*/
def processProprietaryFormat: UserDefinedFunction =
udf((filePath: String) => {
try {
// Use GDAL's flexible driver system
val ds = gdal.Open(filePath)
// Extract metadata using execute methods
val metadata = RST_MetaData.execute(ds)
// Apply domain-specific interpretation
val calibrationFactor = metadata.getOrElse("CAL_FACTOR", "1.0").toDouble
val sensorType = metadata.getOrElse("SENSOR", "unknown")
// Get band data
val band = ds.GetRasterBand(1)
val width = RST_Width.execute(ds)
val height = RST_Height.execute(ds)
// Read and calibrate
val pixels = band.ReadRaster(0, 0, width, height)
val calibrated = applyCalibration(pixels, calibrationFactor, sensorType)
// Create calibrated raster
val output = createCalibratedDataset(calibrated, width, height, ds)
// Write to bytes
val result = RasterDriver.writeToBytes(output, Map.empty[String, String])
RasterDriver.releaseDataset(ds)
RasterDriver.releaseDataset(output)
result
} catch {
case e: Exception =>
log.error(s"Failed to process $filePath: ${e.getMessage}")
Array.empty[Byte]
}
})
Example output
+------------------+---------------+
|file_path |calibrated_tile|
+------------------+---------------+
|/data/custom.xyz |[BINARY] |
+------------------+---------------+
Best Practices
1. Resource Management
Always clean up GDAL resources:
Resource Management Pattern
def safeExecute[T](f: Dataset => T): UserDefinedFunction = udf((bytes: Array[Byte]) => {
// Read Dataset from binary raster data
val ds: Dataset = RasterDriver.readFromBytes(bytes, Map.empty[String, String])
try {
f(ds)
} finally {
RasterDriver.releaseDataset(ds) // Always clean up!
}
})
Example output
// safeExecute ensures ds is always released
// Example: width = safeExecute(ds => RST_Width.execute(ds))(bytes)
2. Error Handling
Wrap execute calls in try-catch:
Error Handling Pattern
@udf
def robustUDF(bytes: Array[Byte]): Option[String] = {
try {
val ds = loadDataset(bytes)
val result = RST_Format.execute(ds)
ds.delete()
Some(result)
} catch {
case e: Exception =>
log.warn(s"UDF failed: ${e.getMessage}")
None
}
}
Example output
+--------------------+------------------+
|tile |result |
+--------------------+------------------+
|[BINARY] |Some(GTiff) |
|[BINARY] |None |
+--------------------+------------------+
3. Performance Considerations
Batch operations when possible:
Performance Optimization Pattern
def efficientBatchUDF: UserDefinedFunction = udf((bytes: Array[Byte]) => {
val ds = loadDataset(bytes)
try {
// Single dataset load, multiple operations
val results = Map(
"format" -> RST_Format.execute(ds),
"width" -> RST_Width.execute(ds).toString,
"height" -> RST_Height.execute(ds).toString,
"bands" -> RST_NumBands.execute(ds).toString,
"srid" -> RST_SRID.execute(ds).toString
)
results
} finally {
ds.delete()
}
})
Example output
+--------------------+----------------------------------------+
|tile |results |
+--------------------+----------------------------------------+
|[BINARY] |{format=GTiff, width=1024, height=1...} |
+--------------------+----------------------------------------+
4. Type Safety
Use proper Spark SQL types:
Type Safety Pattern
import org.apache.spark.sql.types._
val schema = StructType(Seq(
StructField("width", IntegerType, nullable = false),
StructField("height", IntegerType, nullable = false),
StructField("aspect_ratio", DoubleType, nullable = false)
))
spark.udf.register("raster_dims",
(bytes: Array[Byte]) => {
val ds = loadDataset(bytes)
val w = RST_Width.execute(ds)
val h = RST_Height.execute(ds)
ds.delete()
(w, h, w.toDouble / h.toDouble)
},
schema
)
Example output
+--------------------+-----+------+------------+
|tile |width|height|aspect_ratio|
+--------------------+-----+------+------------+
|[BINARY] |1024 |1024 |1.0 |
+--------------------+-----+------+------------+
Common Patterns
One-copy examples: code in docs/tests/python/advanced/custom_udfs.py, tested by test_custom_udfs.py.
Pattern: Conditional Processing
Branch logic based on raster properties (e.g. band count):
from databricks.labs.gbx.rasterx import functions as rx
rx.register(spark)
rasters = spark.read.format("gdal").load(_RASTER_PATH)
# Branch by band count: different path for multiband vs single band
with_band_count = rasters.select(
rx.rst_numbands("tile").alias("num_bands"),
rx.rst_width("tile").alias("width"),
rx.rst_height("tile").alias("height"),
)
with_band_count.limit(2).show()
Example output
+---------+-----+------+
|num_bands|width|height|
+---------+-----+------+
|1 |10980|10980 |
|... |... |... |
+---------+-----+------+
Pattern: Chained Processing
Chain multiple operations (metadata, validate, process) with a single load:
from databricks.labs.gbx.rasterx import functions as rx
rx.register(spark)
rasters = spark.read.format("gdal").load(_RASTER_PATH)
# Step 1: metadata; Step 2: dimensions; Step 3: optional clip/transform
result = rasters.select(
rx.rst_boundingbox("tile").alias("bbox"),
rx.rst_width("tile").alias("width"),
rx.rst_height("tile").alias("height"),
)
result.limit(2).show(truncate=30)
Example output
+--------------------+-----+------+
|bbox |width|height|
+--------------------+-----+------+
|POLYGON ((...)) |10980|10980 |
|... |... |... |
+--------------------+-----+------+
Next Steps
- GDAL CLI Integration - Preprocessing with GDAL utilities
- Library Integration - Connect with rasterio, xarray
- Examples - More real-world patterns
- Test Cases - Browse execute method examples