diff --git a/experimental/ast/zero_test.go b/experimental/ast/zero_test.go index f6e8c837..13c8fcbc 100644 --- a/experimental/ast/zero_test.go +++ b/experimental/ast/zero_test.go @@ -104,7 +104,7 @@ func testZero[Node source.Spanner](t *testing.T) { r = m.Call(nil)[0] } - assert.Zero(t, r.Interface(), "non-zero return #%d %#v of %T.%s", i, r, z, m.Name) + assert.Zero(t, r.Interface(), "non-zero return %#d %#v of %T.%s", i, r, z, m.Name) } } }) diff --git a/experimental/ir/export_test.go b/experimental/ir/export_test.go index 522a59c5..83d294a1 100644 --- a/experimental/ir/export_test.go +++ b/experimental/ir/export_test.go @@ -36,5 +36,8 @@ func GetImports(f *File) *Imports { } func (s Symbol) RawData() arena.Untyped { + if s.IsZero() { + return arena.Nil() + } return s.Raw().data } diff --git a/experimental/ir/ir_file.go b/experimental/ir/ir_file.go index 72979694..1abb004b 100644 --- a/experimental/ir/ir_file.go +++ b/experimental/ir/ir_file.go @@ -148,6 +148,9 @@ func (f *File) InternedPath() intern.ID { // google/protobuf/descriptor.proto, which is given special treatment in // the language. func (f *File) IsDescriptorProto() bool { + if f == nil { + return false + } return f.InternedPath() == f.session.builtins.DescriptorFile } @@ -176,7 +179,11 @@ func (f *File) InternedPackage() intern.ID { // Imports returns an indexer over the imports declared in this file. func (f *File) Imports() seq.Indexer[Import] { - return f.imports.Directs() + var imp imports + if f != nil { + imp = f.imports + } + return imp.Directs() } // TransitiveImports returns an indexer over the transitive imports for this @@ -184,7 +191,11 @@ func (f *File) Imports() seq.Indexer[Import] { // // This function does not report whether those imports are weak or not. func (f *File) TransitiveImports() seq.Indexer[Import] { - return f.imports.Transitive() + var imp imports + if f != nil { + imp = f.imports + } + return imp.Transitive() } // ImportFor returns import metadata for a given file, if this file imports it. @@ -199,8 +210,12 @@ func (f *File) ImportFor(that *File) Import { // Types returns the top level types of this file. func (f *File) Types() seq.Indexer[Type] { + var types []id.ID[Type] + if f != nil { + types = f.types[:f.topLevelTypesEnd] + } return seq.NewFixedSlice( - f.types[:f.topLevelTypesEnd], + types, func(_ int, p id.ID[Type]) Type { return id.Wrap(f, p) }, @@ -209,8 +224,12 @@ func (f *File) Types() seq.Indexer[Type] { // AllTypes returns all types defined in this file. func (f *File) AllTypes() seq.Indexer[Type] { + var types []id.ID[Type] + if f != nil { + types = f.types + } return seq.NewFixedSlice( - f.types, + types, func(_ int, p id.ID[Type]) Type { return id.Wrap(f, p) }, @@ -220,8 +239,12 @@ func (f *File) AllTypes() seq.Indexer[Type] { // Extensions returns the top level extensions defined in this file (i.e., // the contents of any top-level `extends` blocks). func (f *File) Extensions() seq.Indexer[Member] { + var slice []id.ID[Member] + if f != nil { + slice = f.extns[:f.topLevelExtnsEnd] + } return seq.NewFixedSlice( - f.extns[:f.topLevelExtnsEnd], + slice, func(_ int, p id.ID[Member]) Member { return id.Wrap(f, p) }, @@ -230,8 +253,12 @@ func (f *File) Extensions() seq.Indexer[Member] { // AllExtensions returns all extensions defined in this file. func (f *File) AllExtensions() seq.Indexer[Member] { + var extns []id.ID[Member] + if f != nil { + extns = f.extns + } return seq.NewFixedSlice( - f.extns, + extns, func(_ int, p id.ID[Member]) Member { return id.Wrap(f, p) }, @@ -240,8 +267,12 @@ func (f *File) AllExtensions() seq.Indexer[Member] { // Extends returns the top level extend blocks in this file. func (f *File) Extends() seq.Indexer[Extend] { + var slice []id.ID[Extend] + if f != nil { + slice = f.extends[:f.topLevelExtendsEnd] + } return seq.NewFixedSlice( - f.extends[:f.topLevelExtendsEnd], + slice, func(_ int, p id.ID[Extend]) Extend { return id.Wrap(f, p) }, @@ -250,8 +281,12 @@ func (f *File) Extends() seq.Indexer[Extend] { // AllExtends returns all extend blocks in this file. func (f *File) AllExtends() seq.Indexer[Extend] { + var extends []id.ID[Extend] + if f != nil { + extends = f.extends + } return seq.NewFixedSlice( - f.extends, + extends, func(_ int, p id.ID[Extend]) Extend { return id.Wrap(f, p) }, @@ -261,8 +296,12 @@ func (f *File) AllExtends() seq.Indexer[Extend] { // AllMembers returns all fields defined in this file, including extensions // and enum values. func (f *File) AllMembers() iter.Seq[Member] { + var raw iter.Seq[*rawMember] + if f != nil { + raw = f.arenas.members.Values() + } i := 0 - return iterx.Map(f.arenas.members.Values(), func(raw *rawMember) Member { + return iterx.Map(raw, func(raw *rawMember) Member { i++ return id.WrapRaw(f, id.ID[Member](i), raw) }) @@ -270,8 +309,12 @@ func (f *File) AllMembers() iter.Seq[Member] { // Services returns all services defined in this file. func (f *File) Services() seq.Indexer[Service] { + var services []id.ID[Service] + if f != nil { + services = f.services + } return seq.NewFixedSlice( - f.services, + services, func(_ int, p id.ID[Service]) Service { return id.Wrap(f, p) }, @@ -280,11 +323,18 @@ func (f *File) Services() seq.Indexer[Service] { // Options returns the top level options applied to this file. func (f *File) Options() MessageValue { - return id.Wrap(f, f.options).AsMessage() + var options id.ID[Value] + if f != nil { + options = f.options + } + return id.Wrap(f, options).AsMessage() } // FeatureSet returns the Editions features associated with this file. func (f *File) FeatureSet() FeatureSet { + if f == nil { + return FeatureSet{} + } return id.Wrap(f, f.features) } @@ -308,8 +358,12 @@ func (f *File) Deprecated() Value { // imported by the file. The symbols are returned in an arbitrary but fixed // order. func (f *File) Symbols() seq.Indexer[Symbol] { + var symbols []Ref[Symbol] + if f != nil { + symbols = f.imported + } return seq.NewFixedSlice( - f.imported, + symbols, func(_ int, r Ref[Symbol]) Symbol { return GetRef(f, r) }, diff --git a/experimental/ir/ir_imports.go b/experimental/ir/ir_imports.go index 3a44eabc..d6dc965e 100644 --- a/experimental/ir/ir_imports.go +++ b/experimental/ir/ir_imports.go @@ -193,14 +193,21 @@ func (i *imports) MarkUsed(file *File) { // DescriptorProto returns the file for descriptor.proto. func (i *imports) DescriptorProto() *File { + if i == nil { + return nil + } imported, _ := slicesx.Last(i.files) return imported.file } // Directs returns an indexer over the Directs imports. func (i *imports) Directs() seq.Indexer[Import] { + var slice []imported + if i != nil { + slice = i.files[:i.importEnd] + } return seq.NewFixedSlice( - i.files[:i.importEnd], + slice, func(j int, imported imported) Import { n := uint32(j) public := n < i.publicEnd @@ -223,8 +230,12 @@ func (i *imports) Directs() seq.Indexer[Import] { // // This function does not report whether those imports are weak or used. func (i *imports) Transitive() seq.Indexer[Import] { + var slice []imported + if i != nil { + slice = i.files[:max(0, len(i.files)-1)] // Exclude the implicit descriptor.proto + } return seq.NewFixedSlice( - i.files[:max(0, len(i.files)-1)], // Exclude the implicit descriptor.proto. + slice, func(j int, imported imported) Import { n := uint32(j) return Import{ diff --git a/experimental/ir/ir_member.go b/experimental/ir/ir_member.go index bab11640..edd1afe1 100644 --- a/experimental/ir/ir_member.go +++ b/experimental/ir/ir_member.go @@ -336,6 +336,10 @@ func (m Member) Oneof() Oneof { // Options returns the options applied to this member. func (m Member) Options() MessageValue { + if m.IsZero() { + return MessageValue{} + } + return id.Wrap(m.Context(), m.Raw().options).AsMessage() } @@ -587,12 +591,13 @@ func (o Oneof) Index() int { // Members returns this oneof's member fields. func (o Oneof) Members() seq.Indexer[Member] { - return seq.NewFixedSlice( - o.Raw().members, - func(_ int, p id.ID[Member]) Member { - return id.Wrap(o.Context(), p) - }, - ) + var members []id.ID[Member] + if !o.IsZero() { + members = o.Raw().members + } + return seq.NewFixedSlice(members, func(_ int, p id.ID[Member]) Member { + return id.Wrap(o.Context(), p) + }) } // Parent returns the type that this oneof is declared within,. @@ -715,6 +720,9 @@ type rawReservedName struct { // AST returns the expression that this name was evaluated from, if known. func (r ReservedName) AST() ast.ExprAny { + if r.IsZero() { + return ast.ExprAny{} + } return r.raw.ast } diff --git a/experimental/ir/ir_type.go b/experimental/ir/ir_type.go index 4a82288d..7a34d8d2 100644 --- a/experimental/ir/ir_type.go +++ b/experimental/ir/ir_type.go @@ -196,7 +196,7 @@ func (t Type) AllowsAlias() bool { // IsAny returns whether this is the type google.protobuf.Any, which gets special // treatment in the language. func (t Type) IsAny() bool { - return t.InternedFullName() == t.Context().session.builtins.AnyPath + return !t.IsZero() && t.InternedFullName() == t.Context().session.builtins.AnyPath } // Predeclared returns the predeclared type that this Type corresponds to, if any. @@ -417,7 +417,10 @@ func (t Type) Extensions() seq.Indexer[Member] { // // This does not include reserved field names; see [Type.ReservedNames]. func (t Type) AllRanges() seq.Indexer[ReservedRange] { - slice := t.Raw().ranges + var slice []id.ID[ReservedRange] + if !t.IsZero() { + slice = t.Raw().ranges + } return seq.NewFixedSlice(slice, func(_ int, p id.ID[ReservedRange]) ReservedRange { return id.Wrap(t.Context(), p) }) @@ -427,7 +430,10 @@ func (t Type) AllRanges() seq.Indexer[ReservedRange] { // // This does not include reserved field names; see [Type.ReservedNames]. func (t Type) ReservedRanges() seq.Indexer[ReservedRange] { - slice := t.Raw().ranges[:t.Raw().rangesExtnStart] + var slice []id.ID[ReservedRange] + if !t.IsZero() { + slice = t.Raw().ranges[:t.Raw().rangesExtnStart] + } return seq.NewFixedSlice(slice, func(_ int, p id.ID[ReservedRange]) ReservedRange { return id.Wrap(t.Context(), p) }) @@ -435,7 +441,10 @@ func (t Type) ReservedRanges() seq.Indexer[ReservedRange] { // ExtensionRanges returns the extension ranges declared in this type. func (t Type) ExtensionRanges() seq.Indexer[ReservedRange] { - slice := t.Raw().ranges[t.Raw().rangesExtnStart:] + var slice []id.ID[ReservedRange] + if !t.IsZero() { + slice = t.Raw().ranges[t.Raw().rangesExtnStart:] + } return seq.NewFixedSlice(slice, func(_ int, p id.ID[ReservedRange]) ReservedRange { return id.Wrap(t.Context(), p) }) @@ -443,8 +452,12 @@ func (t Type) ExtensionRanges() seq.Indexer[ReservedRange] { // ReservedNames returns the reserved named declared in this type. func (t Type) ReservedNames() seq.Indexer[ReservedName] { + var slice []rawReservedName + if !t.IsZero() { + slice = t.Raw().reservedNames + } return seq.NewFixedSlice( - t.Raw().reservedNames, + slice, func(i int, _ rawReservedName) ReservedName { return ReservedName{id.WrapContext(t.Context()), &t.Raw().reservedNames[i]} }, @@ -453,8 +466,12 @@ func (t Type) ReservedNames() seq.Indexer[ReservedName] { // Oneofs returns the options applied to this type. func (t Type) Oneofs() seq.Indexer[Oneof] { + var oneofs []id.ID[Oneof] + if !t.IsZero() { + oneofs = t.Raw().oneofs + } return seq.NewFixedSlice( - t.Raw().oneofs, + oneofs, func(_ int, p id.ID[Oneof]) Oneof { return id.Wrap(t.Context(), p) }, @@ -463,8 +480,12 @@ func (t Type) Oneofs() seq.Indexer[Oneof] { // Extends returns the options applied to this type. func (t Type) Extends() seq.Indexer[Extend] { + var extends []id.ID[Extend] + if !t.IsZero() { + extends = t.Raw().extends + } return seq.NewFixedSlice( - t.Raw().extends, + extends, func(_ int, p id.ID[Extend]) Extend { return id.Wrap(t.Context(), p) }, @@ -473,6 +494,9 @@ func (t Type) Extends() seq.Indexer[Extend] { // Options returns the options applied to this type. func (t Type) Options() MessageValue { + if t.IsZero() { + return MessageValue{} + } return id.Wrap(t.Context(), t.Raw().options).AsMessage() } diff --git a/experimental/ir/ir_value.go b/experimental/ir/ir_value.go index dd3676c2..36ec26f6 100644 --- a/experimental/ir/ir_value.go +++ b/experimental/ir/ir_value.go @@ -571,7 +571,13 @@ func (e Element) AST() ast.ExprAny { // this element, e.g. // // key := e.Value().MessageKeys().At(e.ValueNodeIndex()) +// +// If the element is empty, this returns -1. func (e Element) ValueNodeIndex() int { + if e.IsZero() { + return -1 + } + // We do O(log n) work here, because this function doesn't get called except // for diagnostics. diff --git a/experimental/ir/zero_test.go b/experimental/ir/zero_test.go new file mode 100644 index 00000000..5b3367f2 --- /dev/null +++ b/experimental/ir/zero_test.go @@ -0,0 +1,137 @@ +// Copyright 2020-2025 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ir_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/bufbuild/protocompile/experimental/id" + "github.com/bufbuild/protocompile/experimental/ir" +) + +func TestZero(t *testing.T) { + t.Parallel() + + testZeroAny[*ir.File](t) + testZeroAny[ir.Import](t) // Import embeds *ir.File + testZeroAny[*ir.Imports](t) + + testZeroNode[ir.FeatureSet](t) + testZero[ir.Feature](t) + testZero[ir.FeatureInfo](t) + + testZeroNode[ir.Member](t) + testZeroNode[ir.Extend](t) + testZeroNode[ir.Oneof](t) + testZeroNode[ir.ReservedRange](t) + testZero[ir.ReservedName](t) + + testZero[ir.Service](t) + testZero[ir.Method](t) + + testZeroNode[ir.Symbol](t) + + testZeroNode[ir.Type](t) + + testZeroNode[ir.Value](t) + testZeroNode[ir.MessageValue](t) + testZero[ir.Element](t) +} + +// zeroable is a helper interface to enforce that types implement the IsZero method. +type zeroable interface { + IsZero() bool +} + +// node is a helper interface to enforce [id.Node] types. +type node[T any] interface { + zeroable + ID() id.ID[T] +} + +// testZeroNode is a helper that validates the zero value of IR nodes and enforces the +// [nodes] interface. +func testZeroNode[T node[T]](t *testing.T) { + t.Helper() + testZero[T](t) +} + +// testZero is a helper that validates the zero value of IR structures and enforces the +// [zeroable] interface. +func testZero[T zeroable](t *testing.T) { + t.Helper() + + testZeroAny[T](t) + testZeroAny[ir.Ref[T]](t) +} + +// testZeroAny is a helper that validates the zero value of T: +// +// 1. Accessors do not panic. +// 2. The method, IsZero() bool, returns true when called with the zero value. +// 3. The method, Context() [id.Constraint], if present, returns the zero value of *ir.File, +// which is always comparable. +// 4. Other accessors return zero values. +func testZeroAny[T any](t *testing.T) { + t.Helper() + + var z T + assert.Zero(t, z) + + v := reflect.ValueOf(z) + ty := reflect.TypeOf(z) + + t.Run(fmt.Sprintf("%T", z), func(t *testing.T) { + for i := range ty.NumMethod() { + m := ty.Method(i) + // This roughly represent the "accessors" (NumIn includes the receiver). + if m.Func.Type().NumIn() != 1 || m.Func.Type().NumOut() == 0 { + continue + } + returns := m.Func.Call([]reflect.Value{v}) + switch m.Name { + case "IsZero": + assert.Len(t, returns, 1) + assert.True(t, returns[0].Bool()) + case "ValueNodeIndex": + // This is a special case for [ir.Element], since 0 is a valid index, so for the + // zero value, it returns -1. + assert.Len(t, returns, 1) + assert.Equal(t, int64(-1), returns[0].Int()) + case "Context": + assert.Len(t, returns, 1) + assert.True(t, returns[0].Type().Comparable()) + assert.True(t, returns[0].Type().AssignableTo(reflect.TypeOf(&ir.File{}))) + default: + for i, r := range returns { + if r.Type().Kind() == reflect.Func { + continue + } + // r is an indexable type, so we test that length is 0. + if m := r.MethodByName("Len"); m.IsValid() { + assert.Equal(t, 0, m.Type().NumIn()) + assert.Equal(t, 1, m.Type().NumOut()) + r = m.Call(nil)[0] + } + assert.Zero(t, r.Interface(), "non-zero return %#d %#v of %T.%s", i, r, z, m.Name) + } + } + } + }) +}