package swag import ( "fmt" "go/ast" goparser "go/parser" "go/token" "log" "net/http" "os" "path" "path/filepath" "reflect" "sort" "strconv" "strings" "unicode" "github.com/go-openapi/jsonreference" "github.com/go-openapi/spec" "github.com/pkg/errors" ) const ( // CamelCase indicates using CamelCase strategy for struct field. CamelCase = "camelcase" // PascalCase indicates using PascalCase strategy for struct field. PascalCase = "pascalcase" // SnakeCase indicates using SnakeCase strategy for struct field. SnakeCase = "snakecase" ) // Parser implements a parser for Go source files. type Parser struct { // swagger represents the root document object for the API specification swagger *spec.Swagger //files is a map that stores map[real_go_file_path][astFile] files map[string]*ast.File // TypeDefinitions is a map that stores [package name][type name][*ast.TypeSpec] TypeDefinitions map[string]map[string]*ast.TypeSpec // CustomPrimitiveTypes is a map that stores custom primitive types to actual golang types [type name][string] CustomPrimitiveTypes map[string]string //registerTypes is a map that stores [refTypeName][*ast.TypeSpec] registerTypes map[string]*ast.TypeSpec PropNamingStrategy string // structStack stores full names of the structures that were already parsed or are being parsed now structStack []string } // New creates a new Parser with default properties. func New() *Parser { parser := &Parser{ swagger: &spec.Swagger{ SwaggerProps: spec.SwaggerProps{ Info: &spec.Info{ InfoProps: spec.InfoProps{ Contact: &spec.ContactInfo{}, License: &spec.License{}, }, }, Paths: &spec.Paths{ Paths: make(map[string]spec.PathItem), }, Definitions: make(map[string]spec.Schema), }, }, files: make(map[string]*ast.File), TypeDefinitions: make(map[string]map[string]*ast.TypeSpec), CustomPrimitiveTypes: make(map[string]string), registerTypes: make(map[string]*ast.TypeSpec), } return parser } // ParseAPI parses general api info for gived searchDir and mainAPIFile func (parser *Parser) ParseAPI(searchDir string, mainAPIFile string) error { log.Println("Generate general API Info") if err := parser.getAllGoFileInfo(searchDir); err != nil { return err } parser.ParseGeneralAPIInfo(path.Join(searchDir, mainAPIFile)) for _, astFile := range parser.files { parser.ParseType(astFile) } for _, astFile := range parser.files { parser.ParseRouterAPIInfo(astFile) } parser.ParseDefinitions() return nil } // ParseGeneralAPIInfo parses general api info for gived mainAPIFile path func (parser *Parser) ParseGeneralAPIInfo(mainAPIFile string) error { fileSet := token.NewFileSet() fileTree, err := goparser.ParseFile(fileSet, mainAPIFile, nil, goparser.ParseComments) if err != nil { return errors.Wrap(err, "cannot parse soure files") } parser.swagger.Swagger = "2.0" securityMap := map[string]*spec.SecurityScheme{} // templated defaults parser.swagger.Info.Version = "{{.Version}}" parser.swagger.Info.Title = "{{.Title}}" parser.swagger.Info.Description = "{{.Description}}" parser.swagger.Host = "{{.Host}}" parser.swagger.BasePath = "{{.BasePath}}" if fileTree.Comments != nil { for _, comment := range fileTree.Comments { comments := strings.Split(comment.Text(), "\n") for _, commentLine := range comments { attribute := strings.ToLower(strings.Split(commentLine, " ")[0]) switch attribute { case "@version": parser.swagger.Info.Version = strings.TrimSpace(commentLine[len(attribute):]) case "@title": parser.swagger.Info.Title = strings.TrimSpace(commentLine[len(attribute):]) case "@description": if parser.swagger.Info.Description == "{{.Description}}" { parser.swagger.Info.Description = strings.TrimSpace(commentLine[len(attribute):]) } else { parser.swagger.Info.Description += "\n" + strings.TrimSpace(commentLine[len(attribute):]) } case "@termsofservice": parser.swagger.Info.TermsOfService = strings.TrimSpace(commentLine[len(attribute):]) case "@contact.name": parser.swagger.Info.Contact.Name = strings.TrimSpace(commentLine[len(attribute):]) case "@contact.email": parser.swagger.Info.Contact.Email = strings.TrimSpace(commentLine[len(attribute):]) case "@contact.url": parser.swagger.Info.Contact.URL = strings.TrimSpace(commentLine[len(attribute):]) case "@license.name": parser.swagger.Info.License.Name = strings.TrimSpace(commentLine[len(attribute):]) case "@license.url": parser.swagger.Info.License.URL = strings.TrimSpace(commentLine[len(attribute):]) case "@host": parser.swagger.Host = strings.TrimSpace(commentLine[len(attribute):]) case "@basepath": parser.swagger.BasePath = strings.TrimSpace(commentLine[len(attribute):]) case "@schemes": parser.swagger.Schemes = GetSchemes(commentLine) case "@tag.name": commentInfo := strings.TrimSpace(commentLine[len(attribute):]) parser.swagger.Tags = append(parser.swagger.Tags, spec.Tag{ TagProps: spec.TagProps{ Name: strings.TrimSpace(commentInfo), }, }) case "@tag.description": commentInfo := strings.TrimSpace(commentLine[len(attribute):]) tag := parser.swagger.Tags[len(parser.swagger.Tags)-1] tag.TagProps.Description = commentInfo replaceLastTag(parser.swagger.Tags, tag) case "@tag.docs.url": commentInfo := strings.TrimSpace(commentLine[len(attribute):]) tag := parser.swagger.Tags[len(parser.swagger.Tags)-1] tag.TagProps.ExternalDocs = &spec.ExternalDocumentation{ URL: commentInfo, } replaceLastTag(parser.swagger.Tags, tag) case "@tag.docs.description": commentInfo := strings.TrimSpace(commentLine[len(attribute):]) tag := parser.swagger.Tags[len(parser.swagger.Tags)-1] if tag.TagProps.ExternalDocs == nil { log.Panic("@tag.docs.description needs to come after a @tags.docs.url") } tag.TagProps.ExternalDocs.Description = commentInfo replaceLastTag(parser.swagger.Tags, tag) } } for i := 0; i < len(comments); i++ { attribute := strings.ToLower(strings.Split(comments[i], " ")[0]) switch attribute { case "@securitydefinitions.basic": securityMap[strings.TrimSpace(comments[i][len(attribute):])] = spec.BasicAuth() case "@securitydefinitions.apikey": attrMap := map[string]string{} for _, v := range comments[i+1:] { securityAttr := strings.ToLower(strings.Split(v, " ")[0]) if securityAttr == "@in" || securityAttr == "@name" { attrMap[securityAttr] = strings.TrimSpace(v[len(securityAttr):]) } // next securityDefinitions if strings.Index(securityAttr, "@securitydefinitions.") == 0 { break } } if len(attrMap) != 2 { log.Panic("@securitydefinitions.apikey is @name and @in required") } securityMap[strings.TrimSpace(comments[i][len(attribute):])] = spec.APIKeyAuth(attrMap["@name"], attrMap["@in"]) case "@securitydefinitions.oauth2.application": attrMap := map[string]string{} scopes := map[string]string{} for _, v := range comments[i+1:] { securityAttr := strings.ToLower(strings.Split(v, " ")[0]) if securityAttr == "@tokenurl" { attrMap[securityAttr] = strings.TrimSpace(v[len(securityAttr):]) } else if isExistsScope(securityAttr) { scopes[getScopeScheme(securityAttr)] = v[len(securityAttr):] } // next securityDefinitions if strings.Index(securityAttr, "@securitydefinitions.") == 0 { break } } if len(attrMap) != 1 { log.Panic("@securitydefinitions.oauth2.application is @tokenUrl required") } securityScheme := spec.OAuth2Application(attrMap["@tokenurl"]) for scope, description := range scopes { securityScheme.AddScope(scope, description) } securityMap[strings.TrimSpace(comments[i][len(attribute):])] = securityScheme case "@securitydefinitions.oauth2.implicit": attrMap := map[string]string{} scopes := map[string]string{} for _, v := range comments[i+1:] { securityAttr := strings.ToLower(strings.Split(v, " ")[0]) if securityAttr == "@authorizationurl" { attrMap[securityAttr] = strings.TrimSpace(v[len(securityAttr):]) } else if isExistsScope(securityAttr) { scopes[getScopeScheme(securityAttr)] = v[len(securityAttr):] } // next securityDefinitions if strings.Index(securityAttr, "@securitydefinitions.") == 0 { break } } if len(attrMap) != 1 { log.Panic("@securitydefinitions.oauth2.implicit is @authorizationUrl required") } securityScheme := spec.OAuth2Implicit(attrMap["@authorizationurl"]) for scope, description := range scopes { securityScheme.AddScope(scope, description) } securityMap[strings.TrimSpace(comments[i][len(attribute):])] = securityScheme case "@securitydefinitions.oauth2.password": attrMap := map[string]string{} scopes := map[string]string{} for _, v := range comments[i+1:] { securityAttr := strings.ToLower(strings.Split(v, " ")[0]) if securityAttr == "@tokenurl" { attrMap[securityAttr] = strings.TrimSpace(v[len(securityAttr):]) } else if isExistsScope(securityAttr) { scopes[getScopeScheme(securityAttr)] = v[len(securityAttr):] } // next securityDefinitions if strings.Index(securityAttr, "@securitydefinitions.") == 0 { break } } if len(attrMap) != 1 { log.Panic("@securitydefinitions.oauth2.password is @tokenUrl required") } securityScheme := spec.OAuth2Password(attrMap["@tokenurl"]) for scope, description := range scopes { securityScheme.AddScope(scope, description) } securityMap[strings.TrimSpace(comments[i][len(attribute):])] = securityScheme case "@securitydefinitions.oauth2.accesscode": attrMap := map[string]string{} scopes := map[string]string{} for _, v := range comments[i+1:] { securityAttr := strings.ToLower(strings.Split(v, " ")[0]) if securityAttr == "@tokenurl" || securityAttr == "@authorizationurl" { attrMap[securityAttr] = strings.TrimSpace(v[len(securityAttr):]) } else if isExistsScope(securityAttr) { scopes[getScopeScheme(securityAttr)] = v[len(securityAttr):] } // next securityDefinitions if strings.Index(securityAttr, "@securitydefinitions.") == 0 { break } } if len(attrMap) != 2 { log.Panic("@securitydefinitions.oauth2.accessCode is @tokenUrl and @authorizationUrl required") } securityScheme := spec.OAuth2AccessToken(attrMap["@authorizationurl"], attrMap["@tokenurl"]) for scope, description := range scopes { securityScheme.AddScope(scope, description) } securityMap[strings.TrimSpace(comments[i][len(attribute):])] = securityScheme } } } } if len(securityMap) > 0 { parser.swagger.SecurityDefinitions = securityMap } return nil } func getScopeScheme(scope string) string { scopeValue := scope[strings.Index(scope, "@scope."):] if scopeValue == "" { panic("@scope is empty") } return scope[len("@scope."):] } func isExistsScope(scope string) bool { s := strings.Fields(scope) for _, v := range s { if strings.Index(v, "@scope.") != -1 { if strings.Index(v, ",") != -1 { panic("@scope can't use comma(,) get=" + v) } } } return strings.Index(scope, "@scope.") != -1 } // GetSchemes parses swagger schemes for given commentLine func GetSchemes(commentLine string) []string { attribute := strings.ToLower(strings.Split(commentLine, " ")[0]) return strings.Split(strings.TrimSpace(commentLine[len(attribute):]), " ") } // ParseRouterAPIInfo parses router api info for given astFile func (parser *Parser) ParseRouterAPIInfo(astFile *ast.File) { for _, astDescription := range astFile.Decls { switch astDeclaration := astDescription.(type) { case *ast.FuncDecl: if astDeclaration.Doc != nil && astDeclaration.Doc.List != nil { operation := NewOperation() //for per 'function' comment, create a new 'Operation' object operation.parser = parser for _, comment := range astDeclaration.Doc.List { if err := operation.ParseComment(comment.Text, astFile); err != nil { log.Panicf("ParseComment panic:%+v", err) } } var pathItem spec.PathItem var ok bool if pathItem, ok = parser.swagger.Paths.Paths[operation.Path]; !ok { pathItem = spec.PathItem{} } switch strings.ToUpper(operation.HTTPMethod) { case http.MethodGet: pathItem.Get = &operation.Operation case http.MethodPost: pathItem.Post = &operation.Operation case http.MethodDelete: pathItem.Delete = &operation.Operation case http.MethodPut: pathItem.Put = &operation.Operation case http.MethodPatch: pathItem.Patch = &operation.Operation case http.MethodHead: pathItem.Head = &operation.Operation case http.MethodOptions: pathItem.Options = &operation.Operation } parser.swagger.Paths.Paths[operation.Path] = pathItem } } } } // ParseType parses type info for given astFile. func (parser *Parser) ParseType(astFile *ast.File) { if _, ok := parser.TypeDefinitions[astFile.Name.String()]; !ok { parser.TypeDefinitions[astFile.Name.String()] = make(map[string]*ast.TypeSpec) } for _, astDeclaration := range astFile.Decls { if generalDeclaration, ok := astDeclaration.(*ast.GenDecl); ok && generalDeclaration.Tok == token.TYPE { for _, astSpec := range generalDeclaration.Specs { if typeSpec, ok := astSpec.(*ast.TypeSpec); ok { typeName := fmt.Sprintf("%v", typeSpec.Type) // check if its a custom primitive type if IsGolangPrimitiveType(typeName) { parser.CustomPrimitiveTypes[typeSpec.Name.String()] = TransToValidSchemeType(typeName) } else { parser.TypeDefinitions[astFile.Name.String()][typeSpec.Name.String()] = typeSpec } } } } } } func (parser *Parser) isInStructStack(refTypeName string) bool { for _, structName := range parser.structStack { if refTypeName == structName { return true } } return false } // ParseDefinitions parses Swagger Api definitions. func (parser *Parser) ParseDefinitions() { for refTypeName, typeSpec := range parser.registerTypes { ss := strings.Split(refTypeName, ".") pkgName := ss[0] parser.structStack = nil parser.ParseDefinition(pkgName, typeSpec.Name.Name, typeSpec) } } // ParseDefinition parses given type spec that corresponds to the type under // given name and package, and populates swagger schema definitions registry // with a schema for the given type func (parser *Parser) ParseDefinition(pkgName, typeName string, typeSpec *ast.TypeSpec) { refTypeName := fullTypeName(pkgName, typeName) if _, isParsed := parser.swagger.Definitions[refTypeName]; isParsed { log.Println("Skipping '" + refTypeName + "', already parsed.") return } if parser.isInStructStack(refTypeName) { log.Println("Skipping '" + refTypeName + "', recursion detected.") return } parser.structStack = append(parser.structStack, refTypeName) log.Println("Generating " + refTypeName) parser.swagger.Definitions[refTypeName] = parser.parseTypeExpr(pkgName, typeName, typeSpec.Type, true) } func (parser *Parser) collectRequiredFields(pkgName string, properties map[string]spec.Schema) (requiredFields []string) { // created sorted list of properties keys so when we iterate over them it's deterministic ks := make([]string, 0, len(properties)) for k := range properties { ks = append(ks, k) } sort.Strings(ks) requiredFields = make([]string, 0) // iterate over keys list instead of map to avoid the random shuffle of the order that go does for maps for _, k := range ks { prop := properties[k] // todo find the pkgName of the property type tname := prop.SchemaProps.Type[0] if _, ok := parser.TypeDefinitions[pkgName][tname]; ok { tspec := parser.TypeDefinitions[pkgName][tname] parser.ParseDefinition(pkgName, tname, tspec) } if tname != "object" { requiredFields = append(requiredFields, prop.SchemaProps.Required...) } properties[k] = prop } return } func fullTypeName(pkgName, typeName string) string { if pkgName != "" { return pkgName + "." + typeName } return typeName } // parseTypeExpr parses given type expression that corresponds to the type under // given name and package, and returns swagger schema for it. func (parser *Parser) parseTypeExpr(pkgName, typeName string, typeExpr ast.Expr, flattenRequired bool) spec.Schema { switch expr := typeExpr.(type) { // type Foo struct {...} case *ast.StructType: refTypeName := fullTypeName(pkgName, typeName) if schema, isParsed := parser.swagger.Definitions[refTypeName]; isParsed { return schema } properties := make(map[string]spec.Schema) for _, field := range expr.Fields.List { var fieldProps map[string]spec.Schema if field.Names == nil { fieldProps = parser.parseAnonymousField(pkgName, field) } else { fieldProps = parser.parseStruct(pkgName, field) } for k, v := range fieldProps { properties[k] = v } } required := parser.collectRequiredFields(pkgName, properties) // unset required from properties because we've aggregated them if flattenRequired { for k, prop := range properties { tname := prop.SchemaProps.Type[0] if tname != "object" { prop.SchemaProps.Required = make([]string, 0) } properties[k] = prop } } return spec.Schema{ SchemaProps: spec.SchemaProps{ Type: []string{"object"}, Properties: properties, Required: required, }, } // type Foo Baz case *ast.Ident: refTypeName := fullTypeName(pkgName, expr.Name) if _, isParsed := parser.swagger.Definitions[refTypeName]; !isParsed { typedef := parser.TypeDefinitions[pkgName][expr.Name] parser.ParseDefinition(pkgName, expr.Name, typedef) } return parser.swagger.Definitions[refTypeName] // type Foo *Baz case *ast.StarExpr: return parser.parseTypeExpr(pkgName, typeName, expr.X, true) // type Foo []Baz case *ast.ArrayType: itemSchema := parser.parseTypeExpr(pkgName, "", expr.Elt, true) return spec.Schema{ SchemaProps: spec.SchemaProps{ Type: []string{"array"}, Items: &spec.SchemaOrArray{ Schema: &itemSchema, }, }, } // type Foo pkg.Bar case *ast.SelectorExpr: if xIdent, ok := expr.X.(*ast.Ident); ok { pkgName = xIdent.Name typeName = expr.Sel.Name refTypeName := fullTypeName(pkgName, typeName) if _, isParsed := parser.swagger.Definitions[refTypeName]; !isParsed { typedef := parser.TypeDefinitions[pkgName][typeName] parser.ParseDefinition(pkgName, typeName, typedef) } return parser.swagger.Definitions[refTypeName] } // type Foo map[string]Bar // ... default: log.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead.\n", typeExpr) } return spec.Schema{ SchemaProps: spec.SchemaProps{ Type: []string{"object"}, }, } } type structField struct { name string schemaType string arrayType string formatType string isRequired bool crossPkg string exampleValue interface{} maximum *float64 minimum *float64 maxLength *int64 minLength *int64 enums []interface{} defaultValue interface{} } func (parser *Parser) parseStruct(pkgName string, field *ast.Field) (properties map[string]spec.Schema) { properties = map[string]spec.Schema{} structField := parser.parseField(field) if structField.name == "" { return } var desc string if field.Doc != nil { desc = strings.TrimSpace(field.Doc.Text()) } // TODO: find package of schemaType and/or arrayType if structField.crossPkg != "" { pkgName = structField.crossPkg } if _, ok := parser.TypeDefinitions[pkgName][structField.schemaType]; ok { // user type field // write definition if not yet present parser.ParseDefinition(pkgName, structField.schemaType, parser.TypeDefinitions[pkgName][structField.schemaType]) properties[structField.name] = spec.Schema{ SchemaProps: spec.SchemaProps{ Type: []string{"object"}, // to avoid swagger validation error Description: desc, Ref: spec.Ref{ Ref: jsonreference.MustCreateRef("#/definitions/" + pkgName + "." + structField.schemaType), }, }, } } else if structField.schemaType == "array" { // array field type // if defined -- ref it if _, ok := parser.TypeDefinitions[pkgName][structField.arrayType]; ok { // user type in array parser.ParseDefinition(pkgName, structField.arrayType, parser.TypeDefinitions[pkgName][structField.arrayType]) properties[structField.name] = spec.Schema{ SchemaProps: spec.SchemaProps{ Type: []string{structField.schemaType}, Description: desc, Items: &spec.SchemaOrArray{ Schema: &spec.Schema{ SchemaProps: spec.SchemaProps{ Ref: spec.Ref{ Ref: jsonreference.MustCreateRef("#/definitions/" + pkgName + "." + structField.arrayType), }, }, }, }, }, } } else { // standard type in array required := make([]string, 0) if structField.isRequired { required = append(required, structField.name) } properties[structField.name] = spec.Schema{ SchemaProps: spec.SchemaProps{ Type: []string{structField.schemaType}, Description: desc, Format: structField.formatType, Required: required, Items: &spec.SchemaOrArray{ Schema: &spec.Schema{ SchemaProps: spec.SchemaProps{ Type: []string{structField.arrayType}, Maximum: structField.maximum, Minimum: structField.minimum, MaxLength: structField.maxLength, MinLength: structField.minLength, Enum: structField.enums, Default: structField.defaultValue, }, }, }, }, SwaggerSchemaProps: spec.SwaggerSchemaProps{ Example: structField.exampleValue, }, } } } else { required := make([]string, 0) if structField.isRequired { required = append(required, structField.name) } properties[structField.name] = spec.Schema{ SchemaProps: spec.SchemaProps{ Type: []string{structField.schemaType}, Description: desc, Format: structField.formatType, Required: required, Maximum: structField.maximum, Minimum: structField.minimum, MaxLength: structField.maxLength, MinLength: structField.minLength, Enum: structField.enums, Default: structField.defaultValue, }, SwaggerSchemaProps: spec.SwaggerSchemaProps{ Example: structField.exampleValue, }, } nestStruct, ok := field.Type.(*ast.StructType) if ok { props := map[string]spec.Schema{} nestRequired := make([]string, 0) for _, v := range nestStruct.Fields.List { p := parser.parseStruct(pkgName, v) for k, v := range p { if v.SchemaProps.Type[0] != "object" { nestRequired = append(nestRequired, v.SchemaProps.Required...) v.SchemaProps.Required = make([]string, 0) } props[k] = v } } properties[structField.name] = spec.Schema{ SchemaProps: spec.SchemaProps{ Type: []string{structField.schemaType}, Description: desc, Format: structField.formatType, Properties: props, Required: nestRequired, Maximum: structField.maximum, Minimum: structField.minimum, MaxLength: structField.maxLength, MinLength: structField.minLength, Enum: structField.enums, Default: structField.defaultValue, }, SwaggerSchemaProps: spec.SwaggerSchemaProps{ Example: structField.exampleValue, }, } } } return } func (parser *Parser) parseAnonymousField(pkgName string, field *ast.Field) map[string]spec.Schema { properties := make(map[string]spec.Schema) fullTypeName := "" switch ftype := field.Type.(type) { case *ast.Ident: fullTypeName = ftype.Name case *ast.StarExpr: if ftypeX, ok := ftype.X.(*ast.Ident); ok { fullTypeName = ftypeX.Name } default: log.Printf("Field type of '%T' is unsupported. Skipping", ftype) return properties } typeName := fullTypeName if splits := strings.Split(fullTypeName, "."); len(splits) > 1 { pkgName = splits[0] typeName = splits[1] } typeSpec := parser.TypeDefinitions[pkgName][typeName] schema := parser.parseTypeExpr(pkgName, typeName, typeSpec.Type, false) schemaType := "unknown" if len(schema.SchemaProps.Type) > 0 { schemaType = schema.SchemaProps.Type[0] } switch schemaType { case "object": for k, v := range schema.SchemaProps.Properties { properties[k] = v } case "array": properties[typeName] = schema default: log.Printf("Can't extract properties from a schema of type '%s'", schemaType) } return properties } func (parser *Parser) parseField(field *ast.Field) *structField { prop := getPropertyName(field, parser) if len(prop.ArrayType) == 0 { CheckSchemaType(prop.SchemaType) } else { CheckSchemaType("array") } structField := &structField{ name: field.Names[0].Name, schemaType: prop.SchemaType, arrayType: prop.ArrayType, crossPkg: prop.CrossPkg, } switch parser.PropNamingStrategy { case SnakeCase: structField.name = toSnakeCase(structField.name) case PascalCase: //use struct field name case CamelCase: structField.name = toLowerCamelCase(structField.name) default: structField.name = toLowerCamelCase(structField.name) } if field.Tag == nil { return structField } // `json:"tag"` -> json:"tag" structTag := reflect.StructTag(strings.Replace(field.Tag.Value, "`", "", -1)) jsonTag := structTag.Get("json") // json:"tag,hoge" if strings.Contains(jsonTag, ",") { // json:",hoge" if strings.HasPrefix(jsonTag, ",") { jsonTag = "" } else { jsonTag = strings.SplitN(jsonTag, ",", 2)[0] } } if jsonTag == "-" { structField.name = "" } else if jsonTag != "" { structField.name = jsonTag } if typeTag := structTag.Get("swaggertype"); typeTag != "" { parts := strings.Split(typeTag, ",") if 0 < len(parts) && len(parts) <= 2 { newSchemaType := parts[0] newArrayType := structField.arrayType if len(parts) >= 2 && newSchemaType == "array" { newArrayType = parts[1] } CheckSchemaType(newSchemaType) CheckSchemaType(newArrayType) structField.schemaType = newSchemaType structField.arrayType = newArrayType } } if exampleTag := structTag.Get("example"); exampleTag != "" { structField.exampleValue = defineTypeOfExample(structField.schemaType, structField.arrayType, exampleTag) } if formatTag := structTag.Get("format"); formatTag != "" { structField.formatType = formatTag } if bindingTag := structTag.Get("binding"); bindingTag != "" { for _, val := range strings.Split(bindingTag, ",") { if val == "required" { structField.isRequired = true break } } } if validateTag := structTag.Get("validate"); validateTag != "" { for _, val := range strings.Split(validateTag, ",") { if val == "required" { structField.isRequired = true break } } } if enumsTag := structTag.Get("enums"); enumsTag != "" { enumType := structField.schemaType if structField.schemaType == "array" { enumType = structField.arrayType } for _, e := range strings.Split(enumsTag, ",") { structField.enums = append(structField.enums, defineType(enumType, e)) } } if defaultTag := structTag.Get("default"); defaultTag != "" { structField.defaultValue = defineType(structField.schemaType, defaultTag) } if IsNumericType(structField.schemaType) || IsNumericType(structField.arrayType) { structField.maximum = getFloatTag(structTag, "maximum") structField.minimum = getFloatTag(structTag, "minimum") } if structField.schemaType == "string" || structField.arrayType == "string" { structField.maxLength = getIntTag(structTag, "maxLength") structField.minLength = getIntTag(structTag, "minLength") } return structField } func replaceLastTag(slice []spec.Tag, element spec.Tag) { slice = slice[:len(slice)-1] slice = append(slice, element) } func getFloatTag(structTag reflect.StructTag, tagName string) *float64 { strValue := structTag.Get(tagName) if strValue == "" { return nil } value, err := strconv.ParseFloat(strValue, 64) if err != nil { panic(fmt.Errorf("can't parse numeric value of %q tag: %v", tagName, err)) } return &value } func getIntTag(structTag reflect.StructTag, tagName string) *int64 { strValue := structTag.Get(tagName) if strValue == "" { return nil } value, err := strconv.ParseInt(strValue, 10, 64) if err != nil { panic(fmt.Errorf("can't parse numeric value of %q tag: %v", tagName, err)) } return &value } func toSnakeCase(in string) string { runes := []rune(in) length := len(runes) var out []rune for i := 0; i < length; i++ { if i > 0 && unicode.IsUpper(runes[i]) && ((i+1 < length && unicode.IsLower(runes[i+1])) || unicode.IsLower(runes[i-1])) { out = append(out, '_') } out = append(out, unicode.ToLower(runes[i])) } return string(out) } func toLowerCamelCase(in string) string { runes := []rune(in) var out []rune flag := false for i, curr := range runes { if (i == 0 && unicode.IsUpper(curr)) || (flag && unicode.IsUpper(curr)) { out = append(out, unicode.ToLower(curr)) flag = true } else { out = append(out, curr) flag = false } } return string(out) } // defineTypeOfExample example value define the type (object and array unsupported) func defineTypeOfExample(schemaType, arrayType, exampleValue string) interface{} { switch schemaType { case "string": return exampleValue case "number": v, err := strconv.ParseFloat(exampleValue, 64) if err != nil { panic(fmt.Errorf("example value %s can't convert to %s err: %s", exampleValue, schemaType, err)) } return v case "integer": v, err := strconv.Atoi(exampleValue) if err != nil { panic(fmt.Errorf("example value %s can't convert to %s err: %s", exampleValue, schemaType, err)) } return v case "boolean": v, err := strconv.ParseBool(exampleValue) if err != nil { panic(fmt.Errorf("example value %s can't convert to %s err: %s", exampleValue, schemaType, err)) } return v case "array": values := strings.Split(exampleValue, ",") result := make([]interface{}, 0) for _, value := range values { result = append(result, defineTypeOfExample(arrayType, "", value)) } return result default: panic(fmt.Errorf("%s is unsupported type in example value", schemaType)) } } // GetAllGoFileInfo gets all Go source files information for given searchDir. func (parser *Parser) getAllGoFileInfo(searchDir string) error { return filepath.Walk(searchDir, parser.visit) } func (parser *Parser) visit(path string, f os.FileInfo, err error) error { if err := Skip(f); err != nil { return err } if ext := filepath.Ext(path); ext == ".go" { fset := token.NewFileSet() // positions are relative to fset astFile, err := goparser.ParseFile(fset, path, nil, goparser.ParseComments) if err != nil { log.Panicf("ParseFile panic:%+v", err) } parser.files[path] = astFile } return nil } // Skip returns filepath.SkipDir error if match vendor and hidden folder func Skip(f os.FileInfo) error { // exclude vendor folder if f.IsDir() && f.Name() == "vendor" { return filepath.SkipDir } // exclude all hidden folder if f.IsDir() && len(f.Name()) > 1 && f.Name()[0] == '.' { return filepath.SkipDir } return nil } // GetSwagger returns *spec.Swagger which is the root document object for the API specification. func (parser *Parser) GetSwagger() *spec.Swagger { return parser.swagger }