深入探索:Go语言中的protoc-gen-validate源码

发表时间: 2023-03-07 00:23

业务代码中有很多参数校验的代码,如果手动实现,会非常繁琐,
https://github.com/go-playground/validator是一个非常不错的选择echo 源码分析(validator),但是对于grpc来说,在定义proto的时候使用直接定义参数的限制规则是一种更合理、更优雅的方式,插件
https://github.com/bufbuild/protoc-gen-validate就是来帮助我们实现这一功能的。kratos框架也用到了这个插件。下面我们详细介绍下如何安装和使用。

首先,github上的安装方式并不好使,生成的代码里并没有校验规则,相反我们会得到下面的注释

  // no validation rules for Id  // no validation rules for Email

这是因为,这个包的main分支是不稳定版本,按照官方的方式安装并不好使。我们可以安装稳定版本

go install github.com/envoyproxy/protoc-gen-validate@v0.1.0

然后我们可以在GOPATH看到这个插件

 % ls $GOPATH/bin/protoc-gen-validatexxx/bin/protoc-gen-validate

对应的,我们的protoc版本如下

% protoc --version                 libprotoc 3.19.4

然后,可以定义我们的proto文件

syntax = "proto3";package examplepb;option go_package = "./example";import "validate/validate.proto";message Person {  uint64 id    = 1 [(validate.rules).uint64.gt = 999];  string email = 2 [(validate.rules).string.email = true];  string name  = 3 [(validate.rules).string = {                      pattern:   "^[^[0-9]A-Za-z]+( [^[0-9]A-Za-z]+)*$",                      max_bytes: 256,                   }];  Location home = 4 [(validate.rules).message.required = true];// 参数必须大于 0int64 ids = 5 [(validate.rules).int64 = {gt: 0}];// 参数必须在 0  120 之间int32 age = 6 [(validate.rules).int32 = {gt:0, lte: 120}];// 参数是 1  2  3uint32 code = 7 [(validate.rules).uint32 = {in: [1,2,3]}];// 参数不能是 0  99.99float score = 8 [(validate.rules).float = {not_in: [0, 99.99]}];  message Location {    double lat = 1 [(validate.rules).double = { gte: -90,  lte: 90 }];    double lng = 2 [(validate.rules).double = { gte: -180, lte: 180 }];  }}

使用命令生成go文件

% protoc \                                              -I . \  --plugin=$GOPATH/bin/protoc-gen-validate \  -I ${GOPATH}/pkg/mod/github.com/envoyproxy/protoc-gen-validate@v0.1.0/ \  --go_out=":./generated" \  --validate_out="lang=go:./generated" \  example.proto

相应的,我们得到了两个文件

learn/pgv/generated/example/example.pb.go

// Code generated by protoc-gen-go. DO NOT EDIT.// versions://   protoc-gen-go v1.28.1//   protoc        v3.19.4// source: example.protopackage exampleimport (  _ "github.com/envoyproxy/protoc-gen-validate/validate"  protoreflect "google.golang.org/protobuf/reflect/protoreflect"  protoimpl "google.golang.org/protobuf/runtime/protoimpl"  reflect "reflect"  sync "sync")const (  // Verify that this generated code is sufficiently up-to-date.  _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)  // Verify that runtime/protoimpl is sufficiently up-to-date.  _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20))type Person struct {  state         protoimpl.MessageState  sizeCache     protoimpl.SizeCache  unknownFields protoimpl.UnknownFields  Id    uint64           `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"`  Email string           `protobuf:"bytes,2,opt,name=email,proto3" json:"email,omitempty"`  Name  string           `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"`  Home  *Person_Location `protobuf:"bytes,4,opt,name=home,proto3" json:"home,omitempty"`  // 参数必须大于 0  Ids int64 `protobuf:"varint,5,opt,name=ids,proto3" json:"ids,omitempty"`  // 参数必须在 0  120 之间  Age int32 `protobuf:"varint,6,opt,name=age,proto3" json:"age,omitempty"`  // 参数是 1  2  3  Code uint32 `protobuf:"varint,7,opt,name=code,proto3" json:"code,omitempty"`  // 参数不能是 0  99.99  Score float32 `protobuf:"fixed32,8,opt,name=score,proto3" json:"score,omitempty"`}func (x *Person) Reset() {  *x = Person{}  if protoimpl.UnsafeEnabled {    mi := &file_example_proto_msgTypes[0]    ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))    ms.StoreMessageInfo(mi)  }}func (x *Person) String() string {  return protoimpl.X.MessageStringOf(x)}func (*Person) ProtoMessage() {}func (x *Person) ProtoReflect() protoreflect.Message {  mi := &file_example_proto_msgTypes[0]  if protoimpl.UnsafeEnabled && x != nil {    ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))    if ms.LoadMessageInfo() == nil {      ms.StoreMessageInfo(mi)    }    return ms  }  return mi.MessageOf(x)}// Deprecated: Use Person.ProtoReflect.Descriptor instead.func (*Person) Descriptor() ([]byte, []int) {  return file_example_proto_rawDescGZIP(), []int{0}}func (x *Person) GetId() uint64 {  if x != nil {    return x.Id  }  return 0}func (x *Person) GetEmail() string {  if x != nil {    return x.Email  }  return ""}func (x *Person) GetName() string {  if x != nil {    return x.Name  }  return ""}func (x *Person) GetHome() *Person_Location {  if x != nil {    return x.Home  }  return nil}func (x *Person) GetIds() int64 {  if x != nil {    return x.Ids  }  return 0}func (x *Person) GetAge() int32 {  if x != nil {    return x.Age  }  return 0}func (x *Person) GetCode() uint32 {  if x != nil {    return x.Code  }  return 0}func (x *Person) GetScore() float32 {  if x != nil {    return x.Score  }  return 0}type Person_Location struct {  state         protoimpl.MessageState  sizeCache     protoimpl.SizeCache  unknownFields protoimpl.UnknownFields  Lat float64 `protobuf:"fixed64,1,opt,name=lat,proto3" json:"lat,omitempty"`  Lng float64 `protobuf:"fixed64,2,opt,name=lng,proto3" json:"lng,omitempty"`}func (x *Person_Location) Reset() {  *x = Person_Location{}  if protoimpl.UnsafeEnabled {    mi := &file_example_proto_msgTypes[1]    ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))    ms.StoreMessageInfo(mi)  }}func (x *Person_Location) String() string {  return protoimpl.X.MessageStringOf(x)}func (*Person_Location) ProtoMessage() {}func (x *Person_Location) ProtoReflect() protoreflect.Message {  mi := &file_example_proto_msgTypes[1]  if protoimpl.UnsafeEnabled && x != nil {    ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))    if ms.LoadMessageInfo() == nil {      ms.StoreMessageInfo(mi)    }    return ms  }  return mi.MessageOf(x)}// Deprecated: Use Person_Location.ProtoReflect.Descriptor instead.func (*Person_Location) Descriptor() ([]byte, []int) {  return file_example_proto_rawDescGZIP(), []int{0, 0}}func (x *Person_Location) GetLat() float64 {  if x != nil {    return x.Lat  }  return 0}func (x *Person_Location) GetLng() float64 {  if x != nil {    return x.Lng  }  return 0}var File_example_proto protoreflect.FileDescriptorvar file_example_proto_rawDesc = []byte{  0x0a, 0x0d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,  0x09, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x70, 0x62, 0x1a, 0x17, 0x76, 0x61, 0x6c, 0x69,  0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72,  0x6f, 0x74, 0x6f, 0x22, 0xb5, 0x03, 0x0a, 0x06, 0x50, 0x65, 0x72, 0x73, 0x6f, 0x6e, 0x12, 0x1a,  0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x42, 0x0a, 0xba, 0xe9, 0xc0, 0x03,  0x05, 0x32, 0x03, 0x20, 0xe7, 0x07, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1f, 0x0a, 0x05, 0x65, 0x6d,  0x61, 0x69, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x09, 0xba, 0xe9, 0xc0, 0x03, 0x04,  0x72, 0x02, 0x60, 0x01, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x44, 0x0a, 0x04, 0x6e,  0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x30, 0xba, 0xe9, 0xc0, 0x03, 0x2b,  0x72, 0x29, 0x28, 0x80, 0x02, 0x32, 0x24, 0x5e, 0x5b, 0x5e, 0x5b, 0x30, 0x2d, 0x39, 0x5d, 0x41,  0x2d, 0x5a, 0x61, 0x2d, 0x7a, 0x5d, 0x2b, 0x28, 0x20, 0x5b, 0x5e, 0x5b, 0x30, 0x2d, 0x39, 0x5d,  0x41, 0x2d, 0x5a, 0x61, 0x2d, 0x7a, 0x5d, 0x2b, 0x29, 0x2a, 0x24, 0x52, 0x04, 0x6e, 0x61, 0x6d,  0x65, 0x12, 0x3a, 0x0a, 0x04, 0x68, 0x6f, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32,  0x1a, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x70, 0x62, 0x2e, 0x50, 0x65, 0x72, 0x73,  0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x42, 0x0a, 0xba, 0xe9, 0xc0,  0x03, 0x05, 0x8a, 0x01, 0x02, 0x10, 0x01, 0x52, 0x04, 0x68, 0x6f, 0x6d, 0x65, 0x12, 0x1b, 0x0a,  0x03, 0x69, 0x64, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x42, 0x09, 0xba, 0xe9, 0xc0, 0x03,  0x04, 0x22, 0x02, 0x20, 0x00, 0x52, 0x03, 0x69, 0x64, 0x73, 0x12, 0x1d, 0x0a, 0x03, 0x61, 0x67,  0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x42, 0x0b, 0xba, 0xe9, 0xc0, 0x03, 0x06, 0x1a, 0x04,  0x18, 0x78, 0x20, 0x00, 0x52, 0x03, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x63, 0x6f, 0x64,  0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0d, 0x42, 0x0d, 0xba, 0xe9, 0xc0, 0x03, 0x08, 0x2a, 0x06,  0x30, 0x01, 0x30, 0x02, 0x30, 0x03, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x27, 0x0a, 0x05,  0x73, 0x63, 0x6f, 0x72, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x02, 0x42, 0x11, 0xba, 0xe9, 0xc0,  0x03, 0x0c, 0x0a, 0x0a, 0x3d, 0x00, 0x00, 0x00, 0x00, 0x3d, 0xe1, 0xfa, 0xc7, 0x42, 0x52, 0x05,  0x73, 0x63, 0x6f, 0x72, 0x65, 0x1a, 0x64, 0x0a, 0x08, 0x4c, 0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f,  0x6e, 0x12, 0x2b, 0x0a, 0x03, 0x6c, 0x61, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x01, 0x42, 0x19,  0xba, 0xe9, 0xc0, 0x03, 0x14, 0x12, 0x12, 0x19, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x56, 0x40,  0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x56, 0xc0, 0x52, 0x03, 0x6c, 0x61, 0x74, 0x12, 0x2b,  0x0a, 0x03, 0x6c, 0x6e, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x01, 0x42, 0x19, 0xba, 0xe9, 0xc0,  0x03, 0x14, 0x12, 0x12, 0x19, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x66, 0x40, 0x29, 0x00, 0x00,  0x00, 0x00, 0x00, 0x80, 0x66, 0xc0, 0x52, 0x03, 0x6c, 0x6e, 0x67, 0x42, 0x0b, 0x5a, 0x09, 0x2e,  0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,}var (  file_example_proto_rawDescOnce sync.Once  file_example_proto_rawDescData = file_example_proto_rawDesc)func file_example_proto_rawDescGZIP() []byte {  file_example_proto_rawDescOnce.Do(func() {    file_example_proto_rawDescData = protoimpl.X.CompressGZIP(file_example_proto_rawDescData)  })  return file_example_proto_rawDescData}var file_example_proto_msgTypes = make([]protoimpl.MessageInfo, 2)var file_example_proto_goTypes = []interface{}{  (*Person)(nil),          // 0: examplepb.Person  (*Person_Location)(nil), // 1: examplepb.Person.Location}var file_example_proto_depIdxs = []int32{  1, // 0: examplepb.Person.home:type_name -> examplepb.Person.Location  1, // [1:1] is the sub-list for method output_type  1, // [1:1] is the sub-list for method input_type  1, // [1:1] is the sub-list for extension type_name  1, // [1:1] is the sub-list for extension extendee  0, // [0:1] is the sub-list for field type_name}func init() { file_example_proto_init() }func file_example_proto_init() {  if File_example_proto != nil {    return  }  if !protoimpl.UnsafeEnabled {    file_example_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {      switch v := v.(*Person); i {      case 0:        return &v.state      case 1:        return &v.sizeCache      case 2:        return &v.unknownFields      default:        return nil      }    }    file_example_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {      switch v := v.(*Person_Location); i {      case 0:        return &v.state      case 1:        return &v.sizeCache      case 2:        return &v.unknownFields      default:        return nil      }    }  }  type x struct{}  out := protoimpl.TypeBuilder{    File: protoimpl.DescBuilder{      GoPackagePath: reflect.TypeOf(x{}).PkgPath(),      RawDescriptor: file_example_proto_rawDesc,      NumEnums:      0,      NumMessages:   2,      NumExtensions: 0,      NumServices:   0,    },    GoTypes:           file_example_proto_goTypes,    DependencyIndexes: file_example_proto_depIdxs,    MessageInfos:      file_example_proto_msgTypes,  }.Build()  File_example_proto = out.File  file_example_proto_rawDesc = nil  file_example_proto_goTypes = nil  file_example_proto_depIdxs = nil}

learn/pgv/generated/example/example.pb.validate.go

// Code generated by protoc-gen-validate. DO NOT EDIT.// source: example.protopackage exampleimport (  "bytes"  "errors"  "fmt"  "net"  "net/mail"  "net/url"  "regexp"  "strings"  "time"  "unicode/utf8"  "github.com/golang/protobuf/ptypes")// ensure the imports are usedvar (  _ = bytes.MinRead  _ = errors.New("")  _ = fmt.Print  _ = utf8.UTFMax  _ = (*regexp.Regexp)(nil)  _ = (*strings.Reader)(nil)  _ = net.IPv4len  _ = time.Duration(0)  _ = (*url.URL)(nil)  _ = (*mail.Address)(nil)  _ = ptypes.DynamicAny{})// Validate checks the field values on Person with the rules defined in the// proto definition for this message. If any rules are violated, an error is returned.func (m *Person) Validate() error {  if m == nil {    return nil  }  if m.GetId() <= 999 {    return PersonValidationError{      field:  "Id",      reason: "value must be greater than 999",    }  }  if err := m._validateEmail(m.GetEmail()); err != nil {    return PersonValidationError{      field:  "Email",      reason: "value must be a valid email address",      cause:  err,    }  }  if len(m.GetName()) > 256 {    return PersonValidationError{      field:  "Name",      reason: "value length must be at most 256 bytes",    }  }  if !_Person_Name_Pattern.MatchString(m.GetName()) {    return PersonValidationError{      field:  "Name",      reason: "value does not match regex pattern \"^[^[0-9]A-Za-z]+( [^[0-9]A-Za-z]+)*$\"",    }  }  if m.GetHome() == nil {    return PersonValidationError{      field:  "Home",      reason: "value is required",    }  }  if v, ok := interface{}(m.GetHome()).(interface{ Validate() error }); ok {    if err := v.Validate(); err != nil {      return PersonValidationError{        field:  "Home",        reason: "embedded message failed validation",        cause:  err,      }    }  }  if m.GetIds() <= 0 {    return PersonValidationError{      field:  "Ids",      reason: "value must be greater than 0",    }  }  if val := m.GetAge(); val <= 0 || val > 120 {    return PersonValidationError{      field:  "Age",      reason: "value must be inside range (0, 120]",    }  }  if _, ok := _Person_Code_InLookup[m.GetCode()]; !ok {    return PersonValidationError{      field:  "Code",      reason: "value must be in list [1 2 3]",    }  }  if _, ok := _Person_Score_NotInLookup[m.GetScore()]; ok {    return PersonValidationError{      field:  "Score",      reason: "value must not be in list [0 99.99]",    }  }  return nil}func (m *Person) _validateHostname(host string) error {  s := strings.ToLower(strings.TrimSuffix(host, "."))  if len(host) > 253 {    return errors.New("hostname cannot exceed 253 characters")  }  for _, part := range strings.Split(s, ".") {    if l := len(part); l == 0 || l > 63 {      return errors.New("hostname part must be non-empty and cannot exceed 63 characters")    }    if part[0] == '-' {      return errors.New("hostname parts cannot begin with hyphens")    }    if part[len(part)-1] == '-' {      return errors.New("hostname parts cannot end with hyphens")    }    for _, r := range part {      if (r < 'a' || r > 'z') && (r < '0' || r > '9') && r != '-' {        return fmt.Errorf("hostname parts can only contain alphanumeric characters or hyphens, got %q", string(r))      }    }  }  return nil}func (m *Person) _validateEmail(addr string) error {  a, err := mail.ParseAddress(addr)  if err != nil {    return err  }  addr = a.Address  if len(addr) > 254 {    return errors.New("email addresses cannot exceed 254 characters")  }  parts := strings.SplitN(addr, "@", 2)  if len(parts[0]) > 64 {    return errors.New("email address local phrase cannot exceed 64 characters")  }  return m._validateHostname(parts[1])}// PersonValidationError is the validation error returned by Person.Validate if// the designated constraints aren't met.type PersonValidationError struct {  field  string  reason string  cause  error  key    bool}// Field function returns field value.func (e PersonValidationError) Field() string { return e.field }// Reason function returns reason value.func (e PersonValidationError) Reason() string { return e.reason }// Cause function returns cause value.func (e PersonValidationError) Cause() error { return e.cause }// Key function returns key value.func (e PersonValidationError) Key() bool { return e.key }// ErrorName returns error name.func (e PersonValidationError) ErrorName() string { return "PersonValidationError" }// Error satisfies the builtin error interfacefunc (e PersonValidationError) Error() string {  cause := ""  if e.cause != nil {    cause = fmt.Sprintf(" | caused by: %v", e.cause)  }  key := ""  if e.key {    key = "key for "  }  return fmt.Sprintf(    "invalid %sPerson.%s: %s%s",    key,    e.field,    e.reason,    cause)}var _ error = PersonValidationError{}var _ interface {  Field() string  Reason() string  Key() bool  Cause() error  ErrorName() string} = PersonValidationError{}var _Person_Name_Pattern = regexp.MustCompile("^[^[0-9]A-Za-z]+( [^[0-9]A-Za-z]+)*$")var _Person_Code_InLookup = map[uint32]struct{}{  1: {},  2: {},  3: {},}var _Person_Score_NotInLookup = map[float32]struct{}{  0:     {},  99.99: {},}// Validate checks the field values on Person_Location with the rules defined// in the proto definition for this message. If any rules are violated, an// error is returned.func (m *Person_Location) Validate() error {  if m == nil {    return nil  }  if val := m.GetLat(); val < -90 || val > 90 {    return Person_LocationValidationError{      field:  "Lat",      reason: "value must be inside range [-90, 90]",    }  }  if val := m.GetLng(); val < -180 || val > 180 {    return Person_LocationValidationError{      field:  "Lng",      reason: "value must be inside range [-180, 180]",    }  }  return nil}// Person_LocationValidationError is the validation error returned by// Person_Location.Validate if the designated constraints aren't met.type Person_LocationValidationError struct {  field  string  reason string  cause  error  key    bool}// Field function returns field value.func (e Person_LocationValidationError) Field() string { return e.field }// Reason function returns reason value.func (e Person_LocationValidationError) Reason() string { return e.reason }// Cause function returns cause value.func (e Person_LocationValidationError) Cause() error { return e.cause }// Key function returns key value.func (e Person_LocationValidationError) Key() bool { return e.key }// ErrorName returns error name.func (e Person_LocationValidationError) ErrorName() string { return "Person_LocationValidationError" }// Error satisfies the builtin error interfacefunc (e Person_LocationValidationError) Error() string {  cause := ""  if e.cause != nil {    cause = fmt.Sprintf(" | caused by: %v", e.cause)  }  key := ""  if e.key {    key = "key for "  }  return fmt.Sprintf(    "invalid %sPerson_Location.%s: %s%s",    key,    e.field,    e.reason,    cause)}var _ error = Person_LocationValidationError{}var _ interface {  Field() string  Reason() string  Key() bool  Cause() error  ErrorName() string} = Person_LocationValidationError{}

然后我们就可以通过Validate方法来进行验证

package mainimport (  "fmt"  . "learn/pgv/generated/example")func main() {  p := new(Person)  err := p.Validate() // err: Id must be greater than 999  fmt.Println(err)  p.Id = 1000  err = p.Validate() // err: Email must be a valid email address  p.Email = "example@bufbuild.com"  err = p.Validate() // err: Name must match pattern '^[^\d\s]+( [^\d\s]+)*$'  p.Name = "Protocol Buffer"  err = p.Validate() // err: Home is required  p.Home = &Person_Location{Lat: 37.7, Lng: 999}  err = p.Validate() // err: Home.Lng must be within [-180, 180]  p.Home.Lng = -122.4  err = p.Validate() // err: nil}

运行效果如下

% go run main.go          invalid Person.Id: value must be greater than 999

通过proto的注解扩展,配合这个插件,我们可以非常方便地实现参数校验能力,真正把idl当作交流沟通的完备工具,有效提升开发效率

[(validate.rules).uint32 = {in: [1,2,3]}];