Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lib/aws_codegen/rest_service.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ defmodule AWS.CodeGen.RestService do
"#{if action.language == :elixir, do: ":", else: ""}#{result}"
end

def url_path(action) do
def url_path(context, action) do
Enum.reduce(action.url_parameters, action.request_uri, fn parameter, acc ->
multi_segment = Parameter.multi_segment?(parameter, acc)

Expand All @@ -54,7 +54,11 @@ defmodule AWS.CodeGen.RestService do
if multi_segment do
Enum.join(["\", aws_util:encode_multi_segment_uri(", parameter.code_name, "), \""])
else
Enum.join(["\", aws_util:encode_uri(", parameter.code_name, "), \""])
if context.module_name == "aws_cloudfront_keyvaluestore" do
Enum.join(["\", aws_util:encode_uri(", parameter.code_name, ", full), \""])
else
Enum.join(["\", aws_util:encode_uri(", parameter.code_name, "), \""])
end
end
end

Expand Down
53 changes: 41 additions & 12 deletions priv/rest.erl.eex
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ end) %>
<%= AWS.CodeGen.Types.return_type(context.language, action)%>.
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, QueryMap, HeadersMap, Options0)
when is_map(Client), is_map(QueryMap), is_map(HeadersMap), is_list(Options0) ->
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(context, action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
<%= if !String.contains?("Bucket", AWS.CodeGen.RestService.required_function_parameters(action)) do %><% else %> Bucket = undefined,<% end %><% end %>
SuccessStatusCode = <%= inspect(action.success_status_code) %>,
{SendBodyAsBinary, Options1} = proplists_take(send_body_as_binary, Options0, <%= action.send_body_as_binary? %>),
{SendBodyAsBinary, Options1} = proplists_take(send_body_as_binary, Options0, <%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>true<% else %><%= action.send_body_as_binary? %><% end %>),
{ReceiveBodyAsBinary, Options2} = proplists_take(receive_body_as_binary, Options1, <%= action.receive_body_as_binary? %>),
Options = [{send_body_as_binary, SendBodyAsBinary},
{receive_body_as_binary, ReceiveBodyAsBinary}
{receive_body_as_binary, ReceiveBodyAsBinary}<%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>,
{account, get_account_id(KvsARN)},
{sign_with_v4a, true}<% end %>
| Options2],
<%= if length(action.request_header_parameters) > 0 do %>
Headers0 =
Expand Down Expand Up @@ -122,14 +124,16 @@ end) %>
<%= AWS.CodeGen.Types.return_type(context.language, action)%>.
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, Input0, Options0) ->
Method = <%= AWS.CodeGen.RestService.Action.method(action) %>,
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(context, action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
<%= if !String.contains?("Bucket", AWS.CodeGen.RestService.required_function_parameters(action)) do %><% else %> Bucket = undefined,<% end %><% end %>
SuccessStatusCode = <%= inspect(action.success_status_code) %>,
{SendBodyAsBinary, Options1} = proplists_take(send_body_as_binary, Options0, <%= action.send_body_as_binary? %>),
{SendBodyAsBinary, Options1} = proplists_take(send_body_as_binary, Options0, <%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>true<% else %><%= action.send_body_as_binary? %><% end %>),
{ReceiveBodyAsBinary, Options2} = proplists_take(receive_body_as_binary, Options1, <%= action.receive_body_as_binary? %>),
Options = [{send_body_as_binary, SendBodyAsBinary},
{receive_body_as_binary, ReceiveBodyAsBinary},
{append_sha256_content_hash, <%= Enum.member?(["put_bucket_cors", "put_bucket_lifecycle", "put_bucket_tagging", "delete_objects"], action.function_name) %>}
{append_sha256_content_hash, <%= Enum.member?(["put_bucket_cors", "put_bucket_lifecycle", "put_bucket_tagging", "delete_objects"], action.function_name) %>}<%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>,
{account, get_account_id(KvsARN)},
{sign_with_v4a, true}<% end %>
| Options2],
<%= if length(action.request_header_parameters) > 0 do %>
HeadersMapping = [<%= for parameter <- Enum.drop(action.request_header_parameters, -1) do %>
Expand Down Expand Up @@ -209,7 +213,7 @@ do_request(Client, Method, Path, Query, Headers0, Input, Options, SuccessStatusC
Client1 = Client#{service => <<"<%= context.signing_name %>">><%= if context.is_global do %>,
region => <<"<%= context.credential_scope %>">><% end %>},
<%= if context.endpoint_prefix == "s3-control" do %>AccountId = proplists:get_value(<<"x-amz-account-id">>, Headers0),
DefaultHost = build_host(AccountId, <<"<%= context.endpoint_prefix %>">>, Client1),<% else %><%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>DefaultHost = build_host(<<"<%= context.endpoint_prefix %>">>, Client1, Bucket),<%else %>DefaultHost = build_host(<<"<%= context.endpoint_prefix %>">>, Client1),<% end %><% end %>
DefaultHost = build_host(AccountId, <<"<%= context.endpoint_prefix %>">>, Client1),<% else %><%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>DefaultHost = build_host(<<"<%= context.endpoint_prefix %>">>, Client1, Bucket),<% else %><%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>DefaultHost = build_host(proplists:get_value(account, Options), <<"cloudfront-kvs">>, Client1),<% else %>DefaultHost = build_host(<<"<%= context.endpoint_prefix %>">>, Client1),<% end %><% end %><% end %>
URL0 = build_url(DefaultHost, Path, Client1<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>, Bucket<% end %>),
PathBin = erlang:iolist_to_binary(Path),
{URL1, Host} = aws_util:apply_endpoint_url_override(URL0, DefaultHost, PathBin, <<"<%= context.endpoint_url_env_var %>">>),
Expand All @@ -219,8 +223,12 @@ do_request(Client, Method, Path, Query, Headers0, Input, Options, SuccessStatusC
],
Payload =
case proplists:get_value(send_body_as_binary, Options) of
true ->
maps:get(<<"Body">>, Input, <<"">>);
true when is_list(Input) ->
proplists:get_value(<<"Body">>, Input, <<"">>);
true when Input =:= undefined ->
<<"">>;
true ->
maps:get(<<"Body">>, Input, <<"">>);
false ->
encode_payload(Input)
end,
Expand All @@ -233,7 +241,7 @@ do_request(Client, Method, Path, Query, Headers0, Input, Options, SuccessStatusC
Headers1 = aws_request:add_headers(AdditionalHeaders, Headers0),

MethodBin = aws_request:method_to_binary(Method),
SignedHeaders = aws_request:sign_request(Client1, MethodBin, URL, Headers1, Payload<%= if context.module_name == "aws_apigatewaymanagementapi" or String.contains?(context.module_name, "aws_bedrock") do %>, [{uri_encode_path, true}]<% else %><% end %>),
SignedHeaders = aws_request:sign_request(Client1, MethodBin, URL, Headers1, Payload<%= if context.module_name == "aws_apigatewaymanagementapi" or String.contains?(context.module_name, "aws_bedrock") do %>, [{uri_encode_path, true}]<% else %><%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>, [{sign_with_v4a, true}, {uri_encode_path, false}]<% end %><% end %>),
Response = hackney:request(Method, URL, SignedHeaders, Payload, Options),
DecodeBody = not proplists:get_value(receive_body_as_binary, Options),
handle_response(Response, SuccessStatusCode, DecodeBody).
Expand Down Expand Up @@ -305,6 +313,22 @@ build_host(undefined, _EndpointPrefix, _Client) ->
error(missing_account_id);
build_host(AccountId, EndpointPrefix, #{region := Region, endpoint := Endpoint}) ->
aws_util:binary_join([AccountId, EndpointPrefix, Region, Endpoint],
<<".">>).<% else %><%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>
build_host(_AccountPrefix, _EndpointPrefix, #{region := <<"local">>, endpoint := Endpoint}) ->
Endpoint;
build_host(_AccountPrefix, _EndpointPrefix, #{region := <<"local">>}) ->
<<"localhost">>;
build_host(AccountPrefix, EndpointPrefix, #{region := <<"global">>, endpoint := Endpoint}) ->
aws_util:binary_join([AccountPrefix, EndpointPrefix, <<"global">>, Endpoint], <<".">>).
<% else %><%= if context.endpoint_prefix == "s3-control" do %>
build_host(_AccountPrefix, _EndpointPrefix, #{region := <<"local">>, endpoint := Endpoint}) ->
Endpoint;
build_host(_AccountPrefix, _EndpointPrefix, #{region := <<"local">>}) ->
<<"localhost">>;
build_host(undefined, _EndpointPrefix, _Client) ->
error(missing_account_id);
build_host(AccountPrefix, EndpointPrefix, #{region := Region, endpoint := Endpoint}) ->
aws_util:binary_join([_AccountPrefix, EndpointPrefix, Region, Endpoint],
<<".">>).<% else %>
<%= if context.endpoint_prefix == "s3" do %><%= if context.is_global do %>
build_host(EndpointPrefix, #{endpoint := Endpoint}, undefined) ->
Expand Down Expand Up @@ -333,7 +357,7 @@ build_host(_EndpointPrefix, #{region := <<"local">>}) ->
build_host(EndpointPrefix, #{endpoint := Endpoint}) ->
aws_util:binary_join([EndpointPrefix, Endpoint], <<".">>).<% else %>
build_host(EndpointPrefix, #{region := Region, endpoint := Endpoint}) ->
aws_util:binary_join([EndpointPrefix, Region, Endpoint], <<".">>).<% end %><% end %><% end %><% end %>
aws_util:binary_join([EndpointPrefix, Region, Endpoint], <<".">>).<% end %><% end %><% end %><% end %><% end %><% end %>
<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>build_url(Host0, Path0, Client, Bucket) ->
Proto = aws_client:proto(Client),
%% Mocks are notoriously bad with host-style requests, just skip it and use path-style for anything local
Expand All @@ -353,7 +377,8 @@ build_host(EndpointPrefix, #{region := Region, endpoint := Endpoint}) ->
Host1
end,
Port = aws_client:port(Client),
aws_util:binary_join([Proto, <<"://">>, Host, <<":">>, Port, Path], <<"">>).<% else %>build_url(Host, Path0, Client) ->
aws_util:binary_join([Proto, <<"://">>, Host, <<":">>, Port, Path], <<"">>).<% else %>
build_url(Host, Path0, Client) ->
Proto = aws_client:proto(Client),
Path = erlang:iolist_to_binary(Path0),
Port = aws_client:port(Client),
Expand All @@ -364,3 +389,7 @@ encode_payload(undefined) ->
<<>>;
encode_payload(Input) ->
<%= context.encode %>.
<%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>
get_account_id(Arn) ->
[<<"arn">>, <<"aws">>, <<"cloudfront">>, <<>>, AccountId, _Rest] = binary:split(Arn, <<":">>, [global]),
AccountId.<% end %>
4 changes: 2 additions & 2 deletions priv/rest.ex.eex
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ defmodule <%= context.module_name %> do
"""<% end %><%= if action.method == "GET" do %>
@spec <%= action.function_name %>(map()<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>, String.t() | atom()<% end %><%= AWS.CodeGen.Types.function_parameter_types(action.method, action, false)%>, list()) :: <%= AWS.CodeGen.Types.return_type(context.language, action)%>
def <%= action.function_name %>(%Client{} = client<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>, stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, options \\ []) do
url_path = "<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>/#{stage}<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"
url_path = "<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>/#{stage}<% end %><%= AWS.CodeGen.RestService.Action.url_path(context, action) %>"
headers = []<%= for parameter <- action.request_header_parameters do %>
headers = if !is_nil(<%= parameter.code_name %>) do
[{"<%= parameter.location_name %>", <%= parameter.code_name %>} | headers]
Expand Down Expand Up @@ -117,7 +117,7 @@ defmodule <%= context.module_name %> do
Request.request_rest(client, meta, :get, url_path, query_params, headers, nil, options, <%= inspect(action.success_status_code) %>)<% else %>
@spec <%= action.function_name %>(map()<%= AWS.CodeGen.Types.function_parameter_types(action.method, action, false)%>, <%= if context.module_name == "AWS.ApiGatewayManagementApi" do %> String.t() | atom(), <% end %><%= AWS.CodeGen.Types.function_argument_type(context.language, action)%>, list()) :: <%= AWS.CodeGen.Types.return_type(context.language, action)%>
def <%= action.function_name %>(%Client{} = client<%= AWS.CodeGen.RestService.function_parameters(action) %>, <%= if context.module_name == "AWS.ApiGatewayManagementApi" do %> stage, <% end %>input, options \\ []) do
url_path = "<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>/#{stage}<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"<%= if length(action.request_header_parameters) > 0 do %>
url_path = "<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>/#{stage}<% end %><%= AWS.CodeGen.RestService.Action.url_path(context, action) %>"<%= if length(action.request_header_parameters) > 0 do %>
{headers, input} =
[<%= for parameter <- action.request_header_parameters do %>
{"<%= parameter.name %>", "<%= parameter.location_name %>"},<% end %>
Expand Down
Loading