diff --git a/Cargo.lock b/Cargo.lock index acc8b7aa7fe4..dd00c5d4dc51 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2287,12 +2287,18 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-catalog", "datafusion-common", + "datafusion-datasource", "datafusion-execution", "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", + "datafusion-functions-window", "datafusion-physical-expr", "datafusion-physical-expr-common", + "datafusion-physical-plan", "datafusion-proto", "datafusion-proto-common", "doc-comment", diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index a06a9cf1839d..8ef7e23b7e8b 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -47,12 +47,15 @@ arrow-schema = { workspace = true } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } +datafusion-catalog = { workspace = true } datafusion-common = { workspace = true } +datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +datafusion-physical-plan = { workspace = true } datafusion-proto = { workspace = true } datafusion-proto-common = { workspace = true } futures = { workspace = true } @@ -63,6 +66,10 @@ tokio = { workspace = true } [dev-dependencies] datafusion = { workspace = true, default-features = false, features = ["sql"] } +datafusion-functions = { workspace = true } +datafusion-functions-aggregate = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } +datafusion-functions-window = { workspace = true } doc-comment = { workspace = true } [features] diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index 023869a6c494..8ee67d7e1823 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -301,7 +301,7 @@ impl ExecutionPlan for ForeignExecutionPlan { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index fbb45a8028fc..935c2fc504d2 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -34,6 +34,7 @@ pub mod expr; pub mod insert_op; pub mod physical_expr; pub mod plan_properties; +pub mod proto; pub mod record_batch_stream; pub mod schema_provider; pub mod session_config; diff --git a/datafusion/ffi/src/proto/logical_extension_codec.rs b/datafusion/ffi/src/proto/logical_extension_codec.rs new file mode 100644 index 000000000000..928a7c5bdb72 --- /dev/null +++ b/datafusion/ffi/src/proto/logical_extension_codec.rs @@ -0,0 +1,708 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; +use std::sync::Arc; + +use abi_stable::std_types::{RResult, RSlice, RStr, RVec}; +use abi_stable::StableAbi; +use arrow::datatypes::SchemaRef; +use datafusion_catalog::TableProvider; +use datafusion_common::error::Result; +use datafusion_common::{not_impl_err, TableReference}; +use datafusion_datasource::file_format::FileFormatFactory; +use datafusion_execution::TaskContext; +use datafusion_expr::{ + AggregateUDF, AggregateUDFImpl, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, + WindowUDF, WindowUDFImpl, +}; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use tokio::runtime::Handle; + +use crate::arrow_wrappers::WrappedSchema; +use crate::execution::FFI_TaskContextProvider; +use crate::table_provider::FFI_TableProvider; +use crate::udaf::FFI_AggregateUDF; +use crate::udf::FFI_ScalarUDF; +use crate::udwf::FFI_WindowUDF; +use crate::util::FFIResult; +use crate::{df_result, rresult_return}; + +/// A stable struct for sharing [`LogicalExtensionCodec`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_LogicalExtensionCodec { + /// Decode bytes into a table provider. + try_decode_table_provider: unsafe extern "C" fn( + &Self, + buf: RSlice, + table_ref: RStr, + schema: WrappedSchema, + ) -> FFIResult, + + /// Encode a table provider into bytes. + try_encode_table_provider: unsafe extern "C" fn( + &Self, + table_ref: RStr, + node: FFI_TableProvider, + ) -> FFIResult>, + + /// Decode bytes into a user defined scalar function. + try_decode_udf: unsafe extern "C" fn( + &Self, + name: RStr, + buf: RSlice, + ) -> FFIResult, + + /// Encode a user defined scalar function into bytes. + try_encode_udf: + unsafe extern "C" fn(&Self, node: FFI_ScalarUDF) -> FFIResult>, + + /// Decode bytes into a user defined aggregate function. + try_decode_udaf: unsafe extern "C" fn( + &Self, + name: RStr, + buf: RSlice, + ) -> FFIResult, + + /// Encode a user defined aggregate function into bytes. + try_encode_udaf: + unsafe extern "C" fn(&Self, node: FFI_AggregateUDF) -> FFIResult>, + + /// Decode bytes into a user defined window function. + try_decode_udwf: unsafe extern "C" fn( + &Self, + name: RStr, + buf: RSlice, + ) -> FFIResult, + + /// Encode a user defined window function into bytes. + try_encode_udwf: + unsafe extern "C" fn(&Self, node: FFI_WindowUDF) -> FFIResult>, + + task_ctx_provider: FFI_TaskContextProvider, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Return the major DataFusion version number of this provider. + pub version: unsafe extern "C" fn() -> u64, + + /// Internal data. This is only to be accessed by the provider of the plan. + /// A [`ForeignLogicalExtensionCodec`] should never attempt to access this data. + pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> usize, +} + +unsafe impl Send for FFI_LogicalExtensionCodec {} +unsafe impl Sync for FFI_LogicalExtensionCodec {} + +struct LogicalExtensionCodecPrivateData { + provider: Arc, + runtime: Option, +} + +impl FFI_LogicalExtensionCodec { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const LogicalExtensionCodecPrivateData; + unsafe { &(*private_data).provider } + } + + fn runtime(&self) -> &Option { + let private_data = self.private_data as *const LogicalExtensionCodecPrivateData; + unsafe { &(*private_data).runtime } + } + + fn task_ctx(&self) -> Result> { + (&self.task_ctx_provider).try_into() + } +} + +unsafe extern "C" fn try_decode_table_provider_fn_wrapper( + codec: &FFI_LogicalExtensionCodec, + buf: RSlice, + table_ref: RStr, + schema: WrappedSchema, +) -> FFIResult { + let ctx = rresult_return!(codec.task_ctx()); + let runtime = codec.runtime().clone(); + let codec = codec.inner(); + let table_ref = TableReference::from(table_ref.as_str()); + let schema: SchemaRef = schema.into(); + + let table_provider = rresult_return!(codec.try_decode_table_provider( + buf.as_ref(), + &table_ref, + schema, + ctx.as_ref() + )); + + RResult::ROk(FFI_TableProvider::new(table_provider, true, runtime)) +} + +unsafe extern "C" fn try_encode_table_provider_fn_wrapper( + codec: &FFI_LogicalExtensionCodec, + table_ref: RStr, + node: FFI_TableProvider, +) -> FFIResult> { + let table_ref = TableReference::from(table_ref.as_str()); + let table_provider: Arc = (&node).into(); + let codec = codec.inner(); + + let mut bytes = Vec::new(); + rresult_return!(codec.try_encode_table_provider( + &table_ref, + table_provider, + &mut bytes + )); + + RResult::ROk(bytes.into()) +} + +unsafe extern "C" fn try_decode_udf_fn_wrapper( + codec: &FFI_LogicalExtensionCodec, + name: RStr, + buf: RSlice, +) -> FFIResult { + let codec = codec.inner(); + + let udf = rresult_return!(codec.try_decode_udf(name.as_str(), buf.as_ref())); + let udf = FFI_ScalarUDF::from(udf); + + RResult::ROk(udf) +} + +unsafe extern "C" fn try_encode_udf_fn_wrapper( + codec: &FFI_LogicalExtensionCodec, + node: FFI_ScalarUDF, +) -> FFIResult> { + let codec = codec.inner(); + let node: Arc = (&node).into(); + let node = ScalarUDF::new_from_shared_impl(node); + + let mut bytes = Vec::new(); + rresult_return!(codec.try_encode_udf(&node, &mut bytes)); + + RResult::ROk(bytes.into()) +} + +unsafe extern "C" fn try_decode_udaf_fn_wrapper( + codec: &FFI_LogicalExtensionCodec, + name: RStr, + buf: RSlice, +) -> FFIResult { + let codec_inner = codec.inner(); + let udaf = rresult_return!(codec_inner.try_decode_udaf(name.into(), buf.as_ref())); + let udaf = FFI_AggregateUDF::from(udaf); + + RResult::ROk(udaf) +} + +unsafe extern "C" fn try_encode_udaf_fn_wrapper( + codec: &FFI_LogicalExtensionCodec, + node: FFI_AggregateUDF, +) -> FFIResult> { + let codec = codec.inner(); + let udaf: Arc = (&node).into(); + let udaf = AggregateUDF::new_from_shared_impl(udaf); + + let mut bytes = Vec::new(); + rresult_return!(codec.try_encode_udaf(&udaf, &mut bytes)); + + RResult::ROk(bytes.into()) +} + +unsafe extern "C" fn try_decode_udwf_fn_wrapper( + codec: &FFI_LogicalExtensionCodec, + name: RStr, + buf: RSlice, +) -> FFIResult { + let codec = codec.inner(); + let udwf = rresult_return!(codec.try_decode_udwf(name.into(), buf.as_ref())); + let udwf = FFI_WindowUDF::from(udwf); + + RResult::ROk(udwf) +} + +unsafe extern "C" fn try_encode_udwf_fn_wrapper( + codec: &FFI_LogicalExtensionCodec, + node: FFI_WindowUDF, +) -> FFIResult> { + let codec = codec.inner(); + let udwf: Arc = (&node).into(); + let udwf = WindowUDF::new_from_shared_impl(udwf); + + let mut bytes = Vec::new(); + rresult_return!(codec.try_encode_udwf(&udwf, &mut bytes)); + + RResult::ROk(bytes.into()) +} + +unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_LogicalExtensionCodec) { + let private_data = + Box::from_raw(provider.private_data as *mut LogicalExtensionCodecPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper( + codec: &FFI_LogicalExtensionCodec, +) -> FFI_LogicalExtensionCodec { + let old_codec = Arc::clone(codec.inner()); + let runtime = codec.runtime().clone(); + + FFI_LogicalExtensionCodec::new(old_codec, runtime, codec.task_ctx_provider.clone()) +} + +impl Drop for FFI_LogicalExtensionCodec { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl FFI_LogicalExtensionCodec { + /// Creates a new [`FFI_LogicalExtensionCodec`]. + pub fn new( + provider: Arc, + runtime: Option, + task_ctx_provider: impl Into, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); + let private_data = + Box::new(LogicalExtensionCodecPrivateData { provider, runtime }); + + Self { + try_decode_table_provider: try_decode_table_provider_fn_wrapper, + try_encode_table_provider: try_encode_table_provider_fn_wrapper, + try_decode_udf: try_decode_udf_fn_wrapper, + try_encode_udf: try_encode_udf_fn_wrapper, + try_decode_udaf: try_decode_udaf_fn_wrapper, + try_encode_udaf: try_encode_udaf_fn_wrapper, + try_decode_udwf: try_decode_udwf_fn_wrapper, + try_encode_udwf: try_encode_udwf_fn_wrapper, + task_ctx_provider, + + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: crate::version, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } + } +} + +/// This wrapper struct exists on the receiver side of the FFI interface, so it has +/// no guarantees about being able to access the data in `private_data`. Any functions +/// defined on this struct must only use the stable functions provided in +/// FFI_LogicalExtensionCodec to interact with the foreign table provider. +#[derive(Debug)] +pub struct ForeignLogicalExtensionCodec(pub FFI_LogicalExtensionCodec); + +unsafe impl Send for ForeignLogicalExtensionCodec {} +unsafe impl Sync for ForeignLogicalExtensionCodec {} + +impl From<&FFI_LogicalExtensionCodec> for Arc { + fn from(provider: &FFI_LogicalExtensionCodec) -> Self { + if (provider.library_marker_id)() == crate::get_library_marker_id() { + Arc::clone(provider.inner()) + } else { + Arc::new(ForeignLogicalExtensionCodec(provider.clone())) + } + } +} + +impl Clone for FFI_LogicalExtensionCodec { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl LogicalExtensionCodec for ForeignLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &TaskContext, + ) -> Result { + not_impl_err!("FFI does not support decode of Extensions") + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + not_impl_err!("FFI does not support encode of Extensions") + } + + fn try_decode_table_provider( + &self, + buf: &[u8], + table_ref: &TableReference, + schema: SchemaRef, + _ctx: &TaskContext, + ) -> Result> { + let table_ref = table_ref.to_string(); + let schema: WrappedSchema = schema.into(); + + let ffi_table_provider = unsafe { + df_result!((self.0.try_decode_table_provider)( + &self.0, + buf.into(), + table_ref.as_str().into(), + schema + )) + }?; + + Ok((&ffi_table_provider).into()) + } + + fn try_encode_table_provider( + &self, + table_ref: &TableReference, + node: Arc, + buf: &mut Vec, + ) -> Result<()> { + let table_ref = table_ref.to_string(); + let node = FFI_TableProvider::new(node, true, None); + + let bytes = df_result!(unsafe { + (self.0.try_encode_table_provider)(&self.0, table_ref.as_str().into(), node) + })?; + + buf.extend(bytes); + + Ok(()) + } + + fn try_decode_file_format( + &self, + _buf: &[u8], + _ctx: &TaskContext, + ) -> Result> { + not_impl_err!("FFI does not support decode_file_format") + } + + fn try_encode_file_format( + &self, + _buf: &mut Vec, + _node: Arc, + ) -> Result<()> { + not_impl_err!("FFI does not support encode_file_format") + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + let udf = unsafe { + df_result!((self.0.try_decode_udf)(&self.0, name.into(), buf.into())) + }?; + let udf: Arc = (&udf).into(); + + Ok(Arc::new(ScalarUDF::new_from_shared_impl(udf))) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let node = FFI_ScalarUDF::from(Arc::new(node.clone())); + let bytes = df_result!(unsafe { (self.0.try_encode_udf)(&self.0, node) })?; + + buf.extend(bytes); + + Ok(()) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + let udaf = unsafe { + df_result!((self.0.try_decode_udaf)(&self.0, name.into(), buf.into())) + }?; + let udaf: Arc = (&udaf).into(); + + Ok(Arc::new(AggregateUDF::new_from_shared_impl(udaf))) + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + let node = Arc::new(node.clone()); + let node = FFI_AggregateUDF::from(node); + let bytes = df_result!(unsafe { (self.0.try_encode_udaf)(&self.0, node) })?; + + buf.extend(bytes); + + Ok(()) + } + + fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + let udwf = unsafe { + df_result!((self.0.try_decode_udwf)(&self.0, name.into(), buf.into())) + }?; + let udwf: Arc = (&udwf).into(); + + Ok(Arc::new(WindowUDF::new_from_shared_impl(udwf))) + } + + fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + let node = Arc::new(node.clone()); + let node = FFI_WindowUDF::from(node); + let bytes = df_result!(unsafe { (self.0.try_encode_udwf)(&self.0, node) })?; + + buf.extend(bytes); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::record_batch; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use datafusion_catalog::{MemTable, TableProvider}; + use datafusion_common::{exec_err, Result, TableReference}; + use datafusion_datasource::file_format::FileFormatFactory; + use datafusion_execution::TaskContext; + use datafusion_expr::ptr_eq::arc_ptr_eq; + use datafusion_expr::{AggregateUDF, Extension, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_functions::math::abs::AbsFunc; + use datafusion_functions_aggregate::sum::Sum; + use datafusion_functions_window::rank::{Rank, RankType}; + use datafusion_proto::logical_plan::LogicalExtensionCodec; + use datafusion_proto::physical_plan::PhysicalExtensionCodec; + + use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; + use crate::proto::physical_extension_codec::tests::TestExtensionCodec; + + fn create_test_table() -> MemTable { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let rb = record_batch!(("a", Int32, [1, 2, 3])) + .expect("should be able to create a record batch"); + MemTable::try_new(schema, vec![vec![rb]]) + .expect("should be able to create an in memory table") + } + + impl LogicalExtensionCodec for TestExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &TaskContext, + ) -> Result { + unimplemented!() + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + unimplemented!() + } + + fn try_decode_table_provider( + &self, + buf: &[u8], + _table_ref: &TableReference, + schema: SchemaRef, + _ctx: &TaskContext, + ) -> Result> { + if buf[0] != Self::MAGIC_NUMBER { + return exec_err!( + "TestExtensionCodec input buffer does not start with magic number" + ); + } + + if schema != create_test_table().schema() { + return exec_err!("Incorrect test table schema"); + } + + if buf.len() != 2 || buf[1] != Self::MEMTABLE_SERIALIZED { + return exec_err!("TestExtensionCodec unable to decode table provider"); + } + + Ok(Arc::new(create_test_table()) as Arc) + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + node: Arc, + buf: &mut Vec, + ) -> Result<()> { + buf.push(Self::MAGIC_NUMBER); + + if !node.as_any().is::() { + return exec_err!("TestExtensionCodec only expects MemTable"); + }; + + if node.schema() != create_test_table().schema() { + return exec_err!("Unexpected schema for encoding."); + } + + buf.push(Self::MEMTABLE_SERIALIZED); + + Ok(()) + } + + fn try_decode_file_format( + &self, + _buf: &[u8], + _ctx: &TaskContext, + ) -> Result> { + unimplemented!() + } + + fn try_encode_file_format( + &self, + _buf: &mut Vec, + _node: Arc, + ) -> Result<()> { + unimplemented!() + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + PhysicalExtensionCodec::try_decode_udf(self, name, buf) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + PhysicalExtensionCodec::try_encode_udf(self, node, buf) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + PhysicalExtensionCodec::try_decode_udaf(self, name, buf) + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + PhysicalExtensionCodec::try_encode_udaf(self, node, buf) + } + + fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + PhysicalExtensionCodec::try_decode_udwf(self, name, buf) + } + + fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + PhysicalExtensionCodec::try_encode_udwf(self, node, buf) + } + } + + #[test] + fn roundtrip_ffi_logical_extension_codec_table_provider() -> Result<()> { + let codec = Arc::new(TestExtensionCodec {}); + let (ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_LogicalExtensionCodec::new(codec, None, task_ctx_provider); + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + + let table = Arc::new(create_test_table()) as Arc; + let mut bytes = Vec::new(); + foreign_codec.try_encode_table_provider(&"my_table".into(), table, &mut bytes)?; + + let returned_table = foreign_codec.try_decode_table_provider( + &bytes, + &"my_table".into(), + create_test_table().schema(), + ctx.task_ctx().as_ref(), + )?; + + assert!(returned_table.as_any().is::()); + + Ok(()) + } + + #[test] + fn roundtrip_ffi_logical_extension_codec_udf() -> Result<()> { + let codec = Arc::new(TestExtensionCodec {}); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_LogicalExtensionCodec::new(codec, None, task_ctx_provider); + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + + let udf = Arc::new(ScalarUDF::from(AbsFunc::new())); + let mut bytes = Vec::new(); + foreign_codec.try_encode_udf(udf.as_ref(), &mut bytes)?; + + let returned_udf = foreign_codec.try_decode_udf(udf.name(), &bytes)?; + + assert!(returned_udf.inner().as_any().is::()); + + Ok(()) + } + + #[test] + fn roundtrip_ffi_logical_extension_codec_udaf() -> Result<()> { + let codec = Arc::new(TestExtensionCodec {}); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_LogicalExtensionCodec::new(codec, None, task_ctx_provider); + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + + let udf = Arc::new(AggregateUDF::from(Sum::new())); + let mut bytes = Vec::new(); + foreign_codec.try_encode_udaf(udf.as_ref(), &mut bytes)?; + + let returned_udf = foreign_codec.try_decode_udaf(udf.name(), &bytes)?; + + assert!(returned_udf.inner().as_any().is::()); + + Ok(()) + } + + #[test] + fn roundtrip_ffi_logical_extension_codec_udwf() -> Result<()> { + let codec = Arc::new(TestExtensionCodec {}); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_LogicalExtensionCodec::new(codec, None, task_ctx_provider); + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + + let udf = Arc::new(WindowUDF::from(Rank::new( + "my_rank".to_owned(), + RankType::Basic, + ))); + let mut bytes = Vec::new(); + foreign_codec.try_encode_udwf(udf.as_ref(), &mut bytes)?; + + let returned_udf = foreign_codec.try_decode_udwf(udf.name(), &bytes)?; + + assert!(returned_udf.inner().as_any().is::()); + + Ok(()) + } + + #[test] + fn ffi_logical_extension_codec_local_bypass() { + let codec = + Arc::new(TestExtensionCodec {}) as Arc; + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_LogicalExtensionCodec::new(Arc::clone(&codec), None, task_ctx_provider); + + let codec = codec as Arc; + // Verify local libraries can be downcast to their original + let foreign_codec: Arc = (&ffi_codec).into(); + assert!(arc_ptr_eq(&foreign_codec, &codec)); + + // Verify different library markers generate foreign providers + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + assert!(!arc_ptr_eq(&foreign_codec, &codec)); + } +} diff --git a/datafusion/ffi/src/proto/mod.rs b/datafusion/ffi/src/proto/mod.rs new file mode 100644 index 000000000000..ae76027ecb64 --- /dev/null +++ b/datafusion/ffi/src/proto/mod.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod logical_extension_codec; +pub mod physical_extension_codec; diff --git a/datafusion/ffi/src/proto/physical_extension_codec.rs b/datafusion/ffi/src/proto/physical_extension_codec.rs new file mode 100644 index 000000000000..8f3667400ee4 --- /dev/null +++ b/datafusion/ffi/src/proto/physical_extension_codec.rs @@ -0,0 +1,679 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; +use std::sync::Arc; + +use abi_stable::std_types::{RResult, RSlice, RStr, RVec}; +use abi_stable::StableAbi; +use datafusion_common::error::Result; +use datafusion_execution::TaskContext; +use datafusion_expr::{ + AggregateUDF, AggregateUDFImpl, ScalarUDF, ScalarUDFImpl, WindowUDF, WindowUDFImpl, +}; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use tokio::runtime::Handle; + +use crate::execution::FFI_TaskContextProvider; +use crate::execution_plan::FFI_ExecutionPlan; +use crate::udaf::FFI_AggregateUDF; +use crate::udf::FFI_ScalarUDF; +use crate::udwf::FFI_WindowUDF; +use crate::util::FFIResult; +use crate::{df_result, rresult_return}; + +/// A stable struct for sharing [`PhysicalExtensionCodec`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_PhysicalExtensionCodec { + /// Decode bytes into an execution plan. + try_decode: unsafe extern "C" fn( + &Self, + buf: RSlice, + inputs: RVec, + ) -> FFIResult, + + /// Encode an execution plan into bytes. + try_encode: + unsafe extern "C" fn(&Self, node: FFI_ExecutionPlan) -> FFIResult>, + + /// Decode bytes into a user defined scalar function. + try_decode_udf: unsafe extern "C" fn( + &Self, + name: RStr, + buf: RSlice, + ) -> FFIResult, + + /// Encode a user defined scalar function into bytes. + try_encode_udf: + unsafe extern "C" fn(&Self, node: FFI_ScalarUDF) -> FFIResult>, + + /// Decode bytes into a user defined aggregate function. + try_decode_udaf: unsafe extern "C" fn( + &Self, + name: RStr, + buf: RSlice, + ) -> FFIResult, + + /// Encode a user defined aggregate function into bytes. + try_encode_udaf: + unsafe extern "C" fn(&Self, node: FFI_AggregateUDF) -> FFIResult>, + + /// Decode bytes into a user defined window function. + try_decode_udwf: unsafe extern "C" fn( + &Self, + name: RStr, + buf: RSlice, + ) -> FFIResult, + + /// Encode a user defined window function into bytes. + try_encode_udwf: + unsafe extern "C" fn(&Self, node: FFI_WindowUDF) -> FFIResult>, + + /// Access the current [`TaskContext`]. + task_ctx_provider: FFI_TaskContextProvider, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Return the major DataFusion version number of this provider. + pub version: unsafe extern "C" fn() -> u64, + + /// Internal data. This is only to be accessed by the provider of the plan. + /// A [`ForeignPhysicalExtensionCodec`] should never attempt to access this data. + pub private_data: *mut c_void, + + /// Utility to identify when FFI objects are accessed locally through + /// the foreign interface. + pub library_marker_id: extern "C" fn() -> usize, +} + +unsafe impl Send for FFI_PhysicalExtensionCodec {} +unsafe impl Sync for FFI_PhysicalExtensionCodec {} + +struct PhysicalExtensionCodecPrivateData { + provider: Arc, + runtime: Option, +} + +impl FFI_PhysicalExtensionCodec { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const PhysicalExtensionCodecPrivateData; + unsafe { &(*private_data).provider } + } + + fn runtime(&self) -> &Option { + let private_data = self.private_data as *const PhysicalExtensionCodecPrivateData; + unsafe { &(*private_data).runtime } + } +} + +unsafe extern "C" fn try_decode_fn_wrapper( + codec: &FFI_PhysicalExtensionCodec, + buf: RSlice, + inputs: RVec, +) -> FFIResult { + let task_ctx: Arc = + rresult_return!((&codec.task_ctx_provider).try_into()); + let codec = codec.inner(); + let inputs = inputs + .into_iter() + .map(|plan| >::try_from(&plan)) + .collect::>>(); + let inputs = rresult_return!(inputs); + + let plan = + rresult_return!(codec.try_decode(buf.as_ref(), &inputs, task_ctx.as_ref())); + + RResult::ROk(FFI_ExecutionPlan::new(plan, task_ctx, None)) +} + +unsafe extern "C" fn try_encode_fn_wrapper( + codec: &FFI_PhysicalExtensionCodec, + node: FFI_ExecutionPlan, +) -> FFIResult> { + let codec = codec.inner(); + + let plan: Arc = rresult_return!((&node).try_into()); + + let mut bytes = Vec::new(); + rresult_return!(codec.try_encode(plan, &mut bytes)); + + RResult::ROk(bytes.into()) +} + +unsafe extern "C" fn try_decode_udf_fn_wrapper( + codec: &FFI_PhysicalExtensionCodec, + name: RStr, + buf: RSlice, +) -> FFIResult { + let codec = codec.inner(); + + let udf = rresult_return!(codec.try_decode_udf(name.as_str(), buf.as_ref())); + let udf = FFI_ScalarUDF::from(udf); + + RResult::ROk(udf) +} + +unsafe extern "C" fn try_encode_udf_fn_wrapper( + codec: &FFI_PhysicalExtensionCodec, + node: FFI_ScalarUDF, +) -> FFIResult> { + let codec = codec.inner(); + let node: Arc = (&node).into(); + let node = ScalarUDF::new_from_shared_impl(node); + + let mut bytes = Vec::new(); + rresult_return!(codec.try_encode_udf(&node, &mut bytes)); + + RResult::ROk(bytes.into()) +} + +unsafe extern "C" fn try_decode_udaf_fn_wrapper( + codec: &FFI_PhysicalExtensionCodec, + name: RStr, + buf: RSlice, +) -> FFIResult { + let codec_inner = codec.inner(); + let udaf = rresult_return!(codec_inner.try_decode_udaf(name.into(), buf.as_ref())); + let udaf = FFI_AggregateUDF::from(udaf); + + RResult::ROk(udaf) +} + +unsafe extern "C" fn try_encode_udaf_fn_wrapper( + codec: &FFI_PhysicalExtensionCodec, + node: FFI_AggregateUDF, +) -> FFIResult> { + let codec = codec.inner(); + let udaf: Arc = (&node).into(); + let udaf = AggregateUDF::new_from_shared_impl(udaf); + + let mut bytes = Vec::new(); + rresult_return!(codec.try_encode_udaf(&udaf, &mut bytes)); + + RResult::ROk(bytes.into()) +} + +unsafe extern "C" fn try_decode_udwf_fn_wrapper( + codec: &FFI_PhysicalExtensionCodec, + name: RStr, + buf: RSlice, +) -> FFIResult { + let codec = codec.inner(); + let udwf = rresult_return!(codec.try_decode_udwf(name.into(), buf.as_ref())); + let udwf = FFI_WindowUDF::from(udwf); + + RResult::ROk(udwf) +} + +unsafe extern "C" fn try_encode_udwf_fn_wrapper( + codec: &FFI_PhysicalExtensionCodec, + node: FFI_WindowUDF, +) -> FFIResult> { + let codec = codec.inner(); + let udwf: Arc = (&node).into(); + let udwf = WindowUDF::new_from_shared_impl(udwf); + + let mut bytes = Vec::new(); + rresult_return!(codec.try_encode_udwf(&udwf, &mut bytes)); + + RResult::ROk(bytes.into()) +} + +unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_PhysicalExtensionCodec) { + let private_data = + Box::from_raw(provider.private_data as *mut PhysicalExtensionCodecPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper( + codec: &FFI_PhysicalExtensionCodec, +) -> FFI_PhysicalExtensionCodec { + let old_codec = Arc::clone(codec.inner()); + let runtime = codec.runtime().clone(); + + FFI_PhysicalExtensionCodec::new(old_codec, runtime, codec.task_ctx_provider.clone()) +} + +impl Drop for FFI_PhysicalExtensionCodec { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl FFI_PhysicalExtensionCodec { + /// Creates a new [`FFI_PhysicalExtensionCodec`]. + pub fn new( + provider: Arc, + runtime: Option, + task_ctx_provider: impl Into, + ) -> Self { + let task_ctx_provider = task_ctx_provider.into(); + let private_data = + Box::new(PhysicalExtensionCodecPrivateData { provider, runtime }); + + Self { + try_decode: try_decode_fn_wrapper, + try_encode: try_encode_fn_wrapper, + try_decode_udf: try_decode_udf_fn_wrapper, + try_encode_udf: try_encode_udf_fn_wrapper, + try_decode_udaf: try_decode_udaf_fn_wrapper, + try_encode_udaf: try_encode_udaf_fn_wrapper, + try_decode_udwf: try_decode_udwf_fn_wrapper, + try_encode_udwf: try_encode_udwf_fn_wrapper, + task_ctx_provider, + + clone: clone_fn_wrapper, + release: release_fn_wrapper, + version: crate::version, + private_data: Box::into_raw(private_data) as *mut c_void, + library_marker_id: crate::get_library_marker_id, + } + } +} + +/// This wrapper struct exists on the receiver side of the FFI interface, so it has +/// no guarantees about being able to access the data in `private_data`. Any functions +/// defined on this struct must only use the stable functions provided in +/// FFI_PhysicalExtensionCodec to interact with the foreign table provider. +#[derive(Debug)] +pub struct ForeignPhysicalExtensionCodec(pub FFI_PhysicalExtensionCodec); + +unsafe impl Send for ForeignPhysicalExtensionCodec {} +unsafe impl Sync for ForeignPhysicalExtensionCodec {} + +impl From<&FFI_PhysicalExtensionCodec> for Arc { + fn from(provider: &FFI_PhysicalExtensionCodec) -> Self { + if (provider.library_marker_id)() == crate::get_library_marker_id() { + Arc::clone(provider.inner()) + } else { + Arc::new(ForeignPhysicalExtensionCodec(provider.clone())) + } + } +} + +impl Clone for FFI_PhysicalExtensionCodec { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl PhysicalExtensionCodec for ForeignPhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + let task_ctx = (&self.0.task_ctx_provider).try_into()?; + let inputs = inputs + .iter() + .map(|plan| { + FFI_ExecutionPlan::new(Arc::clone(plan), Arc::clone(&task_ctx), None) + }) + .collect(); + + let plan = + df_result!(unsafe { (self.0.try_decode)(&self.0, buf.into(), inputs) })?; + let plan: Arc = (&plan).try_into()?; + + Ok(plan) + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + let task_ctx = (&self.0.task_ctx_provider).try_into()?; + let plan = FFI_ExecutionPlan::new(node, task_ctx, None); + let bytes = df_result!(unsafe { (self.0.try_encode)(&self.0, plan) })?; + + buf.extend(bytes); + Ok(()) + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + let udf = unsafe { + df_result!((self.0.try_decode_udf)(&self.0, name.into(), buf.into())) + }?; + let udf: Arc = (&udf).into(); + + Ok(Arc::new(ScalarUDF::new_from_shared_impl(udf))) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let node = FFI_ScalarUDF::from(Arc::new(node.clone())); + let bytes = df_result!(unsafe { (self.0.try_encode_udf)(&self.0, node) })?; + + buf.extend(bytes); + + Ok(()) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + let udaf = unsafe { + df_result!((self.0.try_decode_udaf)(&self.0, name.into(), buf.into())) + }?; + let udaf: Arc = (&udaf).into(); + + Ok(Arc::new(AggregateUDF::new_from_shared_impl(udaf))) + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + let node = Arc::new(node.clone()); + let node = FFI_AggregateUDF::from(node); + let bytes = df_result!(unsafe { (self.0.try_encode_udaf)(&self.0, node) })?; + + buf.extend(bytes); + + Ok(()) + } + + fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + let udwf = unsafe { + df_result!((self.0.try_decode_udwf)(&self.0, name.into(), buf.into())) + }?; + let udwf: Arc = (&udwf).into(); + + Ok(Arc::new(WindowUDF::new_from_shared_impl(udwf))) + } + + fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + let node = Arc::new(node.clone()); + let node = FFI_WindowUDF::from(node); + let bytes = df_result!(unsafe { (self.0.try_encode_udwf)(&self.0, node) })?; + + buf.extend(bytes); + + Ok(()) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use std::sync::Arc; + + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::{exec_err, Result}; + use datafusion_execution::TaskContext; + use datafusion_expr::ptr_eq::arc_ptr_eq; + use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF, WindowUDFImpl}; + use datafusion_functions::math::abs::AbsFunc; + use datafusion_functions_aggregate::sum::Sum; + use datafusion_functions_window::rank::{Rank, RankType}; + use datafusion_physical_plan::ExecutionPlan; + use datafusion_proto::physical_plan::PhysicalExtensionCodec; + + use crate::execution_plan::tests::EmptyExec; + use crate::proto::physical_extension_codec::FFI_PhysicalExtensionCodec; + + #[derive(Debug)] + pub(crate) struct TestExtensionCodec; + + impl TestExtensionCodec { + pub(crate) const MAGIC_NUMBER: u8 = 127; + pub(crate) const EMPTY_EXEC_SERIALIZED: u8 = 1; + pub(crate) const ABS_FUNC_SERIALIZED: u8 = 2; + pub(crate) const SUM_UDAF_SERIALIZED: u8 = 3; + pub(crate) const RANK_UDWF_SERIALIZED: u8 = 4; + pub(crate) const MEMTABLE_SERIALIZED: u8 = 5; + } + + impl PhysicalExtensionCodec for TestExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + _inputs: &[Arc], + _ctx: &TaskContext, + ) -> Result> { + if buf[0] != Self::MAGIC_NUMBER { + return exec_err!( + "TestExtensionCodec input buffer does not start with magic number" + ); + } + + if buf.len() != 2 || buf[1] != Self::EMPTY_EXEC_SERIALIZED { + return exec_err!("TestExtensionCodec unable to decode execution plan"); + } + + Ok(create_test_exec()) + } + + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + ) -> Result<()> { + buf.push(Self::MAGIC_NUMBER); + + let Some(_) = node.as_any().downcast_ref::() else { + return exec_err!("TestExtensionCodec only expects EmptyExec"); + }; + + buf.push(Self::EMPTY_EXEC_SERIALIZED); + + Ok(()) + } + + fn try_decode_udf(&self, _name: &str, buf: &[u8]) -> Result> { + if buf[0] != Self::MAGIC_NUMBER { + return exec_err!( + "TestExtensionCodec input buffer does not start with magic number" + ); + } + + if buf.len() != 2 || buf[1] != Self::ABS_FUNC_SERIALIZED { + return exec_err!("TestExtensionCodec unable to decode udf"); + } + + Ok(Arc::new(ScalarUDF::from(AbsFunc::new()))) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + buf.push(Self::MAGIC_NUMBER); + + let udf = node.inner(); + if !udf.as_any().is::() { + return exec_err!("TestExtensionCodec only expects Abs UDF"); + }; + + buf.push(Self::ABS_FUNC_SERIALIZED); + + Ok(()) + } + + fn try_decode_udaf(&self, _name: &str, buf: &[u8]) -> Result> { + if buf[0] != Self::MAGIC_NUMBER { + return exec_err!( + "TestExtensionCodec input buffer does not start with magic number" + ); + } + + if buf.len() != 2 || buf[1] != Self::SUM_UDAF_SERIALIZED { + return exec_err!("TestExtensionCodec unable to decode udaf"); + } + + Ok(Arc::new(AggregateUDF::from(Sum::new()))) + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + buf.push(Self::MAGIC_NUMBER); + + let udf = node.inner(); + let Some(_udf) = udf.as_any().downcast_ref::() else { + return exec_err!("TestExtensionCodec only expects Sum UDAF"); + }; + + buf.push(Self::SUM_UDAF_SERIALIZED); + + Ok(()) + } + + fn try_decode_udwf(&self, _name: &str, buf: &[u8]) -> Result> { + if buf[0] != Self::MAGIC_NUMBER { + return exec_err!( + "TestExtensionCodec input buffer does not start with magic number" + ); + } + + if buf.len() != 2 || buf[1] != Self::RANK_UDWF_SERIALIZED { + return exec_err!("TestExtensionCodec unable to decode udwf"); + } + + Ok(Arc::new(WindowUDF::from(Rank::new( + "my_rank".to_owned(), + RankType::Basic, + )))) + } + + fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + buf.push(Self::MAGIC_NUMBER); + + let udf = node.inner(); + let Some(udf) = udf.as_any().downcast_ref::() else { + return exec_err!("TestExtensionCodec only expects Rank UDWF"); + }; + + if udf.name() != "my_rank" { + return exec_err!("TestExtensionCodec only expects my_rank UDWF name"); + } + + buf.push(Self::RANK_UDWF_SERIALIZED); + + Ok(()) + } + } + + fn create_test_exec() -> Arc { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + Arc::new(EmptyExec::new(schema)) as Arc + } + + #[test] + fn roundtrip_ffi_physical_extension_codec_exec_plan() -> Result<()> { + let codec = Arc::new(TestExtensionCodec {}); + let (ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_PhysicalExtensionCodec::new(codec, None, task_ctx_provider); + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + + let exec = create_test_exec(); + let input_execs = [create_test_exec()]; + let mut bytes = Vec::new(); + foreign_codec.try_encode(Arc::clone(&exec), &mut bytes)?; + + let returned_exec = + foreign_codec.try_decode(&bytes, &input_execs, ctx.task_ctx().as_ref())?; + + assert!(returned_exec.as_any().is::()); + + Ok(()) + } + + #[test] + fn roundtrip_ffi_physical_extension_codec_udf() -> Result<()> { + let codec = Arc::new(TestExtensionCodec {}); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_PhysicalExtensionCodec::new(codec, None, task_ctx_provider); + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + + let udf = Arc::new(ScalarUDF::from(AbsFunc::new())); + let mut bytes = Vec::new(); + foreign_codec.try_encode_udf(udf.as_ref(), &mut bytes)?; + + let returned_udf = foreign_codec.try_decode_udf(udf.name(), &bytes)?; + + assert!(returned_udf.inner().as_any().is::()); + + Ok(()) + } + + #[test] + fn roundtrip_ffi_physical_extension_codec_udaf() -> Result<()> { + let codec = Arc::new(TestExtensionCodec {}); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_PhysicalExtensionCodec::new(codec, None, task_ctx_provider); + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + + let udf = Arc::new(AggregateUDF::from(Sum::new())); + let mut bytes = Vec::new(); + foreign_codec.try_encode_udaf(udf.as_ref(), &mut bytes)?; + + let returned_udf = foreign_codec.try_decode_udaf(udf.name(), &bytes)?; + + assert!(returned_udf.inner().as_any().is::()); + + Ok(()) + } + + #[test] + fn roundtrip_ffi_physical_extension_codec_udwf() -> Result<()> { + let codec = Arc::new(TestExtensionCodec {}); + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_PhysicalExtensionCodec::new(codec, None, task_ctx_provider); + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + + let udf = Arc::new(WindowUDF::from(Rank::new( + "my_rank".to_owned(), + RankType::Basic, + ))); + let mut bytes = Vec::new(); + foreign_codec.try_encode_udwf(udf.as_ref(), &mut bytes)?; + + let returned_udf = foreign_codec.try_decode_udwf(udf.name(), &bytes)?; + + assert!(returned_udf.inner().as_any().is::()); + + Ok(()) + } + + #[test] + fn ffi_physical_extension_codec_local_bypass() { + let codec = + Arc::new(TestExtensionCodec {}) as Arc; + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_codec = + FFI_PhysicalExtensionCodec::new(Arc::clone(&codec), None, task_ctx_provider); + + let codec = codec as Arc; + // Verify local libraries can be downcast to their original + let foreign_codec: Arc = (&ffi_codec).into(); + assert!(arc_ptr_eq(&foreign_codec, &codec)); + + // Verify different library markers generate foreign providers + ffi_codec.library_marker_id = crate::mock_foreign_marker_id; + let foreign_codec: Arc = (&ffi_codec).into(); + assert!(!arc_ptr_eq(&foreign_codec, &codec)); + } +} diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 640da7c04292..330286e3d562 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow_wrappers::WrappedSchema; +use std::sync::Arc; + use abi_stable::std_types::{RResult, RString, RVec}; -use arrow::datatypes::Field; -use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; use arrow_schema::FieldRef; -use std::sync::Arc; + +use crate::arrow_wrappers::WrappedSchema; /// Convenience type for results passed through the FFI boundary. Since the /// `DataFusionError` enum is complex and little value is gained from creating @@ -124,10 +126,25 @@ pub fn rvec_wrapped_to_vec_datatype( } #[cfg(test)] -mod tests { - use crate::util::FFIResult; +pub(crate) mod tests { + use std::sync::Arc; + use abi_stable::std_types::{RResult, RString}; use datafusion::error::DataFusionError; + use datafusion::prelude::SessionContext; + use datafusion_execution::TaskContextProvider; + + use crate::execution::FFI_TaskContextProvider; + use crate::util::FFIResult; + + pub(crate) fn test_session_and_ctx() -> (Arc, FFI_TaskContextProvider) + { + let ctx = Arc::new(SessionContext::new()); + let task_ctx_provider = Arc::clone(&ctx) as Arc; + let task_ctx_provider = FFI_TaskContextProvider::from(&task_ctx_provider); + + (ctx, task_ctx_provider) + } fn wrap_result(result: Result) -> FFIResult { RResult::ROk(rresult_return!(result))