Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ demo.launch()

var gr Module

func UpdateExamples(country string) Object {
func updateExamples(country string) Object {
println("country:", country)
if country == "USA" {
return gr.Call("Dataset", KwArgs{
Expand All @@ -280,10 +280,6 @@ func main() {
Initialize()
defer Finalize()
gr = ImportModule("gradio")
fn := CreateFunc("update_examples", UpdateExamples,
"(country, /)\n--\n\nUpdate examples based on country")
// Would be (in the future):
// fn := FuncOf(UpdateExamples)
demo := With(gr.Call("Blocks"), func(v Object) {
dropdown := gr.Call("Dropdown", KwArgs{
"label": "Country",
Expand All @@ -293,7 +289,7 @@ func main() {
textbox := gr.Call("Textbox")
examples := gr.Call("Examples", [][]string{{"Chicago"}, {"Little Rock"}, {"San Francisco"}}, textbox)
dataset := examples.Attr("dataset")
dropdown.Call("change", fn, dropdown, dataset)
dropdown.Call("change", updateExamples, dropdown, dataset)
})
demo.Call("launch")
}
Expand Down
2 changes: 2 additions & 0 deletions convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ func From(from any) Object {
return fromMap(vv).Object
case reflect.Struct:
return fromStruct(vv)
case reflect.Func:
return FuncOf(vv.Interface()).Object
}
panic(fmt.Errorf("unsupported type for Python: %T\n", v))
}
Expand Down
55 changes: 55 additions & 0 deletions convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,61 @@ func TestFromSpecialCases(t *testing.T) {
t.Errorf("Object was not independent, got %d after modifying original", got)
}
}()

func() {
// Test From with functions
add := func(a, b int) int { return a + b }
obj := From(add)

// Verify it's a function type
if !obj.IsFunc() {
t.Error("From(func) did not create Function object")
}

fn := obj.AsFunc()

// Test function call
result := fn.Call(5, 3)

if !result.IsLong() {
t.Error("Function call result is not a Long")
}
if got := result.AsLong().Int64(); got != 8 {
t.Errorf("Function call = %d, want 8", got)
}
}()

func() {
// Test From with function that returns multiple values
divMod := func(a, b int) (int, int) {
return a / b, a % b
}
obj := From(divMod)
if !obj.IsFunc() {
t.Error("From(func) did not create Function object")
}

fn := obj.AsFunc()

result := fn.Call(7, 3)

// Result should be a tuple with two values
if !result.IsTuple() {
t.Error("Multiple return value function did not return a Tuple")
}

tuple := result.AsTuple()
if tuple.Len() != 2 {
t.Errorf("Expected tuple of length 2, got %d", tuple.Len())
}

quotient := tuple.Get(0).AsLong().Int64()
remainder := tuple.Get(1).AsLong().Int64()

if quotient != 2 || remainder != 1 {
t.Errorf("Got (%d, %d), want (2, 1)", quotient, remainder)
}
}()
}

func TestToValueWithCustomType(t *testing.T) {
Expand Down
8 changes: 2 additions & 6 deletions demo/gradio/gradio.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ demo.launch()

var gr Module

func UpdateExamples(country string) Object {
func updateExamples(country string) Object {
println("country:", country)
if country == "USA" {
return gr.Call("Dataset", KwArgs{
Expand All @@ -49,10 +49,6 @@ func main() {
Initialize()
defer Finalize()
gr = ImportModule("gradio")
fn := CreateFunc("update_examples", UpdateExamples,
"(country, /)\n--\n\nUpdate examples based on country")
// Would be (in the future):
// fn := FuncOf(UpdateExamples)
demo := With(gr.Call("Blocks"), func(v Object) {
dropdown := gr.Call("Dropdown", KwArgs{
"label": "Country",
Expand All @@ -62,7 +58,7 @@ func main() {
textbox := gr.Call("Textbox")
examples := gr.Call("Examples", [][]string{{"Chicago"}, {"Little Rock"}, {"San Francisco"}}, textbox)
dataset := examples.Attr("dataset")
dropdown.Call("change", fn, dropdown, dataset)
dropdown.Call("change", updateExamples, dropdown, dataset)
})
demo.Call("launch")
}
156 changes: 150 additions & 6 deletions extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ import (
"unsafe"
)

func FuncOf(fn any) Func {
return CreateFunc("", fn, "")
}

func CreateFunc(name string, fn any, doc string) Func {
m := MainModule()
return m.AddMethod(name, fn, doc)
Expand Down Expand Up @@ -583,7 +587,24 @@ func (m Module) AddMethod(name string, fn any, doc string) Func {
}
}
name = goNameToPythonName(name)
doc = name + doc

hasRecv := false
if t.NumIn() > 0 {
firstParam := t.In(0)
if firstParam.Kind() == reflect.Ptr || firstParam.Kind() == reflect.Interface {
hasRecv = true
}
}

kwargsType := reflect.TypeOf(KwArgs{})
hasKwArgs := false
if t.NumIn() > 0 && t.In(t.NumIn()-1) == kwargsType {
hasKwArgs = true
}

sig := genSig(fn, hasRecv)
fullDoc := name + sig + "\n--\n\n" + doc
cDoc := C.CString(fullDoc)

maps := getGlobalData()
meta, ok := maps.typeMetas[m.obj]
Expand All @@ -596,23 +617,25 @@ func (m Module) AddMethod(name string, fn any, doc string) Func {

methodId := uint(len(meta.methods))

methodPtr := C.wrapperMethods[methodId]
cName := C.CString(name)
cDoc := C.CString(doc)

def := (*C.PyMethodDef)(C.malloc(C.size_t(unsafe.Sizeof(C.PyMethodDef{}))))
def.ml_name = cName
def.ml_meth = C.PyCFunction(methodPtr)
def.ml_meth = C.PyCFunction(C.wrapperMethods[methodId])
def.ml_flags = C.METH_VARARGS
if hasKwArgs {
def.ml_flags |= C.METH_KEYWORDS
def.ml_meth = C.PyCFunction(C.wrapperMethodsWithKwargs[methodId])
}
def.ml_doc = cDoc

methodMeta := &slotMeta{
name: name,
methodName: name,
fn: fn,
typ: t,
doc: doc,
hasRecv: false,
doc: fullDoc,
hasRecv: hasRecv,
def: def,
}
meta.methods[methodId] = methodMeta
Expand Down Expand Up @@ -665,3 +688,124 @@ func FetchError() error {

return fmt.Errorf("python error: %s", C.GoString(cstr))
}

func genSig(fn any, hasRecv bool) string {
t := reflect.TypeOf(fn)
if t.Kind() != reflect.Func {
panic("genSig: fn must be a function")
}

var args []string
startIdx := 0
if hasRecv {
startIdx = 1 // skip receiver
}

kwargsType := reflect.TypeOf(KwArgs{})
hasKwArgs := false
lastParamIdx := t.NumIn() - 1
if lastParamIdx >= startIdx && t.In(lastParamIdx) == kwargsType {
hasKwArgs = true
lastParamIdx-- // don't include KwArgs in regular parameters
}

for i := startIdx; i <= lastParamIdx; i++ {
paramName := fmt.Sprintf("arg%d", i-startIdx)
args = append(args, paramName)
}

// add "/" separator only if there are parameters
if len(args) > 0 {
args = append(args, "/")
}

// add "**kwargs" if there are keyword arguments
if hasKwArgs {
args = append(args, "**kwargs")
}

return fmt.Sprintf("(%s)", strings.Join(args, ", "))
}

//export wrapperMethodWithKwargs
func wrapperMethodWithKwargs(self, args, kwargs *C.PyObject, methodId C.int) *C.PyObject {
key := self
if C.isModule(self) == 0 {
key = (*C.PyObject)(unsafe.Pointer(self.ob_type))
}

maps := getGlobalData()
typeMeta, ok := maps.typeMetas[key]
check(ok, fmt.Sprintf("type %v not registered", FromPy(key)))

methodMeta := typeMeta.methods[uint(methodId)]
methodType := methodMeta.typ
hasReceiver := methodMeta.hasRecv

expectedArgs := methodType.NumIn()
if hasReceiver {
expectedArgs-- // skip receiver
}
expectedArgs-- // skip KwArgs

argc := C.PyTuple_Size(args)
if int(argc) != expectedArgs {
SetTypeError(fmt.Errorf("method %s expects %d arguments, got %d", methodMeta.name, expectedArgs, argc))
return nil
}

goArgs := make([]reflect.Value, methodType.NumIn())
argIndex := 0

if hasReceiver {
wrapper := (*wrapperType)(unsafe.Pointer(self))
receiverType := methodType.In(0)
var recv reflect.Value

if receiverType.Kind() == reflect.Ptr {
recv = reflect.ValueOf(wrapper.goObj)
} else {
recv = reflect.ValueOf(wrapper.goObj).Elem()
}

goArgs[0] = recv
argIndex = 1
}

for i := 0; i < int(argc); i++ {
arg := C.PySequence_GetItem(args, C.Py_ssize_t(i))
argType := methodType.In(i + argIndex)
argPy := FromPy(arg)
goValue := reflect.New(argType).Elem()
if !ToValue(argPy, goValue) {
SetTypeError(fmt.Errorf("failed to convert argument %v to %v", argPy, argType))
return nil
}
goArgs[i+argIndex] = goValue
}

kwargsValue := make(KwArgs)
if kwargs != nil {
dict := newDict(kwargs)
dict.Items()(func(key, value Object) bool {
kwargsValue[key.String()] = value
return true
})
}
goArgs[len(goArgs)-1] = reflect.ValueOf(kwargsValue)

results := reflect.ValueOf(methodMeta.fn).Call(goArgs)

if len(results) == 0 {
return None().cpyObj()
}
if len(results) == 1 {
return From(results[0].Interface()).cpyObj()
}

tuple := MakeTupleWithLen(len(results))
for i := range results {
tuple.Set(i, From(results[i].Interface()))
}
return tuple.cpyObj()
}
Loading