Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cmd/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,31 @@ func (a *MethodAws) InitIamCommand() {
return
}

excludeDefaultRoles, err := cmd.Flags().GetBool("exclude-default-roles")
excludeAWSManagedRoles, err := cmd.Flags().GetBool("exclude-aws-managed-roles")
if err != nil {
a.OutputSignal.AddError(err)
return
}

// Get Config
config := getIamEnumerateConfig(accountID, excludeDefaultRoles)
config := getIamEnumerateConfig(accountID, excludeAWSManagedRoles)

// Get Report
report := iamInternal.EnumerateIam(cmd.Context(), *a.AwsConfig, config)
a.OutputSignal.Content = report
},
}

enumerateCmd.Flags().Bool("exclude-default-roles", false, "Exclude default roles")
enumerateCmd.Flags().Bool("exclude-aws-managed-roles", false, "Exclude AWS managed roles")

iamCmd.AddCommand(enumerateCmd)
a.RootCmd.AddCommand(iamCmd)
}

// getIamEnumerateConfig returns an IamEnumerateConfig with the given regions and account ID
func getIamEnumerateConfig(accountID string, excludeDefaultRoles bool) iam.IamEnumerateConfig {
func getIamEnumerateConfig(accountID string, excludeAWSManagedRoles bool) iam.IamEnumerateConfig {
return iam.IamEnumerateConfig{
AccountId: accountID,
ExcludeDefaultRoles: excludeDefaultRoles,
AccountId: accountID,
ExcludeAwsManagedRoles: excludeAWSManagedRoles,
}
}
6 changes: 3 additions & 3 deletions cmd/waf.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
// Generated
waffern "github.com/Method-Security/methodaws/generated/go/waf"
// Internal
"github.com/Method-Security/methodaws/internal/sts"
"github.com/Method-Security/methodaws/internal/waf"
"github.com/Method-Security/methodaws/utils"

// External
"github.com/spf13/cobra"
Expand All @@ -28,14 +28,14 @@ func (a *MethodAws) InitWAFCommand() {
Long: `Enumerate WAFs in your AWS account.`,
Run: func(cmd *cobra.Command, args []string) {
// Get Account ID
accountID, err := sts.GetAccountID(cmd.Context(), *a.AwsConfig)
accountID, err := utils.GetAccountID(cmd.Context(), *a.AwsConfig)
if err != nil {
a.OutputSignal.AddError(err)
return
}

// Config
config := getWafEnumerateConfig(*accountID, a.RootFlags.Regions)
config := getWafEnumerateConfig(accountID, a.RootFlags.Regions)

// Report
report := waf.EnumerateWAF(cmd.Context(), *a.AwsConfig, config)
Expand Down
3 changes: 1 addition & 2 deletions fern/definition/iam/enumerate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ types:
IamEnumerateConfig:
properties:
accountId: string
excludeDefaultRoles: boolean
excludeAWSManagedRoles: boolean
# Supporting Structs
RoleLastUsed:
properties:
Expand All @@ -20,7 +20,6 @@ types:
AttachedPolicyConfigurationInfo:
properties:
policyDocument: optional<string>
isCustomerManaged: optional<boolean>
AttachedPolicy:
properties:
identification: AttachedPolicyIdentificationInfo
Expand Down
87 changes: 52 additions & 35 deletions internal/apigateway/enumerate/apigwv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,23 @@ func getRestAPIRoutes(ctx context.Context, client *apigateway.Client, apiID, reg
var routes []*apigatewayfern.Route
var errors []string

resources, err := client.GetResources(ctx, &apigateway.GetResourcesInput{RestApiId: &apiID})
if err != nil {
log.Warn("Failed to get resources for API",
svc1log.SafeParam("region", region),
svc1log.SafeParam("apiId", apiID),
svc1log.Stacktrace(err))
return routes, []string{fmt.Sprintf("GetResources failed for API %s: %s", apiID, err.Error())}
// Paginate through all resources
var allResources []types.Resource
paginator := apigateway.NewGetResourcesPaginator(client, &apigateway.GetResourcesInput{RestApiId: &apiID})
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
log.Warn("Failed to get resources for API",
svc1log.SafeParam("region", region),
svc1log.SafeParam("apiId", apiID),
svc1log.Stacktrace(err))
errors = append(errors, fmt.Sprintf("GetResources failed for API %s: %s", apiID, err.Error()))
break
}
allResources = append(allResources, page.Items...)
}

for _, resource := range resources.Items {
for _, resource := range allResources {
for methodName := range resource.ResourceMethods {
method, err := client.GetMethod(ctx, &apigateway.GetMethodInput{
RestApiId: &apiID,
Expand Down Expand Up @@ -457,15 +464,21 @@ func getAPICertificates(ctx context.Context, client *apigateway.Client) ([]*apig
var certificates []*apigatewayfern.Certificate
var errors []string

// Get domain names associated with the API
domainNames, err := client.GetDomainNames(ctx, &apigateway.GetDomainNamesInput{})
if err != nil {
log.Warn("Failed to get domain names",
svc1log.Stacktrace(err))
return certificates, []string{fmt.Sprintf("GetDomainNames failed: %s", err.Error())}
// Get domain names associated with the API with pagination
var allDomains []types.DomainName
domainPaginator := apigateway.NewGetDomainNamesPaginator(client, &apigateway.GetDomainNamesInput{})
for domainPaginator.HasMorePages() {
page, err := domainPaginator.NextPage(ctx)
if err != nil {
log.Warn("Failed to get domain names",
svc1log.Stacktrace(err))
errors = append(errors, fmt.Sprintf("GetDomainNames failed: %s", err.Error()))
break
}
allDomains = append(allDomains, page.Items...)
}

for _, domain := range domainNames.Items {
for _, domain := range allDomains {
if domain.CertificateArn != nil {
cert := &apigatewayfern.Certificate{
Arn: *domain.CertificateArn,
Expand Down Expand Up @@ -498,33 +511,37 @@ func getAPIKeysAndUsagePlans(ctx context.Context, client *apigateway.Client, api
var usagePlans []string
var errors []string

// Get API keys - filter for those associated with this API
keysResult, err := client.GetApiKeys(ctx, &apigateway.GetApiKeysInput{})
if err != nil {
log.Warn("Failed to get API keys",
svc1log.SafeParam("apiId", apiID),
svc1log.Stacktrace(err))
errors = append(errors, fmt.Sprintf("GetApiKeys failed: %s", err.Error()))
} else {
for _, key := range keysResult.Items {
// Get API keys with pagination
keysPaginator := apigateway.NewGetApiKeysPaginator(client, &apigateway.GetApiKeysInput{})
for keysPaginator.HasMorePages() {
page, err := keysPaginator.NextPage(ctx)
if err != nil {
log.Warn("Failed to get API keys",
svc1log.SafeParam("apiId", apiID),
svc1log.Stacktrace(err))
errors = append(errors, fmt.Sprintf("GetApiKeys failed: %s", err.Error()))
break
}
for _, key := range page.Items {
if key.Id != nil {
// Check if this API key is associated with our API by checking usage plans
apiKeys = append(apiKeys, *key.Id)
}
}
}

// Get usage plans associated with this API
plansResult, err := client.GetUsagePlans(ctx, &apigateway.GetUsagePlansInput{})
if err != nil {
log.Warn("Failed to get usage plans",
svc1log.SafeParam("apiId", apiID),
svc1log.Stacktrace(err))
errors = append(errors, fmt.Sprintf("GetUsagePlans failed: %s", err.Error()))
} else {
for _, plan := range plansResult.Items {
// Get usage plans with pagination
plansPaginator := apigateway.NewGetUsagePlansPaginator(client, &apigateway.GetUsagePlansInput{})
for plansPaginator.HasMorePages() {
page, err := plansPaginator.NextPage(ctx)
if err != nil {
log.Warn("Failed to get usage plans",
svc1log.SafeParam("apiId", apiID),
svc1log.Stacktrace(err))
errors = append(errors, fmt.Sprintf("GetUsagePlans failed: %s", err.Error()))
break
}
for _, plan := range page.Items {
if plan.Id != nil && plan.ApiStages != nil {
// Check if this usage plan is associated with our API
for _, apiStage := range plan.ApiStages {
if apiStage.ApiId != nil && *apiStage.ApiId == apiID {
usagePlans = append(usagePlans, *plan.Id)
Expand Down
54 changes: 42 additions & 12 deletions internal/apigateway/enumerate/apigwv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func convertV2HttpAPIToFern(ctx context.Context, client *apigatewayv2.Client, ap
var errors []string

if api.ApiId == nil {
log.Warn("API Gateway API ID is nil", svc1log.SafeParam("apiId", *api.ApiId))
log.Warn("API Gateway API ID is nil")
errors = append(errors, "api gateway API ID is nil")
return nil, errors
}
Expand Down Expand Up @@ -201,13 +201,26 @@ func getHTTPAPIRoutes(ctx context.Context, client *apigatewayv2.Client, apiID, r
var routes []*apigatewayfern.Route
var errors []string

// Get routes for the API
routesResult, err := client.GetRoutes(ctx, &apigatewayv2.GetRoutesInput{ApiId: &apiID})
if err != nil {
return routes, []string{err.Error()}
// Get routes for the API with pagination
var allRoutes []types.Route
var nextToken *string
for {
routesResult, err := client.GetRoutes(ctx, &apigatewayv2.GetRoutesInput{
ApiId: &apiID,
NextToken: nextToken,
})
if err != nil {
errors = append(errors, fmt.Sprintf("GetRoutes failed for API %s: %s", apiID, err.Error()))
break
}
allRoutes = append(allRoutes, routesResult.Items...)
if routesResult.NextToken == nil {
break
}
nextToken = routesResult.NextToken
}

for _, route := range routesResult.Items {
for _, route := range allRoutes {
// Get integration for this route if it exists
var integration *apigatewayfern.Integration
if route.Target != nil {
Expand Down Expand Up @@ -403,13 +416,25 @@ func getHTTPAPICertificates(ctx context.Context, client *apigatewayv2.Client, ap
var certificates []*apigatewayfern.Certificate
var errors []string

// Get domain names for this API
domainNames, err := client.GetDomainNames(ctx, &apigatewayv2.GetDomainNamesInput{})
if err != nil {
return certificates, []string{err.Error()}
// Get domain names for this API with pagination
var allDomains []types.DomainName
var domainNextToken *string
for {
domainResult, err := client.GetDomainNames(ctx, &apigatewayv2.GetDomainNamesInput{
NextToken: domainNextToken,
})
if err != nil {
errors = append(errors, fmt.Sprintf("GetDomainNames failed: %s", err.Error()))
break
}
allDomains = append(allDomains, domainResult.Items...)
if domainResult.NextToken == nil {
break
}
domainNextToken = domainResult.NextToken
}

for _, domain := range domainNames.Items {
for _, domain := range allDomains {
// Check if this domain is associated with our API
mappings, err := client.GetApiMappings(ctx, &apigatewayv2.GetApiMappingsInput{
DomainName: domain.DomainName,
Expand All @@ -421,13 +446,18 @@ func getHTTPAPICertificates(ctx context.Context, client *apigatewayv2.Client, ap
// Check if any mapping is for our API
for _, mapping := range mappings.Items {
if mapping.ApiId != nil && *mapping.ApiId == apiID {
if len(domain.DomainNameConfigurations) == 0 || domain.DomainNameConfigurations[0].CertificateArn == nil {
errors = append(errors, fmt.Sprintf("Domain %s has no certificate configuration", aws.ToString(domain.DomainName)))
break
}

cert := &apigatewayfern.Certificate{
Arn: *domain.DomainNameConfigurations[0].CertificateArn,
DomainName: domain.DomainName,
}

// Convert security policy
if len(domain.DomainNameConfigurations) > 0 && domain.DomainNameConfigurations[0].SecurityPolicy != "" {
if domain.DomainNameConfigurations[0].SecurityPolicy != "" {
switch domain.DomainNameConfigurations[0].SecurityPolicy {
case types.SecurityPolicyTls10:
policy := apigatewayfern.SecurityPolicyTls10
Expand Down
15 changes: 7 additions & 8 deletions internal/ec2/enumerate/enumerate.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,8 @@ func enumerateEc2ForRegion(ctx context.Context, awsConfig aws.Config, region str
client := ec2aws.NewFromConfig(regionConfig)

// Get all instances
awsInstances, err := getAllInstances(ctx, client, region)
if err != nil {
errors = append(errors, err.Error())
return instances, errors
}
awsInstances, errs := getAllInstances(ctx, client, region)
errors = append(errors, errs...)

log.Info("Processing ec2aws instances",
svc1log.SafeParam("region", region),
Expand All @@ -103,14 +100,16 @@ func enumerateEc2ForRegion(ctx context.Context, awsConfig aws.Config, region str
}

// getAllInstances retrieves all ec2aws instances in a region
func getAllInstances(ctx context.Context, client *ec2aws.Client, region string) ([]types.Instance, error) {
func getAllInstances(ctx context.Context, client *ec2aws.Client, region string) ([]types.Instance, []string) {
var instances []types.Instance
var errors []string

paginator := ec2aws.NewDescribeInstancesPaginator(client, &ec2aws.DescribeInstancesInput{})
for paginator.HasMorePages() {
result, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list instances in region %s: %w", region, err)
errors = append(errors, fmt.Sprintf("failed to list instances in region %s: %s", region, err.Error()))
break
}

// Extract instances from reservations
Expand All @@ -119,7 +118,7 @@ func getAllInstances(ctx context.Context, client *ec2aws.Client, region string)
}
}

return instances, nil
return instances, errors
}

// processInstance converts an AWS instance to Fern format
Expand Down
5 changes: 4 additions & 1 deletion internal/eks/enumerate/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ func CredsEks(ctx context.Context, cfg aws.Config, clusterName string) (*eksfern
}

expiration := tok.Expiration
caCert := aws.ToString(clusterOutput.Cluster.CertificateAuthority.Data)
var caCert string
if clusterOutput.Cluster.CertificateAuthority != nil {
caCert = aws.ToString(clusterOutput.Cluster.CertificateAuthority.Data)
}
encodedToken := base64.StdEncoding.EncodeToString([]byte(tok.Token))
credInfo := eksfern.CredentialInfo{
Url: aws.ToString(clusterOutput.Cluster.Endpoint),
Expand Down
Loading
Loading