From e58eaa10af2faa489ad547a53de0cf8b9f0ab5a3 Mon Sep 17 00:00:00 2001 From: Mark Kraus Date: Sat, 9 Dec 2017 07:56:54 -0600 Subject: [PATCH] [Feature] Sort and Merge Classes in WebCmdlets\Common --- .../BasicHtmlWebResponseObject.Common.cs | 142 +- .../WebCmdlet/Common/ContentHelper.Common.cs | 179 +- .../Common/InvokeRestMethodCommand.Common.cs | 403 ++-- .../Common/WebRequestPSCmdlet.Common.cs | 1786 ++++++++--------- .../Common/WebResponseObject.Common.cs | 120 +- 5 files changed, 1294 insertions(+), 1336 deletions(-) diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs index 22709a2bc99..88d8374bd9b 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs @@ -3,14 +3,14 @@ --********************************************************************/ using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; using System.Management.Automation; using System.Net; -using System.IO; +using System.Net.Http; using System.Text; using System.Text.RegularExpressions; -using System.Collections.Generic; -using System.Diagnostics; -using System.Net.Http; namespace Microsoft.PowerShell.Commands { @@ -128,8 +128,65 @@ public WebCmdletElementCollection Images #endregion Private Fields + #region Constructors + + /// + /// Constructor for BasicHtmlWebResponseObject + /// + /// + public BasicHtmlWebResponseObject(HttpResponseMessage response) + : this(response, null) + { } + + /// + /// Constructor for HtmlWebResponseObject with memory stream + /// + /// + /// + public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream) + : base(response, contentStream) + { + EnsureHtmlParser(); + InitializeContent(); + InitializeRawContent(response); + } + + #endregion Constructors + #region Methods + /// + /// Reads the response content from the web response. + /// + protected void InitializeContent() + { + string contentType = ContentHelper.GetContentType(BaseResponse); + if (ContentHelper.IsText(contentType)) + { + Encoding encoding = null; + // fill the Content buffer + string characterSet = WebResponseHelper.GetCharacterSet(BaseResponse); + this.Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out encoding); + this.Encoding = encoding; + } + else + { + this.Content = string.Empty; + } + } + + private PSObject CreateHtmlObject(string html, string tagName) + { + PSObject elementObject = new PSObject(); + + elementObject.Properties.Add(new PSNoteProperty("outerHTML", html)); + elementObject.Properties.Add(new PSNoteProperty("tagName", tagName)); + + ParseAttributes(html, elementObject); + + return elementObject; + } + private void EnsureHtmlParser() { if (s_tagRegex == null) @@ -169,16 +226,11 @@ private void EnsureHtmlParser() } } - private PSObject CreateHtmlObject(string html, string tagName) + private void InitializeRawContent(HttpResponseMessage baseResponse) { - PSObject elementObject = new PSObject(); - - elementObject.Properties.Add(new PSNoteProperty("outerHTML", html)); - elementObject.Properties.Add(new PSNoteProperty("tagName", tagName)); - - ParseAttributes(html, elementObject); - - return elementObject; + StringBuilder raw = ContentHelper.GetRawContentHeader(baseResponse); + raw.Append(Content); + this.RawContent = raw.ToString(); } private void ParseAttributes(string outerHtml, PSObject elementObject) @@ -223,70 +275,6 @@ private void ParseAttributes(string outerHtml, PSObject elementObject) } } - /// - /// Reads the response content from the web response. - /// - protected void InitializeContent() - { - string contentType = ContentHelper.GetContentType(BaseResponse); - if (ContentHelper.IsText(contentType)) - { - Encoding encoding = null; - // fill the Content buffer - string characterSet = WebResponseHelper.GetCharacterSet(BaseResponse); - this.Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out encoding); - this.Encoding = encoding; - } - else - { - this.Content = string.Empty; - } - } - - #endregion Methods - } - - // TODO: Merge Partials - - // - /// Response object for html content without DOM parsing - /// - public partial class BasicHtmlWebResponseObject : WebResponseObject - { - #region Constructors - - /// - /// Constructor for BasicHtmlWebResponseObject - /// - /// - public BasicHtmlWebResponseObject(HttpResponseMessage response) - : this(response, null) - { } - - /// - /// Constructor for HtmlWebResponseObject with memory stream - /// - /// - /// - public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream) - : base(response, contentStream) - { - EnsureHtmlParser(); - InitializeContent(); - InitializeRawContent(response); - } - - #endregion Constructors - - #region Methods - - private void InitializeRawContent(HttpResponseMessage baseResponse) - { - StringBuilder raw = ContentHelper.GetRawContentHeader(baseResponse); - raw.Append(Content); - this.RawContent = raw.ToString(); - } - #endregion Methods } } diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/ContentHelper.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/ContentHelper.Common.cs index 8bc5546b181..4fb73ef57f3 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/ContentHelper.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/ContentHelper.Common.cs @@ -3,12 +3,12 @@ --********************************************************************/ using System; -using System.Management.Automation; -using System.Text; -using Microsoft.Win32; using System.Linq; +using System.Management.Automation; using System.Net.Http; using System.Net.Http.Headers; +using System.Text; +using Microsoft.Win32; namespace Microsoft.PowerShell.Commands { @@ -26,22 +26,22 @@ internal static partial class ContentHelper #region Internal Methods - internal static bool IsText(string contentType) + internal static string GetContentType(HttpResponseMessage response) { - contentType = GetContentTypeSignature(contentType); - return CheckIsText(contentType); + // ContentType may not exist in response header. Return null if not. + return response.Content.Headers.ContentType?.MediaType; } - internal static bool IsXml(string contentType) + internal static Encoding GetDefaultEncoding() { - contentType = GetContentTypeSignature(contentType); - return CheckIsXml(contentType); + return GetEncodingOrDefault((string)null); } - internal static bool IsJson(string contentType) + internal static Encoding GetEncoding(HttpResponseMessage response) { - contentType = GetContentTypeSignature(contentType); - return CheckIsJson(contentType); + // ContentType may not exist in response header. + string charSet = response.Content.Headers.ContentType?.CharSet; + return GetEncodingOrDefault(charSet); } internal static Encoding GetEncodingOrDefault(string characterSet) @@ -63,22 +63,87 @@ internal static Encoding GetEncodingOrDefault(string characterSet) return encoding; } - internal static Encoding GetDefaultEncoding() + internal static StringBuilder GetRawContentHeader(HttpResponseMessage response) { - return GetEncodingOrDefault((string)null); + StringBuilder raw = new StringBuilder(); + + string protocol = WebResponseHelper.GetProtocol(response); + if (!string.IsNullOrEmpty(protocol)) + { + int statusCode = WebResponseHelper.GetStatusCode(response); + string statusDescription = WebResponseHelper.GetStatusDescription(response); + raw.AppendFormat("{0} {1} {2}", protocol, statusCode, statusDescription); + raw.AppendLine(); + } + + HttpHeaders[] headerCollections = + { + response.Headers, + response.Content == null ? null : response.Content.Headers + }; + + foreach (var headerCollection in headerCollections) + { + if (headerCollection == null) + { + continue; + } + foreach (var header in headerCollection) + { + // Headers may have multiple entries with different values + foreach (var headerValue in header.Value) + { + raw.Append(header.Key); + raw.Append(": "); + raw.Append(headerValue); + raw.AppendLine(); + } + } + } + + raw.AppendLine(); + return raw; + } + + internal static bool IsJson(string contentType) + { + contentType = GetContentTypeSignature(contentType); + return CheckIsJson(contentType); + } + + internal static bool IsText(string contentType) + { + contentType = GetContentTypeSignature(contentType); + return CheckIsText(contentType); + } + + internal static bool IsXml(string contentType) + { + contentType = GetContentTypeSignature(contentType); + return CheckIsXml(contentType); } #endregion Internal Methods #region Private Helper Methods - private static string GetContentTypeSignature(string contentType) + private static bool CheckIsJson(string contentType) { if (String.IsNullOrEmpty(contentType)) - return null; + return false; - string sig = contentType.Split(s_contentTypeParamSeparator, 2)[0].ToUpperInvariant(); - return (sig); + // the correct type for JSON content, as specified in RFC 4627 + bool isJson = contentType.Equals("application/json", StringComparison.OrdinalIgnoreCase); + + // add in these other "javascript" related types that + // sometimes get sent down as the mime type for JSON content + isJson |= contentType.Equals("text/json", StringComparison.OrdinalIgnoreCase) + || contentType.Equals("application/x-javascript", StringComparison.OrdinalIgnoreCase) + || contentType.Equals("text/x-javascript", StringComparison.OrdinalIgnoreCase) + || contentType.Equals("application/javascript", StringComparison.OrdinalIgnoreCase) + || contentType.Equals("text/javascript", StringComparison.OrdinalIgnoreCase); + + return (isJson); } private static bool CheckIsText(string contentType) @@ -132,85 +197,15 @@ private static bool CheckIsXml(string contentType) return (isXml); } - private static bool CheckIsJson(string contentType) + private static string GetContentTypeSignature(string contentType) { if (String.IsNullOrEmpty(contentType)) - return false; - - // the correct type for JSON content, as specified in RFC 4627 - bool isJson = contentType.Equals("application/json", StringComparison.OrdinalIgnoreCase); - - // add in these other "javascript" related types that - // sometimes get sent down as the mime type for JSON content - isJson |= contentType.Equals("text/json", StringComparison.OrdinalIgnoreCase) - || contentType.Equals("application/x-javascript", StringComparison.OrdinalIgnoreCase) - || contentType.Equals("text/x-javascript", StringComparison.OrdinalIgnoreCase) - || contentType.Equals("application/javascript", StringComparison.OrdinalIgnoreCase) - || contentType.Equals("text/javascript", StringComparison.OrdinalIgnoreCase); + return null; - return (isJson); + string sig = contentType.Split(s_contentTypeParamSeparator, 2)[0].ToUpperInvariant(); + return (sig); } #endregion Internal Helper Methods } - - // TODO: merge Partials - - internal static partial class ContentHelper - { - internal static Encoding GetEncoding(HttpResponseMessage response) - { - // ContentType may not exist in response header. - string charSet = response.Content.Headers.ContentType?.CharSet; - return GetEncodingOrDefault(charSet); - } - - internal static string GetContentType(HttpResponseMessage response) - { - // ContentType may not exist in response header. Return null if not. - return response.Content.Headers.ContentType?.MediaType; - } - - internal static StringBuilder GetRawContentHeader(HttpResponseMessage response) - { - StringBuilder raw = new StringBuilder(); - - string protocol = WebResponseHelper.GetProtocol(response); - if (!string.IsNullOrEmpty(protocol)) - { - int statusCode = WebResponseHelper.GetStatusCode(response); - string statusDescription = WebResponseHelper.GetStatusDescription(response); - raw.AppendFormat("{0} {1} {2}", protocol, statusCode, statusDescription); - raw.AppendLine(); - } - - HttpHeaders[] headerCollections = - { - response.Headers, - response.Content == null ? null : response.Content.Headers - }; - - foreach (var headerCollection in headerCollections) - { - if (headerCollection == null) - { - continue; - } - foreach (var header in headerCollection) - { - // Headers may have multiple entries with different values - foreach (var headerValue in header.Value) - { - raw.Append(header.Key); - raw.Append(": "); - raw.Append(headerValue); - raw.AppendLine(); - } - } - } - - raw.AppendLine(); - return raw; - } - } } diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs index 7c6a149cc7f..dc14fae5d92 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs @@ -3,17 +3,25 @@ --********************************************************************/ using System; -using System.Management.Automation; using System.IO; +using System.Management.Automation; +using System.Net.Http; +using System.Text; using System.Xml; using Newtonsoft.Json; using Newtonsoft.Json.Linq; -using System.Net.Http; -using System.Text; namespace Microsoft.PowerShell.Commands { - public partial class InvokeRestMethodCommand + /// + /// The Invoke-RestMethod command + /// This command makes an HTTP or HTTPS request to a web service, + /// and returns the response in an appropriate way. + /// Intended to work against the wide spectrum of "RESTful" web services + /// currently deployed across the web. + /// + [Cmdlet(VerbsLifecycle.Invoke, "RestMethod", HelpUri = "https://go.microsoft.com/fwlink/?LinkID=217034", DefaultParameterSetName = "StandardMethod")] + public partial class InvokeRestMethodCommand : WebRequestPSCmdlet { #region Parameters @@ -73,63 +81,138 @@ public int MaximumFollowRelLink #endregion Parameters - #region Helper Methods + #region Enums - private bool TryProcessFeedStream(BufferingStreamReader responseStream) + /// + /// enum for rest return type. + /// + public enum RestReturnType { - bool isRssOrFeed = false; + /// + /// Return type not defined in response, + /// best effort detect + /// + Detect, - try - { - XmlReaderSettings readerSettings = GetSecureXmlReaderSettings(); - XmlReader reader = XmlReader.Create(responseStream, readerSettings); + /// + /// Json return type + /// + [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Naming", "CA1704:IdentifiersShouldBeSpelledCorrectly")] + Json, - // See if the reader contained an "RSS" or "Feed" in the first 10 elements (RSS and Feed are normally 2 or 3) - int readCount = 0; - while ((readCount < 10) && reader.Read()) + /// + /// Xml return type + /// + Xml, + } + + #endregion Enums + + #region Virtual Method Overrides + + /// + /// Process the web response and output corresponding objects. + /// + /// + internal override void ProcessResponse(HttpResponseMessage response) + { + if (null == response) { throw new ArgumentNullException("response"); } + + using (BufferingStreamReader responseStream = new BufferingStreamReader(StreamHelper.GetResponseStream(response))) + { + if (ShouldWriteToPipeline) { - if (String.Equals("rss", reader.Name, StringComparison.OrdinalIgnoreCase) || - String.Equals("feed", reader.Name, StringComparison.OrdinalIgnoreCase)) + // First see if it is an RSS / ATOM feed, in which case we can + // stream it - unless the user has overridden it with a return type of "XML" + if (TryProcessFeedStream(responseStream)) { - isRssOrFeed = true; - break; + // Do nothing, content has been processed. } + else + { + // determine the response type + RestReturnType returnType = CheckReturnType(response); - readCount++; - } + // Try to get the response encoding from the ContentType header. + Encoding encoding = null; + string charSet = response.Content.Headers.ContentType?.CharSet; + if (!string.IsNullOrEmpty(charSet)) + { + // NOTE: Don't use ContentHelper.GetEncoding; it returns a + // default which bypasses checking for a meta charset value. + StreamHelper.TryGetEncoding(charSet, out encoding); + } - if (isRssOrFeed) - { - XmlDocument workingDocument = new XmlDocument(); - // performing a Read() here to avoid rrechecking - // "rss" or "feed" items - reader.Read(); - while (!reader.EOF) - { - // If node is Element and it's the 'Item' or 'Entry' node, emit that node. - if ((reader.NodeType == XmlNodeType.Element) && - (string.Equals("Item", reader.Name, StringComparison.OrdinalIgnoreCase) || - string.Equals("Entry", reader.Name, StringComparison.OrdinalIgnoreCase)) - ) + object obj = null; + Exception ex = null; + + string str = StreamHelper.DecodeStream(responseStream, ref encoding); + // NOTE: Tests use this verbose output to verify the encoding. + WriteVerbose(string.Format + ( + System.Globalization.CultureInfo.InvariantCulture, + "Content encoding: {0}", + string.IsNullOrEmpty(encoding.HeaderName) ? encoding.EncodingName : encoding.HeaderName) + ); + bool convertSuccess = false; + + if (returnType == RestReturnType.Json) { - // this one will do reader.Read() internally - XmlNode result = workingDocument.ReadNode(reader); - WriteObject(result); + convertSuccess = TryConvertToJson(str, out obj, ref ex) || TryConvertToXml(str, out obj, ref ex); } + // default to try xml first since it's more common else { - reader.Read(); + convertSuccess = TryConvertToXml(str, out obj, ref ex) || TryConvertToJson(str, out obj, ref ex); + } + + if (!convertSuccess) + { + // fallback to string + obj = str; } + + WriteObject(obj); } } + + if (ShouldSaveToOutFile) + { + StreamHelper.SaveStreamToFile(responseStream, QualifiedOutFile, this); + } + + if (!String.IsNullOrEmpty(ResponseHeadersVariable)) + { + PSVariableIntrinsics vi = SessionState.PSVariable; + vi.Set(ResponseHeadersVariable, WebResponseHelper.GetHeadersDictionary(response)); + } } - catch (XmlException) { } - finally + } + + #endregion Virtual Method Overrides + + #region Helper Methods + + private RestReturnType CheckReturnType(HttpResponseMessage response) + { + if (null == response) { throw new ArgumentNullException("response"); } + + RestReturnType rt = RestReturnType.Detect; + string contentType = ContentHelper.GetContentType(response); + if (string.IsNullOrEmpty(contentType)) { - responseStream.Seek(0, SeekOrigin.Begin); + rt = RestReturnType.Detect; + } + else if (ContentHelper.IsJson(contentType)) + { + rt = RestReturnType.Json; + } + else if (ContentHelper.IsXml(contentType)) + { + rt = RestReturnType.Xml; } - return isRssOrFeed; + return (rt); } // Mostly cribbed from Serialization.cs#GetXmlReaderSettingsForCliXml() @@ -149,27 +232,6 @@ private XmlReaderSettings GetSecureXmlReaderSettings() return xrs; } - private bool TryConvertToXml(string xml, out object doc, ref Exception exRef) - { - try - { - XmlReaderSettings settings = GetSecureXmlReaderSettings(); - XmlReader xmlReader = XmlReader.Create(new StringReader(xml), settings); - - var xmlDoc = new XmlDocument(); - xmlDoc.PreserveWhitespace = true; - xmlDoc.Load(xmlReader); - - doc = xmlDoc; - } - catch (XmlException ex) - { - exRef = ex; - doc = null; - } - return (null != doc); - } - private bool TryConvertToJson(string json, out object obj, ref Exception exRef) { bool converted = false; @@ -214,31 +276,88 @@ private bool TryConvertToJson(string json, out object obj, ref Exception exRef) return converted; } - #endregion + private bool TryConvertToXml(string xml, out object doc, ref Exception exRef) + { + try + { + XmlReaderSettings settings = GetSecureXmlReaderSettings(); + XmlReader xmlReader = XmlReader.Create(new StringReader(xml), settings); + + var xmlDoc = new XmlDocument(); + xmlDoc.PreserveWhitespace = true; + xmlDoc.Load(xmlReader); - /// - /// enum for rest return type. - /// - public enum RestReturnType + doc = xmlDoc; + } + catch (XmlException ex) + { + exRef = ex; + doc = null; + } + return (null != doc); + } + + private bool TryProcessFeedStream(BufferingStreamReader responseStream) { - /// - /// Return type not defined in response, - /// best effort detect - /// - Detect, + bool isRssOrFeed = false; - /// - /// Json return type - /// - [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Naming", "CA1704:IdentifiersShouldBeSpelledCorrectly")] - Json, + try + { + XmlReaderSettings readerSettings = GetSecureXmlReaderSettings(); + XmlReader reader = XmlReader.Create(responseStream, readerSettings); - /// - /// Xml return type - /// - Xml, + // See if the reader contained an "RSS" or "Feed" in the first 10 elements (RSS and Feed are normally 2 or 3) + int readCount = 0; + while ((readCount < 10) && reader.Read()) + { + if (String.Equals("rss", reader.Name, StringComparison.OrdinalIgnoreCase) || + String.Equals("feed", reader.Name, StringComparison.OrdinalIgnoreCase)) + { + isRssOrFeed = true; + break; + } + + readCount++; + } + + if (isRssOrFeed) + { + XmlDocument workingDocument = new XmlDocument(); + // performing a Read() here to avoid rrechecking + // "rss" or "feed" items + reader.Read(); + while (!reader.EOF) + { + // If node is Element and it's the 'Item' or 'Entry' node, emit that node. + if ((reader.NodeType == XmlNodeType.Element) && + (string.Equals("Item", reader.Name, StringComparison.OrdinalIgnoreCase) || + string.Equals("Entry", reader.Name, StringComparison.OrdinalIgnoreCase)) + ) + { + // this one will do reader.Read() internally + XmlNode result = workingDocument.ReadNode(reader); + WriteObject(result); + } + else + { + reader.Read(); + } + } + } + } + catch (XmlException) { } + finally + { + responseStream.Seek(0, SeekOrigin.Begin); + } + + return isRssOrFeed; } + #endregion Helper Methods + + #region Classes + internal class BufferingStreamReader : Stream { internal BufferingStreamReader(Stream baseStream) @@ -344,127 +463,7 @@ public override void Write(byte[] buffer, int offset, int count) throw new NotSupportedException(); } } - } - - // TODO: Merge Partials - - /// - /// The Invoke-RestMethod command - /// This command makes an HTTP or HTTPS request to a web service, - /// and returns the response in an appropriate way. - /// Intended to work against the wide spectrum of "RESTful" web services - /// currently deployed across the web. - /// - [Cmdlet(VerbsLifecycle.Invoke, "RestMethod", HelpUri = "https://go.microsoft.com/fwlink/?LinkID=217034", DefaultParameterSetName = "StandardMethod")] - public partial class InvokeRestMethodCommand : WebRequestPSCmdlet - { - #region Virtual Method Overrides - - /// - /// Process the web response and output corresponding objects. - /// - /// - internal override void ProcessResponse(HttpResponseMessage response) - { - if (null == response) { throw new ArgumentNullException("response"); } - - using (BufferingStreamReader responseStream = new BufferingStreamReader(StreamHelper.GetResponseStream(response))) - { - if (ShouldWriteToPipeline) - { - // First see if it is an RSS / ATOM feed, in which case we can - // stream it - unless the user has overridden it with a return type of "XML" - if (TryProcessFeedStream(responseStream)) - { - // Do nothing, content has been processed. - } - else - { - // determine the response type - RestReturnType returnType = CheckReturnType(response); - - // Try to get the response encoding from the ContentType header. - Encoding encoding = null; - string charSet = response.Content.Headers.ContentType?.CharSet; - if (!string.IsNullOrEmpty(charSet)) - { - // NOTE: Don't use ContentHelper.GetEncoding; it returns a - // default which bypasses checking for a meta charset value. - StreamHelper.TryGetEncoding(charSet, out encoding); - } - - object obj = null; - Exception ex = null; - - string str = StreamHelper.DecodeStream(responseStream, ref encoding); - // NOTE: Tests use this verbose output to verify the encoding. - WriteVerbose(string.Format - ( - System.Globalization.CultureInfo.InvariantCulture, - "Content encoding: {0}", - string.IsNullOrEmpty(encoding.HeaderName) ? encoding.EncodingName : encoding.HeaderName) - ); - bool convertSuccess = false; - - if (returnType == RestReturnType.Json) - { - convertSuccess = TryConvertToJson(str, out obj, ref ex) || TryConvertToXml(str, out obj, ref ex); - } - // default to try xml first since it's more common - else - { - convertSuccess = TryConvertToXml(str, out obj, ref ex) || TryConvertToJson(str, out obj, ref ex); - } - - if (!convertSuccess) - { - // fallback to string - obj = str; - } - - WriteObject(obj); - } - } - - if (ShouldSaveToOutFile) - { - StreamHelper.SaveStreamToFile(responseStream, QualifiedOutFile, this); - } - - if (!String.IsNullOrEmpty(ResponseHeadersVariable)) - { - PSVariableIntrinsics vi = SessionState.PSVariable; - vi.Set(ResponseHeadersVariable, WebResponseHelper.GetHeadersDictionary(response)); - } - } - } - #endregion Virtual Method Overrides - - #region Helper Methods - - private RestReturnType CheckReturnType(HttpResponseMessage response) - { - if (null == response) { throw new ArgumentNullException("response"); } - - RestReturnType rt = RestReturnType.Detect; - string contentType = ContentHelper.GetContentType(response); - if (string.IsNullOrEmpty(contentType)) - { - rt = RestReturnType.Detect; - } - else if (ContentHelper.IsJson(contentType)) - { - rt = RestReturnType.Json; - } - else if (ContentHelper.IsXml(contentType)) - { - rt = RestReturnType.Xml; - } - - return (rt); - } - - #endregion Helper Methods + #endregion Classes } } diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs index 2b2dc46a6c3..6316804bc8b 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs @@ -3,28 +3,49 @@ --********************************************************************/ using System; +using System.Collections; +using System.Collections.Generic; using System.Collections.ObjectModel; +using System.Globalization; +using System.IO; +using System.Linq; using System.Management.Automation; using System.Net; -using System.IO; -using System.Text; -using System.Collections; -using System.Globalization; +using System.Net.Http; +using System.Net.Http.Headers; using System.Security; using System.Security.Authentication; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; -using Microsoft.Win32; -using System.Net.Http; -using System.Net.Http.Headers; +using System.Text; +using System.Text.RegularExpressions; using System.Threading; using System.Xml; -using System.Collections.Generic; -using System.Text.RegularExpressions; -using System.Linq; +using Microsoft.Win32; namespace Microsoft.PowerShell.Commands { + /// + /// Exception class for webcmdlets to enable returning HTTP error response + /// + public sealed class HttpResponseException : HttpRequestException + { + /// + /// Constructor for HttpResponseException + /// + /// Message for the exception + /// Response from the HTTP server + public HttpResponseException (string message, HttpResponseMessage response) : base(message) + { + Response = response; + } + + /// + /// HTTP error response + /// + public HttpResponseMessage Response { get; private set; } + } + /// /// The valid values for the -Authentication parameter for Invoke-RestMethod and Invoke-WebRequest /// @@ -85,6 +106,36 @@ public enum WebSslProtocol /// public abstract partial class WebRequestPSCmdlet : PSCmdlet { + + #region Fields + + /// + /// Automatically follow Rel Links + /// + internal bool _followRelLink = false; + + /// + /// Maximum number of Rel Links to follow + /// + internal int _maximumFollowRelLink = Int32.MaxValue; + + /// + /// Parse Rel Links + /// + internal bool _parseRelLink = false; + + /// + /// Automatically follow Rel Links + /// + internal Dictionary _relationLink = null; + + /// + /// Cancellation token source + /// + private CancellationTokenSource _cancelToken = null; + + #endregion Fields + #region Virtual Properties #region URI @@ -183,6 +234,21 @@ public abstract partial class WebRequestPSCmdlet : PSCmdlet [Parameter] public virtual SecureString Token { get; set; } + /// + /// gets or sets the PreserveAuthorizationOnRedirect property + /// + /// + /// This property overrides compatibility with web requests on Windows. + /// On FullCLR (WebRequest), authorization headers are stripped during redirect. + /// CoreCLR (HTTPClient) does not have this behavior so web requests that work on + /// PowerShell/FullCLR can fail with PowerShell/CoreCLR. To provide compatibility, + /// we'll detect requests with an Authorization header and automatically strip + /// the header when the first redirect occurs. This switch turns off this logic for + /// edge cases where the authorization header needs to be preserved across redirects. + /// + [Parameter] + public virtual SwitchParameter PreserveAuthorizationOnRedirect { get; set; } + #endregion #region Headers @@ -213,6 +279,15 @@ public abstract partial class WebRequestPSCmdlet : PSCmdlet [Parameter] public virtual IDictionary Headers { get; set; } + /// + /// gets or sets the SkipHeaderValidation property + /// + /// + /// This property adds headers to the request's header collection without validation. + /// + [Parameter] + public virtual SwitchParameter SkipHeaderValidation { get; set; } + #endregion #region Redirect @@ -351,1050 +426,848 @@ public virtual string CustomMethod #endregion Virtual Properties - #region Virtual Methods + #region Helper Properties - internal virtual void ValidateParameters() + internal string QualifiedOutFile { - // sessions - if ((null != WebSession) && (null != SessionVariable)) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.SessionConflict, - "WebCmdletSessionConflictException"); - ThrowTerminatingError(error); - } + get { return (QualifyFilePath(OutFile)); } + } - // Authentication - if (UseDefaultCredentials && (Authentication != WebAuthenticationType.None)) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.AuthenticationConflict, - "WebCmdletAuthenticationConflictException"); - ThrowTerminatingError(error); - } - if ((Authentication != WebAuthenticationType.None) && (null != Token) && (null != Credential)) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.AuthenticationTokenConflict, - "WebCmdletAuthenticationTokenConflictException"); - ThrowTerminatingError(error); - } - if ((Authentication == WebAuthenticationType.Basic) && (null == Credential)) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.AuthenticationCredentialNotSupplied, - "WebCmdletAuthenticationCredentialNotSuppliedException"); - ThrowTerminatingError(error); - } - if ((Authentication == WebAuthenticationType.OAuth || Authentication == WebAuthenticationType.Bearer) && (null == Token)) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.AuthenticationTokenNotSupplied, - "WebCmdletAuthenticationTokenNotSuppliedException"); - ThrowTerminatingError(error); - } - if (!AllowUnencryptedAuthentication && (Authentication != WebAuthenticationType.None) && (Uri.Scheme != "https")) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.AllowUnencryptedAuthenticationRequired, - "WebCmdletAllowUnencryptedAuthenticationRequiredException"); - ThrowTerminatingError(error); - } - if (!AllowUnencryptedAuthentication && (null != Credential || UseDefaultCredentials) && (Uri.Scheme != "https")) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.AllowUnencryptedAuthenticationRequired, - "WebCmdletAllowUnencryptedAuthenticationRequiredException"); - ThrowTerminatingError(error); - } + internal bool ShouldSaveToOutFile + { + get { return (!string.IsNullOrEmpty(OutFile)); } + } - // credentials - if (UseDefaultCredentials && (null != Credential)) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.CredentialConflict, - "WebCmdletCredentialConflictException"); - ThrowTerminatingError(error); - } + internal bool ShouldWriteToPipeline + { + get { return (!ShouldSaveToOutFile || PassThru); } + } - // Proxy server - if (ProxyUseDefaultCredentials && (null != ProxyCredential)) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.ProxyCredentialConflict, - "WebCmdletProxyCredentialConflictException"); - ThrowTerminatingError(error); - } - else if ((null == Proxy) && ((null != ProxyCredential) || ProxyUseDefaultCredentials)) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.ProxyUriNotSupplied, - "WebCmdletProxyUriNotSuppliedException"); - ThrowTerminatingError(error); - } + #endregion Helper Properties - // request body content - if ((null != Body) && (null != InFile)) - { - ErrorRecord error = GetValidationError(WebCmdletStrings.BodyConflict, - "WebCmdletBodyConflictException"); - ThrowTerminatingError(error); - } + #region Abstract Methods - // validate InFile path - if (InFile != null) + /// + /// Read the supplied WebResponse object and push the + /// resulting output into the pipeline. + /// + /// Instance of a WebResponse object to be processed + internal abstract void ProcessResponse(HttpResponseMessage response); + + #endregion Abstract Methods + + #region Overrides + + /// + /// the main execution method for cmdlets derived from WebRequestPSCmdlet. + /// + protected override void ProcessRecord() + { + try { - ProviderInfo provider = null; - ErrorRecord errorRecord = null; + // Set cmdlet context for write progress + ValidateParameters(); + PrepareSession(); - try - { - Collection providerPaths = GetResolvedProviderPathFromPSPath(InFile, out provider); + // if the request contains an authorization header and PreserveAuthorizationOnRedirect is not set, + // it needs to be stripped on the first redirect. + bool stripAuthorization = null != WebSession + && + null != WebSession.Headers + && + !PreserveAuthorizationOnRedirect.IsPresent + && + WebSession.Headers.ContainsKey(HttpKnownHeaderNames.Authorization.ToString()); - if (!provider.Name.Equals(FileSystemProvider.ProviderName, StringComparison.OrdinalIgnoreCase)) - { - errorRecord = GetValidationError(WebCmdletStrings.NotFilesystemPath, - "WebCmdletInFileNotFilesystemPathException", InFile); - } - else + using (HttpClient client = GetHttpClient(stripAuthorization)) + { + int followedRelLink = 0; + Uri uri = Uri; + do { - if (providerPaths.Count > 1) - { - errorRecord = GetValidationError(WebCmdletStrings.MultiplePathsResolved, - "WebCmdletInFileMultiplePathsResolvedException", InFile); - } - else if (providerPaths.Count == 0) + if (followedRelLink > 0) { - errorRecord = GetValidationError(WebCmdletStrings.NoPathResolved, - "WebCmdletInFileNoPathResolvedException", InFile); + string linkVerboseMsg = string.Format(CultureInfo.CurrentCulture, + WebCmdletStrings.FollowingRelLinkVerboseMsg, + uri.AbsoluteUri); + WriteVerbose(linkVerboseMsg); } - else + + using (HttpRequestMessage request = GetRequest(uri, stripAuthorization:false)) { - if (Directory.Exists(providerPaths[0])) + FillRequestStream(request); + try { - errorRecord = GetValidationError(WebCmdletStrings.DirectoryPathSpecified, - "WebCmdletInFileNotFilePathException", InFile); - } - _originalFilePath = InFile; - InFile = providerPaths[0]; - } - } - } - catch (ItemNotFoundException pathNotFound) - { - errorRecord = new ErrorRecord(pathNotFound.ErrorRecord, pathNotFound); - } - catch (ProviderNotFoundException providerNotFound) - { - errorRecord = new ErrorRecord(providerNotFound.ErrorRecord, providerNotFound); - } - catch (System.Management.Automation.DriveNotFoundException driveNotFound) - { - errorRecord = new ErrorRecord(driveNotFound.ErrorRecord, driveNotFound); - } + long requestContentLength = 0; + if (request.Content != null) + requestContentLength = request.Content.Headers.ContentLength.Value; - if (errorRecord != null) - { - ThrowTerminatingError(errorRecord); + string reqVerboseMsg = String.Format(CultureInfo.CurrentCulture, + WebCmdletStrings.WebMethodInvocationVerboseMsg, + request.Method, + request.RequestUri, + requestContentLength); + WriteVerbose(reqVerboseMsg); + + HttpResponseMessage response = GetResponse(client, request, stripAuthorization); + + string contentType = ContentHelper.GetContentType(response); + string respVerboseMsg = string.Format(CultureInfo.CurrentCulture, + WebCmdletStrings.WebResponseVerboseMsg, + response.Content.Headers.ContentLength, + contentType); + WriteVerbose(respVerboseMsg); + + if (!response.IsSuccessStatusCode) + { + string message = String.Format(CultureInfo.CurrentCulture, WebCmdletStrings.ResponseStatusCodeFailure, + (int)response.StatusCode, response.ReasonPhrase); + HttpResponseException httpEx = new HttpResponseException(message, response); + ErrorRecord er = new ErrorRecord(httpEx, "WebCmdletWebResponseException", ErrorCategory.InvalidOperation, request); + string detailMsg = ""; + StreamReader reader = null; + try + { + reader = new StreamReader(StreamHelper.GetResponseStream(response)); + // remove HTML tags making it easier to read + detailMsg = System.Text.RegularExpressions.Regex.Replace(reader.ReadToEnd(), "<[^>]*>",""); + } + catch (Exception) + { + // catch all + } + finally + { + if (reader != null) + { + reader.Dispose(); + } + } + if (!String.IsNullOrEmpty(detailMsg)) + { + er.ErrorDetails = new ErrorDetails(detailMsg); + } + ThrowTerminatingError(er); + } + + if (_parseRelLink || _followRelLink) + { + ParseLinkHeader(response, uri); + } + ProcessResponse(response); + UpdateSession(response); + + // If we hit our maximum redirection count, generate an error. + // Errors with redirection counts of greater than 0 are handled automatically by .NET, but are + // impossible to detect programmatically when we hit this limit. By handling this ourselves + // (and still writing out the result), users can debug actual HTTP redirect problems. + if (WebSession.MaximumRedirection == 0) // Indicate "HttpClientHandler.AllowAutoRedirect == false" + { + if (response.StatusCode == HttpStatusCode.Found || + response.StatusCode == HttpStatusCode.Moved || + response.StatusCode == HttpStatusCode.MovedPermanently) + { + ErrorRecord er = new ErrorRecord(new InvalidOperationException(), "MaximumRedirectExceeded", ErrorCategory.InvalidOperation, request); + er.ErrorDetails = new ErrorDetails(WebCmdletStrings.MaximumRedirectionCountExceeded); + WriteError(er); + } + } + } + catch (HttpRequestException ex) + { + ErrorRecord er = new ErrorRecord(ex, "WebCmdletWebResponseException", ErrorCategory.InvalidOperation, request); + if (ex.InnerException != null) + { + er.ErrorDetails = new ErrorDetails(ex.InnerException.Message); + } + ThrowTerminatingError(er); + } + + if (_followRelLink) + { + if (!_relationLink.ContainsKey("next")) + { + return; + } + uri = new Uri(_relationLink["next"]); + followedRelLink++; + } + } + } + while (_followRelLink && (followedRelLink < _maximumFollowRelLink)); } } - - // output ?? - if (PassThru && (OutFile == null)) + catch (CryptographicException ex) { - ErrorRecord error = GetValidationError(WebCmdletStrings.OutFileMissing, - "WebCmdletOutFileMissingException"); - ThrowTerminatingError(error); + ErrorRecord er = new ErrorRecord(ex, "WebCmdletCertificateException", ErrorCategory.SecurityError, null); + ThrowTerminatingError(er); + } + catch (NotSupportedException ex) + { + ErrorRecord er = new ErrorRecord(ex, "WebCmdletIEDomNotSupportedException", ErrorCategory.NotImplemented, null); + ThrowTerminatingError(er); } } - internal virtual void PrepareSession() + /// + /// Implementing ^C, after start the BeginGetResponse + /// + protected override void StopProcessing() { - // make sure we have a valid WebRequestSession object to work with - if (null == WebSession) + if (_cancelToken != null) { - WebSession = new WebRequestSession(); + _cancelToken.Cancel(); } + } - if (null != SessionVariable) - { - // save the session back to the PS environment if requested - PSVariableIntrinsics vi = SessionState.PSVariable; - vi.Set(SessionVariable, WebSession); - } + #endregion Overrides - // - // handle credentials - // - if (null != Credential && Authentication == WebAuthenticationType.None) - { - // get the relevant NetworkCredential - NetworkCredential netCred = Credential.GetNetworkCredential(); - WebSession.Credentials = netCred; + #region Virtual Methods - // supplying a credential overrides the UseDefaultCredentials setting - WebSession.UseDefaultCredentials = false; - } - else if ((null != Credential || null!= Token) && Authentication != WebAuthenticationType.None) + internal virtual void FillRequestStream(HttpRequestMessage request) + { + if (null == request) { throw new ArgumentNullException("request"); } + + // set the content type + if (ContentType != null) { - ProcessAuthentication(); + WebSession.ContentHeaders[HttpKnownHeaderNames.ContentType] = ContentType; + //request } - else if (UseDefaultCredentials) + // ContentType == null + else if (Method == WebRequestMethod.Post || (IsCustomMethodSet() && CustomMethod.ToUpperInvariant() == "POST")) { - WebSession.UseDefaultCredentials = true; + // Win8:545310 Invoke-WebRequest does not properly set MIME type for POST + string contentType = null; + WebSession.ContentHeaders.TryGetValue(HttpKnownHeaderNames.ContentType, out contentType); + if (string.IsNullOrEmpty(contentType)) + { + WebSession.ContentHeaders[HttpKnownHeaderNames.ContentType] = "application/x-www-form-urlencoded"; + } } - - if (null != CertificateThumbprint) + // coerce body into a usable form + if (Body != null) { - X509Store store = new X509Store(StoreName.My, StoreLocation.CurrentUser); - store.Open(OpenFlags.ReadOnly | OpenFlags.OpenExistingOnly); - X509Certificate2Collection collection = (X509Certificate2Collection)store.Certificates; - X509Certificate2Collection tbCollection = (X509Certificate2Collection)collection.Find(X509FindType.FindByThumbprint, CertificateThumbprint, false); - if (tbCollection.Count == 0) + object content = Body; + + // make sure we're using the base object of the body, not the PSObject wrapper + PSObject psBody = Body as PSObject; + if (psBody != null) { - CryptographicException ex = new CryptographicException(WebCmdletStrings.ThumbprintNotFound); - throw ex; + content = psBody.BaseObject; } - foreach (X509Certificate2 tbCert in tbCollection) + + if (content is FormObject) { - X509Certificate certificate = (X509Certificate)tbCert; - WebSession.AddCertificate(certificate); + FormObject form = content as FormObject; + SetRequestContent(request, form.Fields); + } + else if (content is IDictionary && request.Method != HttpMethod.Get) + { + IDictionary dictionary = content as IDictionary; + SetRequestContent(request, dictionary); + } + else if (content is XmlNode) + { + XmlNode xmlNode = content as XmlNode; + SetRequestContent(request, xmlNode); + } + else if (content is Stream) + { + Stream stream = content as Stream; + SetRequestContent(request, stream); + } + else if (content is byte[]) + { + byte[] bytes = content as byte[]; + SetRequestContent(request, bytes); + } + else if (content is MultipartFormDataContent multipartFormDataContent) + { + WebSession.ContentHeaders.Clear(); + SetRequestContent(request, multipartFormDataContent); + } + else + { + SetRequestContent(request, + (string)LanguagePrimitives.ConvertTo(content, typeof(string), CultureInfo.InvariantCulture)); } } - - if (null != Certificate) + else if (InFile != null) // copy InFile data { - WebSession.AddCertificate(Certificate); + try + { + // open the input file + SetRequestContent(request, new FileStream(InFile, FileMode.Open)); + } + catch (UnauthorizedAccessException) + { + string msg = string.Format(CultureInfo.InvariantCulture, WebCmdletStrings.AccessDenied, + _originalFilePath); + throw new UnauthorizedAccessException(msg); + } } - // - // handle the user agent - // - if (null != UserAgent) + // Add the content headers + if (request.Content != null) { - // store the UserAgent string - WebSession.UserAgent = UserAgent; + foreach (var entry in WebSession.ContentHeaders) + { + request.Content.Headers.Add(entry.Key, entry.Value); + } } + } - if (null != Proxy) + // NOTE: Only pass true for handleRedirect if the original request has an authorization header + // and PreserveAuthorizationOnRedirect is NOT set. + internal virtual HttpClient GetHttpClient(bool handleRedirect) + { + // By default the HttpClientHandler will automatically decompress GZip and Deflate content + HttpClientHandler handler = new HttpClientHandler(); + handler.CookieContainer = WebSession.Cookies; + + // set the credentials used by this request + if (WebSession.UseDefaultCredentials) { - WebProxy webProxy = new WebProxy(Proxy); - webProxy.BypassProxyOnLocal = false; - if (null != ProxyCredential) - { - webProxy.Credentials = ProxyCredential.GetNetworkCredential(); - } - else if (ProxyUseDefaultCredentials) - { - // If both ProxyCredential and ProxyUseDefaultCredentials are passed, - // UseDefaultCredentials will overwrite the supplied credentials. - webProxy.UseDefaultCredentials = true; - } - WebSession.Proxy = webProxy; + // the UseDefaultCredentials flag overrides other supplied credentials + handler.UseDefaultCredentials = true; } - - if (-1 < MaximumRedirection) + else if (WebSession.Credentials != null) { - WebSession.MaximumRedirection = MaximumRedirection; + handler.Credentials = WebSession.Credentials; } - // store the other supplied headers - if (null != Headers) + if (NoProxy) { - foreach (string key in Headers.Keys) - { - // add the header value (or overwrite it if already present) - WebSession.Headers[key] = Headers[key].ToString(); - } + handler.UseProxy = false; + } + else if (WebSession.Proxy != null) + { + handler.Proxy = WebSession.Proxy; } - } - - #endregion Virtual Methods - - #region Helper Properties - - internal string QualifiedOutFile - { - get { return (QualifyFilePath(OutFile)); } - } - - internal bool ShouldSaveToOutFile - { - get { return (!string.IsNullOrEmpty(OutFile)); } - } - - internal bool ShouldWriteToPipeline - { - get { return (!ShouldSaveToOutFile || PassThru); } - } - - #endregion Helper Properties - #region Helper Methods + if (null != WebSession.Certificates) + { + handler.ClientCertificates.AddRange(WebSession.Certificates); + } - private Uri PrepareUri(Uri uri) - { - uri = CheckProtocol(uri); + if (SkipCertificateCheck) + { + handler.ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator; + handler.ClientCertificateOptions = ClientCertificateOption.Manual; + } - // before creating the web request, - // preprocess Body if content is a dictionary and method is GET (set as query) - IDictionary bodyAsDictionary; - LanguagePrimitives.TryConvertTo(Body, out bodyAsDictionary); - if ((null != bodyAsDictionary) - && ((IsStandardMethodSet() && (Method == WebRequestMethod.Default || Method == WebRequestMethod.Get)) - || (IsCustomMethodSet() && CustomMethod.ToUpperInvariant() == "GET"))) + // This indicates GetResponse will handle redirects. + if (handleRedirect) { - UriBuilder uriBuilder = new UriBuilder(uri); - if (uriBuilder.Query != null && uriBuilder.Query.Length > 1) + handler.AllowAutoRedirect = false; + } + else if (WebSession.MaximumRedirection > -1) + { + if (WebSession.MaximumRedirection == 0) { - uriBuilder.Query = uriBuilder.Query.Substring(1) + "&" + FormatDictionary(bodyAsDictionary); + handler.AllowAutoRedirect = false; } else { - uriBuilder.Query = FormatDictionary(bodyAsDictionary); + handler.MaxAutomaticRedirections = WebSession.MaximumRedirection; } - uri = uriBuilder.Uri; - // set body to null to prevent later FillRequestStream - Body = null; } - return uri; - } + handler.SslProtocols = (SslProtocols)SslProtocol; - private Uri CheckProtocol(Uri uri) - { - if (null == uri) { throw new ArgumentNullException("uri"); } - if (!uri.IsAbsoluteUri) + HttpClient httpClient = new HttpClient(handler); + + // check timeout setting (in seconds instead of milliseconds as in HttpWebRequest) + if (TimeoutSec == 0) { - uri = new Uri("http://" + uri.OriginalString); + // A zero timeout means infinite + httpClient.Timeout = TimeSpan.FromMilliseconds(Timeout.Infinite); + } + else if (TimeoutSec > 0) + { + httpClient.Timeout = new TimeSpan(0, 0, TimeoutSec); } - return (uri); - } - private string QualifyFilePath(string path) - { - string resolvedFilePath = PathUtils.ResolveFilePath(path, this, false); - return resolvedFilePath; + return httpClient; } - private string FormatDictionary(IDictionary content) + internal virtual HttpRequestMessage GetRequest(Uri uri, bool stripAuthorization) { - if (content == null) - throw new ArgumentNullException("content"); + Uri requestUri = PrepareUri(uri); + HttpMethod httpMethod = null; - StringBuilder bodyBuilder = new StringBuilder(); - foreach (string key in content.Keys) + switch (ParameterSetName) { - if (0 < bodyBuilder.Length) + case "StandardMethodNoProxy": + goto case "StandardMethod"; + case "StandardMethod": + // set the method if the parameter was provided + httpMethod = GetHttpMethod(Method); + break; + case "CustomMethodNoProxy": + goto case "CustomMethod"; + case "CustomMethod": + if (!string.IsNullOrEmpty(CustomMethod)) + { + // set the method if the parameter was provided + httpMethod = new HttpMethod(CustomMethod.ToString().ToUpperInvariant()); + } + break; + } + + // create the base WebRequest object + var request = new HttpRequestMessage(httpMethod, requestUri); + + // pull in session data + if (WebSession.Headers.Count > 0) + { + WebSession.ContentHeaders.Clear(); + foreach (var entry in WebSession.Headers) { - bodyBuilder.Append("&"); + if (HttpKnownHeaderNames.ContentHeaders.Contains(entry.Key)) + { + WebSession.ContentHeaders.Add(entry.Key, entry.Value); + } + else + { + if (stripAuthorization + && + String.Equals(entry.Key, HttpKnownHeaderNames.Authorization.ToString(), StringComparison.OrdinalIgnoreCase) + ) + { + continue; + } + + if (SkipHeaderValidation) + { + request.Headers.TryAddWithoutValidation(entry.Key, entry.Value); + } + else + { + request.Headers.Add(entry.Key, entry.Value); + } + } } + } - object value = content[key]; + // Set 'Transfer-Encoding: chunked' if 'Transfer-Encoding' is specified + if (WebSession.Headers.ContainsKey(HttpKnownHeaderNames.TransferEncoding)) + { + request.Headers.TransferEncodingChunked = true; + } - // URLEncode the key and value - string encodedKey = WebUtility.UrlEncode(key); - string encodedValue = String.Empty; - if (null != value) + // Set 'User-Agent' if WebSession.Headers doesn't already contain it + string userAgent = null; + if (WebSession.Headers.TryGetValue(HttpKnownHeaderNames.UserAgent, out userAgent)) + { + WebSession.UserAgent = userAgent; + } + else + { + if (SkipHeaderValidation) { - encodedValue = WebUtility.UrlEncode(value.ToString()); + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.UserAgent, WebSession.UserAgent); + } + else + { + request.Headers.Add(HttpKnownHeaderNames.UserAgent, WebSession.UserAgent); } - bodyBuilder.AppendFormat("{0}={1}", encodedKey, encodedValue); } - return bodyBuilder.ToString(); - } - private ErrorRecord GetValidationError(string msg, string errorId) - { - var ex = new ValidationMetadataException(msg); - var error = new ErrorRecord(ex, errorId, ErrorCategory.InvalidArgument, this); - return (error); - } + // Set 'Keep-Alive' to false. This means set the Connection to 'Close'. + if (DisableKeepAlive) + { + request.Headers.Add(HttpKnownHeaderNames.Connection, "Close"); + } - private ErrorRecord GetValidationError(string msg, string errorId, params object[] args) - { - msg = string.Format(CultureInfo.InvariantCulture, msg, args); - var ex = new ValidationMetadataException(msg); - var error = new ErrorRecord(ex, errorId, ErrorCategory.InvalidArgument, this); - return (error); - } + // Set 'Transfer-Encoding' + if (TransferEncoding != null) + { + request.Headers.TransferEncodingChunked = true; + var headerValue = new TransferCodingHeaderValue(TransferEncoding); + if (!request.Headers.TransferEncoding.Contains(headerValue)) + { + request.Headers.TransferEncoding.Add(headerValue); + } + } - private bool IsStandardMethodSet() - { - return (ParameterSetName == "StandardMethod"); - } + // Some web sites (e.g. Twitter) will return exception on POST when Expect100 is sent + // Default behavior is continue to send body content anyway after a short period + // Here it send the two part as a whole. + request.Headers.ExpectContinue = false; - private bool IsCustomMethodSet() - { - return (ParameterSetName == "CustomMethod"); + return (request); } - private string GetBasicAuthorizationHeader() + internal virtual HttpResponseMessage GetResponse(HttpClient client, HttpRequestMessage request, bool stripAuthorization) { - string unencoded = String.Format("{0}:{1}", Credential.UserName, Credential.GetNetworkCredential().Password); - Byte[] bytes = Encoding.UTF8.GetBytes(unencoded); - return String.Format("Basic {0}", Convert.ToBase64String(bytes)); - } + if (client == null) { throw new ArgumentNullException("client"); } + if (request == null) { throw new ArgumentNullException("request"); } - private string GetBearerAuthorizationHeader() - { - return String.Format("Bearer {0}", new NetworkCredential(String.Empty, Token).Password); - } + _cancelToken = new CancellationTokenSource(); + HttpResponseMessage response = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _cancelToken.Token).GetAwaiter().GetResult(); - private void ProcessAuthentication() - { - if(Authentication == WebAuthenticationType.Basic) - { - WebSession.Headers["Authorization"] = GetBasicAuthorizationHeader(); - } - else if (Authentication == WebAuthenticationType.Bearer || Authentication == WebAuthenticationType.OAuth) - { - WebSession.Headers["Authorization"] = GetBearerAuthorizationHeader(); - } - else + if (stripAuthorization && IsRedirectCode(response.StatusCode)) { - Diagnostics.Assert(false, String.Format("Unrecognized Authentication value: {0}", Authentication)); - } - } + _cancelToken.Cancel(); + _cancelToken = null; - #endregion Helper Methods - } + // if explicit count was provided, reduce it for this redirection. + if (WebSession.MaximumRedirection > 0) + { + WebSession.MaximumRedirection--; + } + // For selected redirects that used POST, GET must be used with the + // redirected Location. + // Since GET is the default; POST only occurs when -Method POST is used. + if (Method == WebRequestMethod.Post && IsRedirectToGet(response.StatusCode)) + { + // See https://msdn.microsoft.com/en-us/library/system.net.httpstatuscode(v=vs.110).aspx + Method = WebRequestMethod.Get; + } - // TODO: Merge Partials + // recreate the HttpClient with redirection enabled since the first call suppressed redirection + using (client = GetHttpClient(false)) + using (HttpRequestMessage redirectRequest = GetRequest(response.Headers.Location, stripAuthorization:true)) + { + FillRequestStream(redirectRequest); + _cancelToken = new CancellationTokenSource(); + response = client.SendAsync(redirectRequest, HttpCompletionOption.ResponseHeadersRead, _cancelToken.Token).GetAwaiter().GetResult(); + } + } + return response; + } - /// - /// Exception class for webcmdlets to enable returning HTTP error response - /// - public sealed class HttpResponseException : HttpRequestException - { - /// - /// Constructor for HttpResponseException - /// - /// Message for the exception - /// Response from the HTTP server - public HttpResponseException (string message, HttpResponseMessage response) : base(message) + // Returns true if the status code is one of the supported redirection codes. + static bool IsRedirectCode(HttpStatusCode code) { - Response = response; + int intCode = (int) code; + return + ( + (intCode >= 300 && intCode < 304) + || + intCode == 307 + ); } - /// - /// HTTP error response - /// - public HttpResponseMessage Response { get; private set; } - } - - /// - /// Base class for Invoke-RestMethod and Invoke-WebRequest commands. - /// - public abstract partial class WebRequestPSCmdlet : PSCmdlet - { - - /// - /// gets or sets the PreserveAuthorizationOnRedirect property - /// - /// - /// This property overrides compatibility with web requests on Windows. - /// On FullCLR (WebRequest), authorization headers are stripped during redirect. - /// CoreCLR (HTTPClient) does not have this behavior so web requests that work on - /// PowerShell/FullCLR can fail with PowerShell/CoreCLR. To provide compatibility, - /// we'll detect requests with an Authorization header and automatically strip - /// the header when the first redirect occurs. This switch turns off this logic for - /// edge cases where the authorization header needs to be preserved across redirects. - /// - [Parameter] - public virtual SwitchParameter PreserveAuthorizationOnRedirect { get; set; } - - /// - /// gets or sets the SkipHeaderValidation property - /// - /// - /// This property adds headers to the request's header collection without validation. - /// - [Parameter] - public virtual SwitchParameter SkipHeaderValidation { get; set; } - - #region Abstract Methods - - /// - /// Read the supplied WebResponse object and push the - /// resulting output into the pipeline. - /// - /// Instance of a WebResponse object to be processed - internal abstract void ProcessResponse(HttpResponseMessage response); - - #endregion Abstract Methods - - /// - /// Cancellation token source - /// - private CancellationTokenSource _cancelToken = null; - - /// - /// Parse Rel Links - /// - internal bool _parseRelLink = false; - - /// - /// Automatically follow Rel Links - /// - internal bool _followRelLink = false; - - /// - /// Automatically follow Rel Links - /// - internal Dictionary _relationLink = null; - - /// - /// Maximum number of Rel Links to follow - /// - internal int _maximumFollowRelLink = Int32.MaxValue; - - private HttpMethod GetHttpMethod(WebRequestMethod method) + // Returns true if the status code is a redirection code and the action requires switching from POST to GET on redirection. + // NOTE: Some of these status codes map to the same underlying value but spelling them out for completeness. + static bool IsRedirectToGet(HttpStatusCode code) { - switch (Method) - { - case WebRequestMethod.Default: - case WebRequestMethod.Get: - return HttpMethod.Get; - case WebRequestMethod.Head: - return HttpMethod.Head; - case WebRequestMethod.Post: - return HttpMethod.Post; - case WebRequestMethod.Put: - return HttpMethod.Put; - case WebRequestMethod.Delete: - return HttpMethod.Delete; - case WebRequestMethod.Trace: - return HttpMethod.Trace; - case WebRequestMethod.Options: - return HttpMethod.Options; - default: - // Merge and Patch - return new HttpMethod(Method.ToString().ToUpperInvariant()); - } + return + ( + code == HttpStatusCode.Found + || + code == HttpStatusCode.Moved + || + code == HttpStatusCode.Redirect + || + code == HttpStatusCode.RedirectMethod + || + code == HttpStatusCode.TemporaryRedirect + || + code == HttpStatusCode.RedirectKeepVerb + || + code == HttpStatusCode.SeeOther + ); } - #region Virtual Methods - - // NOTE: Only pass true for handleRedirect if the original request has an authorization header - // and PreserveAuthorizationOnRedirect is NOT set. - internal virtual HttpClient GetHttpClient(bool handleRedirect) + internal virtual void PrepareSession() { - // By default the HttpClientHandler will automatically decompress GZip and Deflate content - HttpClientHandler handler = new HttpClientHandler(); - handler.CookieContainer = WebSession.Cookies; - - // set the credentials used by this request - if (WebSession.UseDefaultCredentials) + // make sure we have a valid WebRequestSession object to work with + if (null == WebSession) { - // the UseDefaultCredentials flag overrides other supplied credentials - handler.UseDefaultCredentials = true; + WebSession = new WebRequestSession(); } - else if (WebSession.Credentials != null) + + if (null != SessionVariable) { - handler.Credentials = WebSession.Credentials; + // save the session back to the PS environment if requested + PSVariableIntrinsics vi = SessionState.PSVariable; + vi.Set(SessionVariable, WebSession); } - if (NoProxy) + // + // handle credentials + // + if (null != Credential && Authentication == WebAuthenticationType.None) { - handler.UseProxy = false; + // get the relevant NetworkCredential + NetworkCredential netCred = Credential.GetNetworkCredential(); + WebSession.Credentials = netCred; + + // supplying a credential overrides the UseDefaultCredentials setting + WebSession.UseDefaultCredentials = false; } - else if (WebSession.Proxy != null) + else if ((null != Credential || null!= Token) && Authentication != WebAuthenticationType.None) { - handler.Proxy = WebSession.Proxy; + ProcessAuthentication(); } - - if (null != WebSession.Certificates) + else if (UseDefaultCredentials) { - handler.ClientCertificates.AddRange(WebSession.Certificates); + WebSession.UseDefaultCredentials = true; } - if (SkipCertificateCheck) - { - handler.ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator; - handler.ClientCertificateOptions = ClientCertificateOption.Manual; - } - // This indicates GetResponse will handle redirects. - if (handleRedirect) - { - handler.AllowAutoRedirect = false; - } - else if (WebSession.MaximumRedirection > -1) + if (null != CertificateThumbprint) { - if (WebSession.MaximumRedirection == 0) + X509Store store = new X509Store(StoreName.My, StoreLocation.CurrentUser); + store.Open(OpenFlags.ReadOnly | OpenFlags.OpenExistingOnly); + X509Certificate2Collection collection = (X509Certificate2Collection)store.Certificates; + X509Certificate2Collection tbCollection = (X509Certificate2Collection)collection.Find(X509FindType.FindByThumbprint, CertificateThumbprint, false); + if (tbCollection.Count == 0) { - handler.AllowAutoRedirect = false; + CryptographicException ex = new CryptographicException(WebCmdletStrings.ThumbprintNotFound); + throw ex; } - else + foreach (X509Certificate2 tbCert in tbCollection) { - handler.MaxAutomaticRedirections = WebSession.MaximumRedirection; + X509Certificate certificate = (X509Certificate)tbCert; + WebSession.AddCertificate(certificate); } } - handler.SslProtocols = (SslProtocols)SslProtocol; - - - HttpClient httpClient = new HttpClient(handler); - - // check timeout setting (in seconds instead of milliseconds as in HttpWebRequest) - if (TimeoutSec == 0) - { - // A zero timeout means infinite - httpClient.Timeout = TimeSpan.FromMilliseconds(Timeout.Infinite); - } - else if (TimeoutSec > 0) - { - httpClient.Timeout = new TimeSpan(0, 0, TimeoutSec); - } - - return httpClient; - } - - internal virtual HttpRequestMessage GetRequest(Uri uri, bool stripAuthorization) - { - Uri requestUri = PrepareUri(uri); - HttpMethod httpMethod = null; - - switch (ParameterSetName) - { - case "StandardMethodNoProxy": - goto case "StandardMethod"; - case "StandardMethod": - // set the method if the parameter was provided - httpMethod = GetHttpMethod(Method); - break; - case "CustomMethodNoProxy": - goto case "CustomMethod"; - case "CustomMethod": - if (!string.IsNullOrEmpty(CustomMethod)) - { - // set the method if the parameter was provided - httpMethod = new HttpMethod(CustomMethod.ToString().ToUpperInvariant()); - } - break; - } - - // create the base WebRequest object - var request = new HttpRequestMessage(httpMethod, requestUri); - - // pull in session data - if (WebSession.Headers.Count > 0) + if (null != Certificate) { - WebSession.ContentHeaders.Clear(); - foreach (var entry in WebSession.Headers) - { - if (HttpKnownHeaderNames.ContentHeaders.Contains(entry.Key)) - { - WebSession.ContentHeaders.Add(entry.Key, entry.Value); - } - else - { - if (stripAuthorization - && - String.Equals(entry.Key, HttpKnownHeaderNames.Authorization.ToString(), StringComparison.OrdinalIgnoreCase) - ) - { - continue; - } - - if (SkipHeaderValidation) - { - request.Headers.TryAddWithoutValidation(entry.Key, entry.Value); - } - else - { - request.Headers.Add(entry.Key, entry.Value); - } - } - } + WebSession.AddCertificate(Certificate); } - // Set 'Transfer-Encoding: chunked' if 'Transfer-Encoding' is specified - if (WebSession.Headers.ContainsKey(HttpKnownHeaderNames.TransferEncoding)) + // + // handle the user agent + // + if (null != UserAgent) { - request.Headers.TransferEncodingChunked = true; + // store the UserAgent string + WebSession.UserAgent = UserAgent; } - // Set 'User-Agent' if WebSession.Headers doesn't already contain it - string userAgent = null; - if (WebSession.Headers.TryGetValue(HttpKnownHeaderNames.UserAgent, out userAgent)) - { - WebSession.UserAgent = userAgent; - } - else + if (null != Proxy) { - if (SkipHeaderValidation) + WebProxy webProxy = new WebProxy(Proxy); + webProxy.BypassProxyOnLocal = false; + if (null != ProxyCredential) { - request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.UserAgent, WebSession.UserAgent); + webProxy.Credentials = ProxyCredential.GetNetworkCredential(); } - else + else if (ProxyUseDefaultCredentials) { - request.Headers.Add(HttpKnownHeaderNames.UserAgent, WebSession.UserAgent); + // If both ProxyCredential and ProxyUseDefaultCredentials are passed, + // UseDefaultCredentials will overwrite the supplied credentials. + webProxy.UseDefaultCredentials = true; } - + WebSession.Proxy = webProxy; } - // Set 'Keep-Alive' to false. This means set the Connection to 'Close'. - if (DisableKeepAlive) + if (-1 < MaximumRedirection) { - request.Headers.Add(HttpKnownHeaderNames.Connection, "Close"); + WebSession.MaximumRedirection = MaximumRedirection; } - // Set 'Transfer-Encoding' - if (TransferEncoding != null) + // store the other supplied headers + if (null != Headers) { - request.Headers.TransferEncodingChunked = true; - var headerValue = new TransferCodingHeaderValue(TransferEncoding); - if (!request.Headers.TransferEncoding.Contains(headerValue)) + foreach (string key in Headers.Keys) { - request.Headers.TransferEncoding.Add(headerValue); + // add the header value (or overwrite it if already present) + WebSession.Headers[key] = Headers[key].ToString(); } } + } - // Some web sites (e.g. Twitter) will return exception on POST when Expect100 is sent - // Default behavior is continue to send body content anyway after a short period - // Here it send the two part as a whole. - request.Headers.ExpectContinue = false; - - return (request); + internal virtual void UpdateSession(HttpResponseMessage response) + { + if (response == null) { throw new ArgumentNullException("response"); } } - internal virtual void FillRequestStream(HttpRequestMessage request) + internal virtual void ValidateParameters() { - if (null == request) { throw new ArgumentNullException("request"); } + // sessions + if ((null != WebSession) && (null != SessionVariable)) + { + ErrorRecord error = GetValidationError(WebCmdletStrings.SessionConflict, + "WebCmdletSessionConflictException"); + ThrowTerminatingError(error); + } - // set the content type - if (ContentType != null) + // Authentication + if (UseDefaultCredentials && (Authentication != WebAuthenticationType.None)) { - WebSession.ContentHeaders[HttpKnownHeaderNames.ContentType] = ContentType; - //request + ErrorRecord error = GetValidationError(WebCmdletStrings.AuthenticationConflict, + "WebCmdletAuthenticationConflictException"); + ThrowTerminatingError(error); } - // ContentType == null - else if (Method == WebRequestMethod.Post || (IsCustomMethodSet() && CustomMethod.ToUpperInvariant() == "POST")) + if ((Authentication != WebAuthenticationType.None) && (null != Token) && (null != Credential)) { - // Win8:545310 Invoke-WebRequest does not properly set MIME type for POST - string contentType = null; - WebSession.ContentHeaders.TryGetValue(HttpKnownHeaderNames.ContentType, out contentType); - if (string.IsNullOrEmpty(contentType)) - { - WebSession.ContentHeaders[HttpKnownHeaderNames.ContentType] = "application/x-www-form-urlencoded"; - } + ErrorRecord error = GetValidationError(WebCmdletStrings.AuthenticationTokenConflict, + "WebCmdletAuthenticationTokenConflictException"); + ThrowTerminatingError(error); } - - // coerce body into a usable form - if (Body != null) + if ((Authentication == WebAuthenticationType.Basic) && (null == Credential)) { - object content = Body; - - // make sure we're using the base object of the body, not the PSObject wrapper - PSObject psBody = Body as PSObject; - if (psBody != null) - { - content = psBody.BaseObject; - } - - if (content is FormObject) - { - FormObject form = content as FormObject; - SetRequestContent(request, form.Fields); - } - else if (content is IDictionary && request.Method != HttpMethod.Get) - { - IDictionary dictionary = content as IDictionary; - SetRequestContent(request, dictionary); - } - else if (content is XmlNode) - { - XmlNode xmlNode = content as XmlNode; - SetRequestContent(request, xmlNode); - } - else if (content is Stream) - { - Stream stream = content as Stream; - SetRequestContent(request, stream); - } - else if (content is byte[]) - { - byte[] bytes = content as byte[]; - SetRequestContent(request, bytes); - } - else if (content is MultipartFormDataContent multipartFormDataContent) - { - WebSession.ContentHeaders.Clear(); - SetRequestContent(request, multipartFormDataContent); - } - else - { - SetRequestContent(request, - (string)LanguagePrimitives.ConvertTo(content, typeof(string), CultureInfo.InvariantCulture)); - } + ErrorRecord error = GetValidationError(WebCmdletStrings.AuthenticationCredentialNotSupplied, + "WebCmdletAuthenticationCredentialNotSuppliedException"); + ThrowTerminatingError(error); } - else if (InFile != null) // copy InFile data + if ((Authentication == WebAuthenticationType.OAuth || Authentication == WebAuthenticationType.Bearer) && (null == Token)) { - try - { - // open the input file - SetRequestContent(request, new FileStream(InFile, FileMode.Open)); - } - catch (UnauthorizedAccessException) - { - string msg = string.Format(CultureInfo.InvariantCulture, WebCmdletStrings.AccessDenied, - _originalFilePath); - throw new UnauthorizedAccessException(msg); - } + ErrorRecord error = GetValidationError(WebCmdletStrings.AuthenticationTokenNotSupplied, + "WebCmdletAuthenticationTokenNotSuppliedException"); + ThrowTerminatingError(error); } - - // Add the content headers - if (request.Content != null) + if (!AllowUnencryptedAuthentication && (Authentication != WebAuthenticationType.None) && (Uri.Scheme != "https")) { - foreach (var entry in WebSession.ContentHeaders) - { - request.Content.Headers.Add(entry.Key, entry.Value); - } + ErrorRecord error = GetValidationError(WebCmdletStrings.AllowUnencryptedAuthenticationRequired, + "WebCmdletAllowUnencryptedAuthenticationRequiredException"); + ThrowTerminatingError(error); } - } - - // Returns true if the status code is one of the supported redirection codes. - static bool IsRedirectCode(HttpStatusCode code) - { - int intCode = (int) code; - return - ( - (intCode >= 300 && intCode < 304) - || - intCode == 307 - ); - } - - // Returns true if the status code is a redirection code and the action requires switching from POST to GET on redirection. - // NOTE: Some of these status codes map to the same underlying value but spelling them out for completeness. - static bool IsRedirectToGet(HttpStatusCode code) - { - return - ( - code == HttpStatusCode.Found - || - code == HttpStatusCode.Moved - || - code == HttpStatusCode.Redirect - || - code == HttpStatusCode.RedirectMethod - || - code == HttpStatusCode.TemporaryRedirect - || - code == HttpStatusCode.RedirectKeepVerb - || - code == HttpStatusCode.SeeOther - ); - } - - internal virtual HttpResponseMessage GetResponse(HttpClient client, HttpRequestMessage request, bool stripAuthorization) - { - if (client == null) { throw new ArgumentNullException("client"); } - if (request == null) { throw new ArgumentNullException("request"); } - - _cancelToken = new CancellationTokenSource(); - HttpResponseMessage response = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _cancelToken.Token).GetAwaiter().GetResult(); - - if (stripAuthorization && IsRedirectCode(response.StatusCode)) + if (!AllowUnencryptedAuthentication && (null != Credential || UseDefaultCredentials) && (Uri.Scheme != "https")) { - _cancelToken.Cancel(); - _cancelToken = null; - - // if explicit count was provided, reduce it for this redirection. - if (WebSession.MaximumRedirection > 0) - { - WebSession.MaximumRedirection--; - } - // For selected redirects that used POST, GET must be used with the - // redirected Location. - // Since GET is the default; POST only occurs when -Method POST is used. - if (Method == WebRequestMethod.Post && IsRedirectToGet(response.StatusCode)) - { - // See https://msdn.microsoft.com/en-us/library/system.net.httpstatuscode(v=vs.110).aspx - Method = WebRequestMethod.Get; - } - - // recreate the HttpClient with redirection enabled since the first call suppressed redirection - using (client = GetHttpClient(false)) - using (HttpRequestMessage redirectRequest = GetRequest(response.Headers.Location, stripAuthorization:true)) - { - FillRequestStream(redirectRequest); - _cancelToken = new CancellationTokenSource(); - response = client.SendAsync(redirectRequest, HttpCompletionOption.ResponseHeadersRead, _cancelToken.Token).GetAwaiter().GetResult(); - } + ErrorRecord error = GetValidationError(WebCmdletStrings.AllowUnencryptedAuthenticationRequired, + "WebCmdletAllowUnencryptedAuthenticationRequiredException"); + ThrowTerminatingError(error); } - return response; - } - internal virtual void UpdateSession(HttpResponseMessage response) - { - if (response == null) { throw new ArgumentNullException("response"); } - } - - #endregion Virtual Methods + // credentials + if (UseDefaultCredentials && (null != Credential)) + { + ErrorRecord error = GetValidationError(WebCmdletStrings.CredentialConflict, + "WebCmdletCredentialConflictException"); + ThrowTerminatingError(error); + } - #region Overrides + // Proxy server + if (ProxyUseDefaultCredentials && (null != ProxyCredential)) + { + ErrorRecord error = GetValidationError(WebCmdletStrings.ProxyCredentialConflict, + "WebCmdletProxyCredentialConflictException"); + ThrowTerminatingError(error); + } + else if ((null == Proxy) && ((null != ProxyCredential) || ProxyUseDefaultCredentials)) + { + ErrorRecord error = GetValidationError(WebCmdletStrings.ProxyUriNotSupplied, + "WebCmdletProxyUriNotSuppliedException"); + ThrowTerminatingError(error); + } - /// - /// the main execution method for cmdlets derived from WebRequestPSCmdlet. - /// - protected override void ProcessRecord() - { - try + // request body content + if ((null != Body) && (null != InFile)) { - // Set cmdlet context for write progress - ValidateParameters(); - PrepareSession(); + ErrorRecord error = GetValidationError(WebCmdletStrings.BodyConflict, + "WebCmdletBodyConflictException"); + ThrowTerminatingError(error); + } - // if the request contains an authorization header and PreserveAuthorizationOnRedirect is not set, - // it needs to be stripped on the first redirect. - bool stripAuthorization = null != WebSession - && - null != WebSession.Headers - && - !PreserveAuthorizationOnRedirect.IsPresent - && - WebSession.Headers.ContainsKey(HttpKnownHeaderNames.Authorization.ToString()); + // validate InFile path + if (InFile != null) + { + ProviderInfo provider = null; + ErrorRecord errorRecord = null; - using (HttpClient client = GetHttpClient(stripAuthorization)) + try { - int followedRelLink = 0; - Uri uri = Uri; - do + Collection providerPaths = GetResolvedProviderPathFromPSPath(InFile, out provider); + + if (!provider.Name.Equals(FileSystemProvider.ProviderName, StringComparison.OrdinalIgnoreCase)) { - if (followedRelLink > 0) + errorRecord = GetValidationError(WebCmdletStrings.NotFilesystemPath, + "WebCmdletInFileNotFilesystemPathException", InFile); + } + else + { + if (providerPaths.Count > 1) { - string linkVerboseMsg = string.Format(CultureInfo.CurrentCulture, - WebCmdletStrings.FollowingRelLinkVerboseMsg, - uri.AbsoluteUri); - WriteVerbose(linkVerboseMsg); + errorRecord = GetValidationError(WebCmdletStrings.MultiplePathsResolved, + "WebCmdletInFileMultiplePathsResolvedException", InFile); } - - using (HttpRequestMessage request = GetRequest(uri, stripAuthorization:false)) + else if (providerPaths.Count == 0) { - FillRequestStream(request); - try - { - long requestContentLength = 0; - if (request.Content != null) - requestContentLength = request.Content.Headers.ContentLength.Value; - - string reqVerboseMsg = String.Format(CultureInfo.CurrentCulture, - WebCmdletStrings.WebMethodInvocationVerboseMsg, - request.Method, - request.RequestUri, - requestContentLength); - WriteVerbose(reqVerboseMsg); - - HttpResponseMessage response = GetResponse(client, request, stripAuthorization); - - string contentType = ContentHelper.GetContentType(response); - string respVerboseMsg = string.Format(CultureInfo.CurrentCulture, - WebCmdletStrings.WebResponseVerboseMsg, - response.Content.Headers.ContentLength, - contentType); - WriteVerbose(respVerboseMsg); - - if (!response.IsSuccessStatusCode) - { - string message = String.Format(CultureInfo.CurrentCulture, WebCmdletStrings.ResponseStatusCodeFailure, - (int)response.StatusCode, response.ReasonPhrase); - HttpResponseException httpEx = new HttpResponseException(message, response); - ErrorRecord er = new ErrorRecord(httpEx, "WebCmdletWebResponseException", ErrorCategory.InvalidOperation, request); - string detailMsg = ""; - StreamReader reader = null; - try - { - reader = new StreamReader(StreamHelper.GetResponseStream(response)); - // remove HTML tags making it easier to read - detailMsg = System.Text.RegularExpressions.Regex.Replace(reader.ReadToEnd(), "<[^>]*>",""); - } - catch (Exception) - { - // catch all - } - finally - { - if (reader != null) - { - reader.Dispose(); - } - } - if (!String.IsNullOrEmpty(detailMsg)) - { - er.ErrorDetails = new ErrorDetails(detailMsg); - } - ThrowTerminatingError(er); - } - - if (_parseRelLink || _followRelLink) - { - ParseLinkHeader(response, uri); - } - ProcessResponse(response); - UpdateSession(response); - - // If we hit our maximum redirection count, generate an error. - // Errors with redirection counts of greater than 0 are handled automatically by .NET, but are - // impossible to detect programmatically when we hit this limit. By handling this ourselves - // (and still writing out the result), users can debug actual HTTP redirect problems. - if (WebSession.MaximumRedirection == 0) // Indicate "HttpClientHandler.AllowAutoRedirect == false" - { - if (response.StatusCode == HttpStatusCode.Found || - response.StatusCode == HttpStatusCode.Moved || - response.StatusCode == HttpStatusCode.MovedPermanently) - { - ErrorRecord er = new ErrorRecord(new InvalidOperationException(), "MaximumRedirectExceeded", ErrorCategory.InvalidOperation, request); - er.ErrorDetails = new ErrorDetails(WebCmdletStrings.MaximumRedirectionCountExceeded); - WriteError(er); - } - } - } - catch (HttpRequestException ex) - { - ErrorRecord er = new ErrorRecord(ex, "WebCmdletWebResponseException", ErrorCategory.InvalidOperation, request); - if (ex.InnerException != null) - { - er.ErrorDetails = new ErrorDetails(ex.InnerException.Message); - } - ThrowTerminatingError(er); - } - - if (_followRelLink) + errorRecord = GetValidationError(WebCmdletStrings.NoPathResolved, + "WebCmdletInFileNoPathResolvedException", InFile); + } + else + { + if (Directory.Exists(providerPaths[0])) { - if (!_relationLink.ContainsKey("next")) - { - return; - } - uri = new Uri(_relationLink["next"]); - followedRelLink++; + errorRecord = GetValidationError(WebCmdletStrings.DirectoryPathSpecified, + "WebCmdletInFileNotFilePathException", InFile); } + _originalFilePath = InFile; + InFile = providerPaths[0]; } } - while (_followRelLink && (followedRelLink < _maximumFollowRelLink)); + } + catch (ItemNotFoundException pathNotFound) + { + errorRecord = new ErrorRecord(pathNotFound.ErrorRecord, pathNotFound); + } + catch (ProviderNotFoundException providerNotFound) + { + errorRecord = new ErrorRecord(providerNotFound.ErrorRecord, providerNotFound); + } + catch (System.Management.Automation.DriveNotFoundException driveNotFound) + { + errorRecord = new ErrorRecord(driveNotFound.ErrorRecord, driveNotFound); + } + + if (errorRecord != null) + { + ThrowTerminatingError(errorRecord); } } - catch (CryptographicException ex) - { - ErrorRecord er = new ErrorRecord(ex, "WebCmdletCertificateException", ErrorCategory.SecurityError, null); - ThrowTerminatingError(er); - } - catch (NotSupportedException ex) + + // output ?? + if (PassThru && (OutFile == null)) { - ErrorRecord er = new ErrorRecord(ex, "WebCmdletIEDomNotSupportedException", ErrorCategory.NotImplemented, null); - ThrowTerminatingError(er); + ErrorRecord error = GetValidationError(WebCmdletStrings.OutFileMissing, + "WebCmdletOutFileMissingException"); + ThrowTerminatingError(error); } } - /// - /// Implementing ^C, after start the BeginGetResponse - /// - protected override void StopProcessing() + #endregion Virtual Methods + + #region Helper Internal Methods + + internal void ParseLinkHeader(HttpResponseMessage response, System.Uri requestUri) { - if (_cancelToken != null) + if (_relationLink == null) { - _cancelToken.Cancel(); + _relationLink = new Dictionary(); + } + else + { + _relationLink.Clear(); } - } - - #endregion Overrides - #region Helper Methods + // we only support the URL in angle brackets and `rel`, other attributes are ignored + // user can still parse it themselves via the Headers property + string pattern = "<(?.*?)>;\\srel=\"(?.*?)\""; + IEnumerable links; + if (response.Headers.TryGetValues("Link", out links)) + { + foreach (string linkHeader in links) + { + foreach (string link in linkHeader.Split(",")) + { + Match match = Regex.Match(link, pattern); + if (match.Success) + { + string url = match.Groups["url"].Value; + string rel = match.Groups["rel"].Value; + if (url != String.Empty && rel != String.Empty && !_relationLink.ContainsKey(rel)) + { + Uri absoluteUri = new Uri(requestUri, url); + _relationLink.Add(rel, absoluteUri.AbsoluteUri.ToString()); + } + } + } + } + } + } /// /// Sets the ContentLength property of the request and writes the specified content to the request's RequestStream. @@ -1553,43 +1426,162 @@ internal long SetRequestContent(HttpRequestMessage request, IDictionary content) } - internal void ParseLinkHeader(HttpResponseMessage response, System.Uri requestUri) + #endregion Helper Internal Methods + + #region Helper Private Methods + + private Uri CheckProtocol(Uri uri) { - if (_relationLink == null) + if (null == uri) { throw new ArgumentNullException("uri"); } + + if (!uri.IsAbsoluteUri) { - _relationLink = new Dictionary(); + uri = new Uri("http://" + uri.OriginalString); } - else + return (uri); + } + + private string FormatDictionary(IDictionary content) + { + if (content == null) + throw new ArgumentNullException("content"); + + StringBuilder bodyBuilder = new StringBuilder(); + foreach (string key in content.Keys) { - _relationLink.Clear(); + if (0 < bodyBuilder.Length) + { + bodyBuilder.Append("&"); + } + + object value = content[key]; + + // URLEncode the key and value + string encodedKey = WebUtility.UrlEncode(key); + string encodedValue = String.Empty; + if (null != value) + { + encodedValue = WebUtility.UrlEncode(value.ToString()); + } + + bodyBuilder.AppendFormat("{0}={1}", encodedKey, encodedValue); + } + return bodyBuilder.ToString(); + } + + private string GetBasicAuthorizationHeader() + { + string unencoded = String.Format("{0}:{1}", Credential.UserName, Credential.GetNetworkCredential().Password); + Byte[] bytes = Encoding.UTF8.GetBytes(unencoded); + return String.Format("Basic {0}", Convert.ToBase64String(bytes)); + } + + private string GetBearerAuthorizationHeader() + { + return String.Format("Bearer {0}", new NetworkCredential(String.Empty, Token).Password); + } + + private HttpMethod GetHttpMethod(WebRequestMethod method) + { + switch (Method) + { + case WebRequestMethod.Default: + case WebRequestMethod.Get: + return HttpMethod.Get; + case WebRequestMethod.Head: + return HttpMethod.Head; + case WebRequestMethod.Post: + return HttpMethod.Post; + case WebRequestMethod.Put: + return HttpMethod.Put; + case WebRequestMethod.Delete: + return HttpMethod.Delete; + case WebRequestMethod.Trace: + return HttpMethod.Trace; + case WebRequestMethod.Options: + return HttpMethod.Options; + default: + // Merge and Patch + return new HttpMethod(Method.ToString().ToUpperInvariant()); } + } - // we only support the URL in angle brackets and `rel`, other attributes are ignored - // user can still parse it themselves via the Headers property - string pattern = "<(?.*?)>;\\srel=\"(?.*?)\""; - IEnumerable links; - if (response.Headers.TryGetValues("Link", out links)) + private ErrorRecord GetValidationError(string msg, string errorId) + { + var ex = new ValidationMetadataException(msg); + var error = new ErrorRecord(ex, errorId, ErrorCategory.InvalidArgument, this); + return (error); + } + + private ErrorRecord GetValidationError(string msg, string errorId, params object[] args) + { + msg = string.Format(CultureInfo.InvariantCulture, msg, args); + var ex = new ValidationMetadataException(msg); + var error = new ErrorRecord(ex, errorId, ErrorCategory.InvalidArgument, this); + return (error); + } + + private bool IsCustomMethodSet() + { + return (ParameterSetName == "CustomMethod"); + } + + private bool IsStandardMethodSet() + { + return (ParameterSetName == "StandardMethod"); + } + + private Uri PrepareUri(Uri uri) + { + uri = CheckProtocol(uri); + + // before creating the web request, + // preprocess Body if content is a dictionary and method is GET (set as query) + IDictionary bodyAsDictionary; + LanguagePrimitives.TryConvertTo(Body, out bodyAsDictionary); + if ((null != bodyAsDictionary) + && ((IsStandardMethodSet() && (Method == WebRequestMethod.Default || Method == WebRequestMethod.Get)) + || (IsCustomMethodSet() && CustomMethod.ToUpperInvariant() == "GET"))) { - foreach (string linkHeader in links) + UriBuilder uriBuilder = new UriBuilder(uri); + if (uriBuilder.Query != null && uriBuilder.Query.Length > 1) { - foreach (string link in linkHeader.Split(",")) - { - Match match = Regex.Match(link, pattern); - if (match.Success) - { - string url = match.Groups["url"].Value; - string rel = match.Groups["rel"].Value; - if (url != String.Empty && rel != String.Empty && !_relationLink.ContainsKey(rel)) - { - Uri absoluteUri = new Uri(requestUri, url); - _relationLink.Add(rel, absoluteUri.AbsoluteUri.ToString()); - } - } - } + uriBuilder.Query = uriBuilder.Query.Substring(1) + "&" + FormatDictionary(bodyAsDictionary); + } + else + { + uriBuilder.Query = FormatDictionary(bodyAsDictionary); } + uri = uriBuilder.Uri; + // set body to null to prevent later FillRequestStream + Body = null; + } + + return uri; + } + + private void ProcessAuthentication() + { + if(Authentication == WebAuthenticationType.Basic) + { + WebSession.Headers["Authorization"] = GetBasicAuthorizationHeader(); + } + else if (Authentication == WebAuthenticationType.Bearer || Authentication == WebAuthenticationType.OAuth) + { + WebSession.Headers["Authorization"] = GetBearerAuthorizationHeader(); + } + else + { + Diagnostics.Assert(false, String.Format("Unrecognized Authentication value: {0}", Authentication)); } } - #endregion Helper Methods + private string QualifyFilePath(string path) + { + string resolvedFilePath = PathUtils.ResolveFilePath(path, this, false); + return resolvedFilePath; + } + + #endregion Helper Private Methods } } diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs index ee29d80e6f2..ca385d68ce8 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs @@ -3,11 +3,11 @@ --********************************************************************/ using System; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.IO; -using System.Text; using System.Net.Http; -using System.Collections.Generic; +using System.Text; namespace Microsoft.PowerShell.Commands { @@ -16,6 +16,30 @@ namespace Microsoft.PowerShell.Commands /// public partial class WebResponseObject { + #region Constructors + + /// + /// Constructor for WebResponseObject + /// + /// + public WebResponseObject(HttpResponseMessage response) + : this(response, null) + { } + + /// + /// Constructor for WebResponseObject with contentStream + /// + /// + /// + public WebResponseObject(HttpResponseMessage response, Stream contentStream) + { + SetResponse(response, contentStream); + InitializeContent(); + InitializeRawContent(response); + } + + #endregion Constructors + #region Properties /// @@ -62,53 +86,6 @@ public long RawContentLength /// public string RawContent { get; protected set; } - #endregion Properties - - #region Methods - - /// - /// Reads the response content from the web response. - /// - private void InitializeContent() - { - this.Content = this.RawContentStream.ToArray(); - } - - private bool IsPrintable(char c) - { - return (Char.IsLetterOrDigit(c) || Char.IsPunctuation(c) || Char.IsSeparator(c) || Char.IsSymbol(c) || Char.IsWhiteSpace(c)); - } - - /// - /// Returns the string representation of this web response. - /// - /// The string representation of this web response. - public sealed override string ToString() - { - char[] stringContent = System.Text.Encoding.ASCII.GetChars(Content); - for (int counter = 0; counter < stringContent.Length; counter++) - { - if (!IsPrintable(stringContent[counter])) - { - stringContent[counter] = '.'; - } - } - - return new string(stringContent); - } - - #endregion Methods - } - - // TODO: Merge Partials - - /// - /// WebResponseObject - /// - public partial class WebResponseObject - { - #region Properties - /// /// gets or sets the BaseResponse property /// @@ -137,34 +114,36 @@ public Dictionary> Headers /// public Dictionary RelationLink { get; internal set; } - #endregion + #endregion Properties - #region Constructors + #region Methods /// - /// Constructor for WebResponseObject + /// Returns the string representation of this web response. /// - /// - public WebResponseObject(HttpResponseMessage response) - : this(response, null) - { } + /// The string representation of this web response. + public sealed override string ToString() + { + char[] stringContent = System.Text.Encoding.ASCII.GetChars(Content); + for (int counter = 0; counter < stringContent.Length; counter++) + { + if (!IsPrintable(stringContent[counter])) + { + stringContent[counter] = '.'; + } + } + + return new string(stringContent); + } /// - /// Constructor for WebResponseObject with contentStream + /// Reads the response content from the web response. /// - /// - /// - public WebResponseObject(HttpResponseMessage response, Stream contentStream) + private void InitializeContent() { - SetResponse(response, contentStream); - InitializeContent(); - InitializeRawContent(response); + this.Content = this.RawContentStream.ToArray(); } - #endregion Constructors - - #region Methods - private void InitializeRawContent(HttpResponseMessage baseResponse) { StringBuilder raw = ContentHelper.GetRawContentHeader(baseResponse); @@ -178,6 +157,11 @@ private void InitializeRawContent(HttpResponseMessage baseResponse) this.RawContent = raw.ToString(); } + private bool IsPrintable(char c) + { + return (Char.IsLetterOrDigit(c) || Char.IsPunctuation(c) || Char.IsSeparator(c) || Char.IsSymbol(c) || Char.IsWhiteSpace(c)); + } + private void SetResponse(HttpResponseMessage response, Stream contentStream) { if (null == response) { throw new ArgumentNullException("response"); } @@ -209,6 +193,6 @@ private void SetResponse(HttpResponseMessage response, Stream contentStream) _rawContentStream.Position = 0; } - #endregion + #endregion Methods } }